diff --git a/tests/conftest.py b/tests/conftest.py index 504ffb12d..5983446f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,8 @@ _API_KEY_ENV_VARS = ( "DEEPSEEK_API_KEY", "DASHSCOPE_API_KEY", "ZHIPU_API_KEY", + "MINIMAX_API_KEY", + "MINIMAX_CN_API_KEY", "OPENROUTER_API_KEY", "AZURE_OPENAI_API_KEY", "ALPHA_VANTAGE_API_KEY", diff --git a/tests/test_dataflows_config.py b/tests/test_dataflows_config.py new file mode 100644 index 000000000..ab0800eee --- /dev/null +++ b/tests/test_dataflows_config.py @@ -0,0 +1,61 @@ +"""Config isolation: get/set must not leak nested-dict references.""" + +import copy +import unittest + +import pytest + +import tradingagents.default_config as default_config +from tradingagents.dataflows.config import get_config, set_config + + +@pytest.mark.unit +class DataflowsConfigIsolationTests(unittest.TestCase): + def setUp(self): + set_config(copy.deepcopy(default_config.DEFAULT_CONFIG)) + + def test_get_config_returns_deep_copy(self): + cfg = get_config() + cfg["data_vendors"]["core_stock_apis"] = "alpha_vantage" + cfg["tool_vendors"]["get_stock_data"] = "alpha_vantage" + + fresh = get_config() + self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "yfinance") + self.assertNotIn("get_stock_data", fresh["tool_vendors"]) + + def test_set_config_does_not_alias_caller_nested_dicts(self): + custom = copy.deepcopy(default_config.DEFAULT_CONFIG) + custom["data_vendors"]["core_stock_apis"] = "alpha_vantage" + custom["tool_vendors"]["get_stock_data"] = "alpha_vantage" + + set_config(custom) + + custom["data_vendors"]["core_stock_apis"] = "yfinance" + custom["tool_vendors"]["get_stock_data"] = "yfinance" + + fresh = get_config() + self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage") + self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage") + + def test_partial_nested_update_preserves_existing_defaults(self): + set_config( + { + "data_vendors": { + "core_stock_apis": "alpha_vantage", + } + } + ) + + fresh = get_config() + self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage") + self.assertEqual(fresh["data_vendors"]["technical_indicators"], "yfinance") + self.assertEqual(fresh["data_vendors"]["fundamental_data"], "yfinance") + self.assertEqual(fresh["data_vendors"]["news_data"], "yfinance") + + def test_nested_dict_updates_merge_one_level_deep(self): + set_config({"tool_vendors": {"get_stock_data": "alpha_vantage"}}) + set_config({"tool_vendors": {"get_news": "alpha_vantage"}}) + + fresh = get_config() + self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage") + self.assertEqual(fresh["tool_vendors"]["get_news"], "alpha_vantage") diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index 5819494a3..6f3076aea 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -1,6 +1,8 @@ -import tradingagents.default_config as default_config +from copy import deepcopy from typing import Dict, Optional +import tradingagents.default_config as default_config + # Use default config but allow it to be overridden _config: Optional[Dict] = None @@ -9,22 +11,31 @@ def initialize_config(): """Initialize the configuration with default values.""" global _config if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() + _config = deepcopy(default_config.DEFAULT_CONFIG) def set_config(config: Dict): - """Update the configuration with custom values.""" + """Update the configuration with custom values. + + Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a + partial update like ``{"data_vendors": {"core_stock_apis": "alpha_vantage"}}`` + keeps the other nested keys from the default; scalar keys are replaced. + """ global _config - if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() - _config.update(config) + initialize_config() + incoming = deepcopy(config) + for key, value in incoming.items(): + if isinstance(value, dict) and isinstance(_config.get(key), dict): + _config[key].update(value) + else: + _config[key] = value def get_config() -> Dict: """Get the current configuration.""" if _config is None: initialize_config() - return _config.copy() + return deepcopy(_config) # Initialize with default config diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 197913e21..949dbf654 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -116,7 +116,9 @@ class TradingAgentsGraph: self.conditional_logic, ) - self.propagator = Propagator() + self.propagator = Propagator( + max_recur_limit=self.config.get("max_recur_limit", 100), + ) self.reflector = Reflector(self.quick_thinking_llm) self.signal_processor = SignalProcessor(self.quick_thinking_llm) diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 354849123..6947ad41e 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -162,6 +162,12 @@ class OpenAIClient(BaseLLMClient): api_key = os.environ.get(api_key_env) if api_key: llm_kwargs["api_key"] = api_key + else: + raise ValueError( + f"API key for provider '{self.provider}' is not set. " + f"Please set the {api_key_env} environment variable " + f"(e.g. add {api_key_env}=your_key to your .env file)." + ) else: llm_kwargs["api_key"] = "ollama" elif self.base_url: