feat(config): expose sampling temperature and document reproducibility

Adds a cross-provider temperature config (and TRADINGAGENTS_TEMPERATURE),
forwarded to every LLM client when set, so runs can be made less variable
on models that honor it. Adds a README "Reproducibility" section that
separates the sources of run-to-run variation, what users can control
(temperature, non-reasoning model, pinned date), and what is inherent to
LLM-driven analysis, and notes that the identity and verified-data fixes
already removed the "different companies / fabricated prices" variance.

#178 #168
This commit is contained in:
Yijia-Xiao
2026-05-31 03:51:50 +00:00
parent 47cbb321fe
commit 8a22594607
9 changed files with 123 additions and 4 deletions

View File

@@ -30,3 +30,7 @@ OPENROUTER_API_KEY=
#TRADINGAGENTS_MAX_DEBATE_ROUNDS=1
#TRADINGAGENTS_MAX_RISK_ROUNDS=1
#TRADINGAGENTS_CHECKPOINT_ENABLED=false
# Sampling temperature (lower = less run-to-run variation on models that
# honor it). Unset leaves each provider at its default. See the README
# "Reproducibility" note — no setting makes LLM output fully deterministic.
#TRADINGAGENTS_TEMPERATURE=0.0

View File

@@ -253,6 +253,28 @@ ta = TradingAgentsGraph(config=config)
_, decision = ta.propagate("NVDA", "2026-01-15")
```
## Reproducibility
TradingAgents is LLM-driven, so two runs of the same ticker and date can differ. This is expected for a research tool built on language models, not a defect. The variation comes from a few distinct sources, and it helps to separate them.
Language model sampling is non-deterministic. Even at a fixed temperature, providers do not guarantee byte-identical output across calls, and reasoning models (the default GPT-5.x family, and any thinking-mode model) vary the most because their internal reasoning is itself sampled.
Live data moves. News, StockTwits, and Reddit return different content as time passes, so a run today sees different inputs than a run last week even for the same historical trade date. Pin the analysis date to hold the price and indicator window fixed, but the social and news sources still reflect "now".
To reduce variation you can lower the sampling temperature. Set `temperature` in your config (or `TRADINGAGENTS_TEMPERATURE` in `.env`); lower values make models that honor it more repeatable. Reasoning models largely ignore temperature, so for tighter reproducibility pair a low temperature with a non-reasoning model such as `gpt-4.1`.
```python
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "openai"
config["deep_think_llm"] = "gpt-4.1" # non-reasoning model honors temperature
config["quick_think_llm"] = "gpt-4.1"
config["temperature"] = 0.0
```
What does not vary anymore: the analyzed company identity is resolved deterministically from the ticker before any agent runs, and the market analyst grounds exact price and indicator claims in a verified data snapshot. Earlier reports of "different companies" or fabricated price levels across runs are addressed by these two mechanisms.
Backtest results are not guaranteed to match any published figure. Returns depend on the model, the temperature, the date range, data quality, and the sampling above. Treat the framework as a research scaffold for studying multi-agent analysis, not as a strategy with a fixed, replicable return.
## Contributing
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).

View File

@@ -0,0 +1,80 @@
"""Tests for the configurable sampling temperature (#178/#168).
Temperature is a cross-provider knob: when set it must reach the underlying
chat client; when unset the provider keeps its own default.
"""
import importlib
import pytest
from tradingagents.llm_clients.factory import create_llm_client
@pytest.mark.unit
class TestTemperatureForwarding:
@pytest.mark.parametrize(
"provider,model",
[
("openai", "gpt-4.1"),
("anthropic", "claude-sonnet-4-6"),
("google", "gemini-2.5-flash"),
("deepseek", "deepseek-chat"),
],
)
def test_temperature_reaches_client_when_set(self, provider, model):
llm = create_llm_client(
provider=provider, model=model, temperature=0.0, api_key="placeholder"
).get_llm()
assert llm.temperature == 0.0
def test_temperature_omitted_leaves_provider_default(self):
# Not passing temperature must not force it to a value.
llm = create_llm_client(
provider="openai", model="gpt-4.1", api_key="placeholder"
).get_llm()
# langchain's default is unset/None, not 0.0
assert llm.temperature is None
@pytest.mark.unit
class TestTemperatureEnvOverlay:
def test_env_sets_temperature(self, monkeypatch):
import tradingagents.default_config as dc
monkeypatch.setenv("TRADINGAGENTS_TEMPERATURE", "0.2")
importlib.reload(dc)
# Stored on config (string from env is fine; consumed via float()).
assert dc.DEFAULT_CONFIG["temperature"] in ("0.2", 0.2)
assert float(dc.DEFAULT_CONFIG["temperature"]) == 0.2
monkeypatch.delenv("TRADINGAGENTS_TEMPERATURE", raising=False)
importlib.reload(dc)
def test_default_temperature_is_none(self, monkeypatch):
import tradingagents.default_config as dc
monkeypatch.delenv("TRADINGAGENTS_TEMPERATURE", raising=False)
importlib.reload(dc)
assert dc.DEFAULT_CONFIG["temperature"] is None
@pytest.mark.unit
class TestProviderKwargsTemperature:
"""_get_provider_kwargs float-coerces and forwards temperature, or omits it."""
def _kwargs_for(self, temperature):
from tradingagents.graph.trading_graph import TradingAgentsGraph
# Call the method without constructing the full graph.
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
graph.config = {"llm_provider": "openai", "temperature": temperature}
return TradingAgentsGraph._get_provider_kwargs(graph)
def test_float_string_coerced(self):
assert self._kwargs_for("0.3")["temperature"] == 0.3
def test_float_passthrough(self):
assert self._kwargs_for(0.0)["temperature"] == 0.0
def test_none_omitted(self):
assert "temperature" not in self._kwargs_for(None)
def test_empty_string_omitted(self):
assert "temperature" not in self._kwargs_for("")

View File

@@ -17,6 +17,7 @@ _ENV_OVERRIDES = {
"TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds",
"TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled",
"TRADINGAGENTS_BENCHMARK_TICKER": "benchmark_ticker",
"TRADINGAGENTS_TEMPERATURE": "temperature",
}
@@ -64,6 +65,11 @@ DEFAULT_CONFIG = _apply_env_overrides({
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"
"anthropic_effort": None, # "high", "medium", "low"
# Sampling temperature, forwarded to every provider when set. None leaves
# each provider at its own default. Lower values reduce run-to-run
# variation on models that honor it; reasoning models largely ignore it
# and no setting makes LLM output bit-identical across runs (see README).
"temperature": None,
# Checkpoint/resume: when True, LangGraph saves state after each node
# so a crashed run can resume from the last successful step.
"checkpoint_enabled": False,

View File

@@ -155,6 +155,13 @@ class TradingAgentsGraph:
if effort:
kwargs["effort"] = effort
# Sampling temperature is cross-provider: forward it whenever set.
# float() here so a value coming from a TRADINGAGENTS_TEMPERATURE env
# string ("0.2") works the same as a programmatic float.
temperature = self.config.get("temperature")
if temperature is not None and temperature != "":
kwargs["temperature"] = float(temperature)
return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]:

View File

@@ -7,7 +7,7 @@ from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "api_key", "max_tokens",
"timeout", "max_retries", "api_key", "max_tokens", "temperature",
"callbacks", "http_client", "http_async_client", "effort",
)

View File

@@ -7,7 +7,7 @@ from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "api_key", "reasoning_effort",
"timeout", "max_retries", "api_key", "reasoning_effort", "temperature",
"callbacks", "http_client", "http_async_client",
)

View File

@@ -31,7 +31,7 @@ class GoogleClient(BaseLLMClient):
if self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in ("timeout", "max_retries", "callbacks", "http_client", "http_async_client"):
for key in ("timeout", "max_retries", "temperature", "callbacks", "http_client", "http_async_client"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]

View File

@@ -138,7 +138,7 @@ class MinimaxChatOpenAI(NormalizedChatOpenAI):
# Kwargs forwarded from user config to ChatOpenAI
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "reasoning_effort",
"timeout", "max_retries", "reasoning_effort", "temperature",
"api_key", "callbacks", "http_client", "http_async_client",
)