mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
fix: bundle config/recursion/missing-key fixes
- dataflows/config: deepcopy + one-level dict merge so a partial set_config doesn't clobber sibling defaults - graph: thread max_recur_limit from config to Propagator - openai_client: name the missing env var in the API-key error #788 #764 #680
This commit is contained in:
@@ -19,6 +19,8 @@ _API_KEY_ENV_VARS = (
|
|||||||
"DEEPSEEK_API_KEY",
|
"DEEPSEEK_API_KEY",
|
||||||
"DASHSCOPE_API_KEY",
|
"DASHSCOPE_API_KEY",
|
||||||
"ZHIPU_API_KEY",
|
"ZHIPU_API_KEY",
|
||||||
|
"MINIMAX_API_KEY",
|
||||||
|
"MINIMAX_CN_API_KEY",
|
||||||
"OPENROUTER_API_KEY",
|
"OPENROUTER_API_KEY",
|
||||||
"AZURE_OPENAI_API_KEY",
|
"AZURE_OPENAI_API_KEY",
|
||||||
"ALPHA_VANTAGE_API_KEY",
|
"ALPHA_VANTAGE_API_KEY",
|
||||||
|
|||||||
61
tests/test_dataflows_config.py
Normal file
61
tests/test_dataflows_config.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""Config isolation: get/set must not leak nested-dict references."""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
from tradingagents.dataflows.config import get_config, set_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class DataflowsConfigIsolationTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
set_config(copy.deepcopy(default_config.DEFAULT_CONFIG))
|
||||||
|
|
||||||
|
def test_get_config_returns_deep_copy(self):
|
||||||
|
cfg = get_config()
|
||||||
|
cfg["data_vendors"]["core_stock_apis"] = "alpha_vantage"
|
||||||
|
cfg["tool_vendors"]["get_stock_data"] = "alpha_vantage"
|
||||||
|
|
||||||
|
fresh = get_config()
|
||||||
|
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "yfinance")
|
||||||
|
self.assertNotIn("get_stock_data", fresh["tool_vendors"])
|
||||||
|
|
||||||
|
def test_set_config_does_not_alias_caller_nested_dicts(self):
|
||||||
|
custom = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
custom["data_vendors"]["core_stock_apis"] = "alpha_vantage"
|
||||||
|
custom["tool_vendors"]["get_stock_data"] = "alpha_vantage"
|
||||||
|
|
||||||
|
set_config(custom)
|
||||||
|
|
||||||
|
custom["data_vendors"]["core_stock_apis"] = "yfinance"
|
||||||
|
custom["tool_vendors"]["get_stock_data"] = "yfinance"
|
||||||
|
|
||||||
|
fresh = get_config()
|
||||||
|
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
|
||||||
|
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
|
||||||
|
|
||||||
|
def test_partial_nested_update_preserves_existing_defaults(self):
|
||||||
|
set_config(
|
||||||
|
{
|
||||||
|
"data_vendors": {
|
||||||
|
"core_stock_apis": "alpha_vantage",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
fresh = get_config()
|
||||||
|
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
|
||||||
|
self.assertEqual(fresh["data_vendors"]["technical_indicators"], "yfinance")
|
||||||
|
self.assertEqual(fresh["data_vendors"]["fundamental_data"], "yfinance")
|
||||||
|
self.assertEqual(fresh["data_vendors"]["news_data"], "yfinance")
|
||||||
|
|
||||||
|
def test_nested_dict_updates_merge_one_level_deep(self):
|
||||||
|
set_config({"tool_vendors": {"get_stock_data": "alpha_vantage"}})
|
||||||
|
set_config({"tool_vendors": {"get_news": "alpha_vantage"}})
|
||||||
|
|
||||||
|
fresh = get_config()
|
||||||
|
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
|
||||||
|
self.assertEqual(fresh["tool_vendors"]["get_news"], "alpha_vantage")
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
import tradingagents.default_config as default_config
|
from copy import deepcopy
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
# Use default config but allow it to be overridden
|
||||||
_config: Optional[Dict] = None
|
_config: Optional[Dict] = None
|
||||||
|
|
||||||
@@ -9,22 +11,31 @@ def initialize_config():
|
|||||||
"""Initialize the configuration with default values."""
|
"""Initialize the configuration with default values."""
|
||||||
global _config
|
global _config
|
||||||
if _config is None:
|
if _config is None:
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Dict):
|
def set_config(config: Dict):
|
||||||
"""Update the configuration with custom values."""
|
"""Update the configuration with custom values.
|
||||||
|
|
||||||
|
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
|
||||||
|
partial update like ``{"data_vendors": {"core_stock_apis": "alpha_vantage"}}``
|
||||||
|
keeps the other nested keys from the default; scalar keys are replaced.
|
||||||
|
"""
|
||||||
global _config
|
global _config
|
||||||
if _config is None:
|
initialize_config()
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
incoming = deepcopy(config)
|
||||||
_config.update(config)
|
for key, value in incoming.items():
|
||||||
|
if isinstance(value, dict) and isinstance(_config.get(key), dict):
|
||||||
|
_config[key].update(value)
|
||||||
|
else:
|
||||||
|
_config[key] = value
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> Dict:
|
def get_config() -> Dict:
|
||||||
"""Get the current configuration."""
|
"""Get the current configuration."""
|
||||||
if _config is None:
|
if _config is None:
|
||||||
initialize_config()
|
initialize_config()
|
||||||
return _config.copy()
|
return deepcopy(_config)
|
||||||
|
|
||||||
|
|
||||||
# Initialize with default config
|
# Initialize with default config
|
||||||
|
|||||||
@@ -116,7 +116,9 @@ class TradingAgentsGraph:
|
|||||||
self.conditional_logic,
|
self.conditional_logic,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.propagator = Propagator()
|
self.propagator = Propagator(
|
||||||
|
max_recur_limit=self.config.get("max_recur_limit", 100),
|
||||||
|
)
|
||||||
self.reflector = Reflector(self.quick_thinking_llm)
|
self.reflector = Reflector(self.quick_thinking_llm)
|
||||||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||||||
|
|
||||||
|
|||||||
@@ -162,6 +162,12 @@ class OpenAIClient(BaseLLMClient):
|
|||||||
api_key = os.environ.get(api_key_env)
|
api_key = os.environ.get(api_key_env)
|
||||||
if api_key:
|
if api_key:
|
||||||
llm_kwargs["api_key"] = api_key
|
llm_kwargs["api_key"] = api_key
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"API key for provider '{self.provider}' is not set. "
|
||||||
|
f"Please set the {api_key_env} environment variable "
|
||||||
|
f"(e.g. add {api_key_env}=your_key to your .env file)."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
llm_kwargs["api_key"] = "ollama"
|
llm_kwargs["api_key"] = "ollama"
|
||||||
elif self.base_url:
|
elif self.base_url:
|
||||||
|
|||||||
Reference in New Issue
Block a user