mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
feat(config): expose sampling temperature and document reproducibility
Adds a cross-provider temperature config (and TRADINGAGENTS_TEMPERATURE), forwarded to every LLM client when set, so runs can be made less variable on models that honor it. Adds a README "Reproducibility" section that separates the sources of run-to-run variation, what users can control (temperature, non-reasoning model, pinned date), and what is inherent to LLM-driven analysis, and notes that the identity and verified-data fixes already removed the "different companies / fabricated prices" variance. #178 #168
This commit is contained in:
@@ -30,3 +30,7 @@ OPENROUTER_API_KEY=
|
|||||||
#TRADINGAGENTS_MAX_DEBATE_ROUNDS=1
|
#TRADINGAGENTS_MAX_DEBATE_ROUNDS=1
|
||||||
#TRADINGAGENTS_MAX_RISK_ROUNDS=1
|
#TRADINGAGENTS_MAX_RISK_ROUNDS=1
|
||||||
#TRADINGAGENTS_CHECKPOINT_ENABLED=false
|
#TRADINGAGENTS_CHECKPOINT_ENABLED=false
|
||||||
|
# Sampling temperature (lower = less run-to-run variation on models that
|
||||||
|
# honor it). Unset leaves each provider at its default. See the README
|
||||||
|
# "Reproducibility" note — no setting makes LLM output fully deterministic.
|
||||||
|
#TRADINGAGENTS_TEMPERATURE=0.0
|
||||||
|
|||||||
22
README.md
22
README.md
@@ -253,6 +253,28 @@ ta = TradingAgentsGraph(config=config)
|
|||||||
_, decision = ta.propagate("NVDA", "2026-01-15")
|
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Reproducibility
|
||||||
|
|
||||||
|
TradingAgents is LLM-driven, so two runs of the same ticker and date can differ. This is expected for a research tool built on language models, not a defect. The variation comes from a few distinct sources, and it helps to separate them.
|
||||||
|
|
||||||
|
Language model sampling is non-deterministic. Even at a fixed temperature, providers do not guarantee byte-identical output across calls, and reasoning models (the default GPT-5.x family, and any thinking-mode model) vary the most because their internal reasoning is itself sampled.
|
||||||
|
|
||||||
|
Live data moves. News, StockTwits, and Reddit return different content as time passes, so a run today sees different inputs than a run last week even for the same historical trade date. Pin the analysis date to hold the price and indicator window fixed, but the social and news sources still reflect "now".
|
||||||
|
|
||||||
|
To reduce variation you can lower the sampling temperature. Set `temperature` in your config (or `TRADINGAGENTS_TEMPERATURE` in `.env`); lower values make models that honor it more repeatable. Reasoning models largely ignore temperature, so for tighter reproducibility pair a low temperature with a non-reasoning model such as `gpt-4.1`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = DEFAULT_CONFIG.copy()
|
||||||
|
config["llm_provider"] = "openai"
|
||||||
|
config["deep_think_llm"] = "gpt-4.1" # non-reasoning model honors temperature
|
||||||
|
config["quick_think_llm"] = "gpt-4.1"
|
||||||
|
config["temperature"] = 0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
What does not vary anymore: the analyzed company identity is resolved deterministically from the ticker before any agent runs, and the market analyst grounds exact price and indicator claims in a verified data snapshot. Earlier reports of "different companies" or fabricated price levels across runs are addressed by these two mechanisms.
|
||||||
|
|
||||||
|
Backtest results are not guaranteed to match any published figure. Returns depend on the model, the temperature, the date range, data quality, and the sampling above. Treat the framework as a research scaffold for studying multi-agent analysis, not as a strategy with a fixed, replicable return.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||||
|
|||||||
80
tests/test_temperature_config.py
Normal file
80
tests/test_temperature_config.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Tests for the configurable sampling temperature (#178/#168).
|
||||||
|
|
||||||
|
Temperature is a cross-provider knob: when set it must reach the underlying
|
||||||
|
chat client; when unset the provider keeps its own default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTemperatureForwarding:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider,model",
|
||||||
|
[
|
||||||
|
("openai", "gpt-4.1"),
|
||||||
|
("anthropic", "claude-sonnet-4-6"),
|
||||||
|
("google", "gemini-2.5-flash"),
|
||||||
|
("deepseek", "deepseek-chat"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_temperature_reaches_client_when_set(self, provider, model):
|
||||||
|
llm = create_llm_client(
|
||||||
|
provider=provider, model=model, temperature=0.0, api_key="placeholder"
|
||||||
|
).get_llm()
|
||||||
|
assert llm.temperature == 0.0
|
||||||
|
|
||||||
|
def test_temperature_omitted_leaves_provider_default(self):
|
||||||
|
# Not passing temperature must not force it to a value.
|
||||||
|
llm = create_llm_client(
|
||||||
|
provider="openai", model="gpt-4.1", api_key="placeholder"
|
||||||
|
).get_llm()
|
||||||
|
# langchain's default is unset/None, not 0.0
|
||||||
|
assert llm.temperature is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTemperatureEnvOverlay:
|
||||||
|
def test_env_sets_temperature(self, monkeypatch):
|
||||||
|
import tradingagents.default_config as dc
|
||||||
|
monkeypatch.setenv("TRADINGAGENTS_TEMPERATURE", "0.2")
|
||||||
|
importlib.reload(dc)
|
||||||
|
# Stored on config (string from env is fine; consumed via float()).
|
||||||
|
assert dc.DEFAULT_CONFIG["temperature"] in ("0.2", 0.2)
|
||||||
|
assert float(dc.DEFAULT_CONFIG["temperature"]) == 0.2
|
||||||
|
monkeypatch.delenv("TRADINGAGENTS_TEMPERATURE", raising=False)
|
||||||
|
importlib.reload(dc)
|
||||||
|
|
||||||
|
def test_default_temperature_is_none(self, monkeypatch):
|
||||||
|
import tradingagents.default_config as dc
|
||||||
|
monkeypatch.delenv("TRADINGAGENTS_TEMPERATURE", raising=False)
|
||||||
|
importlib.reload(dc)
|
||||||
|
assert dc.DEFAULT_CONFIG["temperature"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestProviderKwargsTemperature:
|
||||||
|
"""_get_provider_kwargs float-coerces and forwards temperature, or omits it."""
|
||||||
|
|
||||||
|
def _kwargs_for(self, temperature):
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
# Call the method without constructing the full graph.
|
||||||
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||||
|
graph.config = {"llm_provider": "openai", "temperature": temperature}
|
||||||
|
return TradingAgentsGraph._get_provider_kwargs(graph)
|
||||||
|
|
||||||
|
def test_float_string_coerced(self):
|
||||||
|
assert self._kwargs_for("0.3")["temperature"] == 0.3
|
||||||
|
|
||||||
|
def test_float_passthrough(self):
|
||||||
|
assert self._kwargs_for(0.0)["temperature"] == 0.0
|
||||||
|
|
||||||
|
def test_none_omitted(self):
|
||||||
|
assert "temperature" not in self._kwargs_for(None)
|
||||||
|
|
||||||
|
def test_empty_string_omitted(self):
|
||||||
|
assert "temperature" not in self._kwargs_for("")
|
||||||
@@ -17,6 +17,7 @@ _ENV_OVERRIDES = {
|
|||||||
"TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds",
|
"TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds",
|
||||||
"TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled",
|
"TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled",
|
||||||
"TRADINGAGENTS_BENCHMARK_TICKER": "benchmark_ticker",
|
"TRADINGAGENTS_BENCHMARK_TICKER": "benchmark_ticker",
|
||||||
|
"TRADINGAGENTS_TEMPERATURE": "temperature",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -64,6 +65,11 @@ DEFAULT_CONFIG = _apply_env_overrides({
|
|||||||
"google_thinking_level": None, # "high", "minimal", etc.
|
"google_thinking_level": None, # "high", "minimal", etc.
|
||||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||||
"anthropic_effort": None, # "high", "medium", "low"
|
"anthropic_effort": None, # "high", "medium", "low"
|
||||||
|
# Sampling temperature, forwarded to every provider when set. None leaves
|
||||||
|
# each provider at its own default. Lower values reduce run-to-run
|
||||||
|
# variation on models that honor it; reasoning models largely ignore it
|
||||||
|
# and no setting makes LLM output bit-identical across runs (see README).
|
||||||
|
"temperature": None,
|
||||||
# Checkpoint/resume: when True, LangGraph saves state after each node
|
# Checkpoint/resume: when True, LangGraph saves state after each node
|
||||||
# so a crashed run can resume from the last successful step.
|
# so a crashed run can resume from the last successful step.
|
||||||
"checkpoint_enabled": False,
|
"checkpoint_enabled": False,
|
||||||
|
|||||||
@@ -155,6 +155,13 @@ class TradingAgentsGraph:
|
|||||||
if effort:
|
if effort:
|
||||||
kwargs["effort"] = effort
|
kwargs["effort"] = effort
|
||||||
|
|
||||||
|
# Sampling temperature is cross-provider: forward it whenever set.
|
||||||
|
# float() here so a value coming from a TRADINGAGENTS_TEMPERATURE env
|
||||||
|
# string ("0.2") works the same as a programmatic float.
|
||||||
|
temperature = self.config.get("temperature")
|
||||||
|
if temperature is not None and temperature != "":
|
||||||
|
kwargs["temperature"] = float(temperature)
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from .base_client import BaseLLMClient, normalize_content
|
|||||||
from .validators import validate_model
|
from .validators import validate_model
|
||||||
|
|
||||||
_PASSTHROUGH_KWARGS = (
|
_PASSTHROUGH_KWARGS = (
|
||||||
"timeout", "max_retries", "api_key", "max_tokens",
|
"timeout", "max_retries", "api_key", "max_tokens", "temperature",
|
||||||
"callbacks", "http_client", "http_async_client", "effort",
|
"callbacks", "http_client", "http_async_client", "effort",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from .base_client import BaseLLMClient, normalize_content
|
|||||||
from .validators import validate_model
|
from .validators import validate_model
|
||||||
|
|
||||||
_PASSTHROUGH_KWARGS = (
|
_PASSTHROUGH_KWARGS = (
|
||||||
"timeout", "max_retries", "api_key", "reasoning_effort",
|
"timeout", "max_retries", "api_key", "reasoning_effort", "temperature",
|
||||||
"callbacks", "http_client", "http_async_client",
|
"callbacks", "http_client", "http_async_client",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class GoogleClient(BaseLLMClient):
|
|||||||
if self.base_url:
|
if self.base_url:
|
||||||
llm_kwargs["base_url"] = self.base_url
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
for key in ("timeout", "max_retries", "callbacks", "http_client", "http_async_client"):
|
for key in ("timeout", "max_retries", "temperature", "callbacks", "http_client", "http_async_client"):
|
||||||
if key in self.kwargs:
|
if key in self.kwargs:
|
||||||
llm_kwargs[key] = self.kwargs[key]
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class MinimaxChatOpenAI(NormalizedChatOpenAI):
|
|||||||
|
|
||||||
# Kwargs forwarded from user config to ChatOpenAI
|
# Kwargs forwarded from user config to ChatOpenAI
|
||||||
_PASSTHROUGH_KWARGS = (
|
_PASSTHROUGH_KWARGS = (
|
||||||
"timeout", "max_retries", "reasoning_effort",
|
"timeout", "max_retries", "reasoning_effort", "temperature",
|
||||||
"api_key", "callbacks", "http_client", "http_async_client",
|
"api_key", "callbacks", "http_client", "http_async_client",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user