From 7004dfe5540a5a82166927eaf11645ed9a6dc1e6 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 4 Apr 2026 07:07:53 +0000 Subject: [PATCH 01/44] fix: remove hardcoded Google endpoint that caused 404 (#493, #496) --- cli/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cli/utils.py b/cli/utils.py index 62b50c9c3..15c4a056b 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -187,12 +187,11 @@ def select_deep_thinking_agent(provider) -> str: return choice -def select_llm_provider() -> tuple[str, str]: - """Select the OpenAI api url using interactive selection.""" - # Define OpenAI api options with their corresponding endpoints +def select_llm_provider() -> tuple[str, str | None]: + """Select the LLM provider and its API endpoint.""" BASE_URLS = [ ("OpenAI", "https://api.openai.com/v1"), - ("Google", "https://generativelanguage.googleapis.com/v1"), + ("Google", None), # google-genai SDK manages its own endpoint ("Anthropic", "https://api.anthropic.com/"), ("xAI", "https://api.x.ai/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), From 28d5cc661fc706b4711d15d5257884bdc4600b01 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 4 Apr 2026 07:14:10 +0000 Subject: [PATCH 02/44] fix: add missing pandas import in y_finance.py (#488) --- tradingagents/dataflows/y_finance.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 8b4b93f57..8f9bfe711 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -1,6 +1,7 @@ from typing import Annotated from datetime import datetime from dateutil.relativedelta import relativedelta +import pandas as pd import yfinance as yf import os from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date From 7269f877c1276f2e45c1e3455ed499f5e6746d6e Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 4 Apr 2026 07:22:01 +0000 Subject: [PATCH 03/44] fix: portfolio manager reads trader's proposal and research plan (#503) --- tradingagents/agents/managers/portfolio_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py index 970efb46a..6c69ae9fd 100644 --- a/tradingagents/agents/managers/portfolio_manager.py +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -12,7 +12,8 @@ def create_portfolio_manager(llm, memory): news_report = state["news_report"] fundamentals_report = state["fundamentals_report"] sentiment_report = state["sentiment_report"] - trader_plan = state["investment_plan"] + research_plan = state["investment_plan"] + trader_plan = state["trader_investment_plan"] curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" past_memories = memory.get_memories(curr_situation, n_matches=2) @@ -35,7 +36,8 @@ def create_portfolio_manager(llm, memory): - **Sell**: Exit position or avoid entry **Context:** -- Trader's proposed plan: **{trader_plan}** +- Research Manager's investment plan: **{research_plan}** +- Trader's transaction proposal: **{trader_plan}** - Lessons from past decisions: **{past_memory_str}** **Required Output Structure:** From 78fb66aed1a5664d163489e756f52750ecfc0ac2 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 4 Apr 2026 07:23:31 +0000 Subject: [PATCH 04/44] fix: normalize indicator names to lowercase (#490) --- tradingagents/agents/utils/technical_indicators_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tradingagents/agents/utils/technical_indicators_tools.py b/tradingagents/agents/utils/technical_indicators_tools.py index dc9825808..a3dda5a51 100644 --- a/tradingagents/agents/utils/technical_indicators_tools.py +++ b/tradingagents/agents/utils/technical_indicators_tools.py @@ -22,7 +22,7 @@ def get_indicators( """ # LLMs sometimes pass multiple indicators as a comma-separated string; # split and process each individually. - indicators = [i.strip() for i in indicator.split(",") if i.strip()] + indicators = [i.strip().lower() for i in indicator.split(",") if i.strip()] results = [] for ind in indicators: try: From bdc5fc62d32d048dc79bd124c5099fba84bb2971 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 4 Apr 2026 07:28:03 +0000 Subject: [PATCH 05/44] chore: bump langchain-google-genai minimum to 4.0.0 for thought signature support --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0decedb0f..98385e32e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "backtrader>=1.9.78.123", "langchain-anthropic>=0.3.15", "langchain-experimental>=0.3.4", - "langchain-google-genai>=2.1.5", + "langchain-google-genai>=4.0.0", "langchain-openai>=0.3.23", "langgraph>=0.4.8", "pandas>=2.3.0", From bdb9c29d44a2f97eede350567cef654ff93031ce Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 4 Apr 2026 07:35:35 +0000 Subject: [PATCH 06/44] refactor: remove stale imports, use configurable results path (#499) --- tradingagents/agents/analysts/fundamentals_analyst.py | 2 -- tradingagents/agents/analysts/market_analyst.py | 2 -- tradingagents/agents/analysts/news_analyst.py | 2 -- tradingagents/agents/analysts/social_media_analyst.py | 2 -- tradingagents/agents/managers/research_manager.py | 2 -- tradingagents/agents/researchers/bear_researcher.py | 3 --- tradingagents/agents/researchers/bull_researcher.py | 3 --- tradingagents/agents/risk_mgmt/aggressive_debator.py | 2 -- .../agents/risk_mgmt/conservative_debator.py | 3 --- tradingagents/agents/risk_mgmt/neutral_debator.py | 2 -- tradingagents/agents/trader/trader.py | 2 -- tradingagents/agents/utils/agent_states.py | 10 +++------- tradingagents/graph/reflection.py | 5 ++--- tradingagents/graph/setup.py | 9 ++++----- tradingagents/graph/signal_processing.py | 4 ++-- tradingagents/graph/trading_graph.py | 11 ++++------- 16 files changed, 15 insertions(+), 49 deletions(-) diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 3f70c734e..6aa49cf3b 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,6 +1,4 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_balance_sheet, diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 680f90190..fef8f7519 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,6 +1,4 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_indicators, diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index 42fc7a619..e0fe93c5c 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,6 +1,4 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_global_news, diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 67d78f4c2..34a53c462 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,6 +1,4 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news from tradingagents.dataflows.config import get_config diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 3ac4b1500..5b4b4fdc5 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -1,5 +1,3 @@ -import time -import json from tradingagents.agents.utils.agent_utils import build_instrument_context diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index 6634490a5..a44212dc4 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -1,6 +1,3 @@ -from langchain_core.messages import AIMessage -import time -import json def create_bear_researcher(llm, memory): diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index b03ef7553..d23d4d76e 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -1,6 +1,3 @@ -from langchain_core.messages import AIMessage -import time -import json def create_bull_researcher(llm, memory): diff --git a/tradingagents/agents/risk_mgmt/aggressive_debator.py b/tradingagents/agents/risk_mgmt/aggressive_debator.py index 651114a73..2dab1152a 100644 --- a/tradingagents/agents/risk_mgmt/aggressive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggressive_debator.py @@ -1,5 +1,3 @@ -import time -import json def create_aggressive_debator(llm): diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index 7c8c0fd1e..99a8315e0 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -1,6 +1,3 @@ -from langchain_core.messages import AIMessage -import time -import json def create_conservative_debator(llm): diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index 9ed490da4..e99ff0af1 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -1,5 +1,3 @@ -import time -import json def create_neutral_debator(llm): diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 6298f239d..07e9f262c 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -1,6 +1,4 @@ import functools -import time -import json from tradingagents.agents.utils.agent_utils import build_instrument_context diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 813b00ee1..6423b9363 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,10 +1,6 @@ -from typing import Annotated, Sequence -from datetime import date, timedelta, datetime -from typing_extensions import TypedDict, Optional -from langchain_openai import ChatOpenAI -from tradingagents.agents import * -from langgraph.prebuilt import ToolNode -from langgraph.graph import END, StateGraph, START, MessagesState +from typing import Annotated +from typing_extensions import TypedDict +from langgraph.graph import MessagesState # Researcher team state diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 85438595d..2a680038c 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -1,13 +1,12 @@ # TradingAgents/graph/reflection.py -from typing import Dict, Any -from langchain_openai import ChatOpenAI +from typing import Any, Dict class Reflector: """Handles reflection on decisions and updating memory.""" - def __init__(self, quick_thinking_llm: ChatOpenAI): + def __init__(self, quick_thinking_llm: Any): """Initialize the reflector with an LLM.""" self.quick_thinking_llm = quick_thinking_llm self.reflection_system_prompt = self._get_reflection_prompt() diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index e0771c656..ae90489c1 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,8 +1,7 @@ # TradingAgents/graph/setup.py -from typing import Dict, Any -from langchain_openai import ChatOpenAI -from langgraph.graph import END, StateGraph, START +from typing import Any, Dict +from langgraph.graph import END, START, StateGraph from langgraph.prebuilt import ToolNode from tradingagents.agents import * @@ -16,8 +15,8 @@ class GraphSetup: def __init__( self, - quick_thinking_llm: ChatOpenAI, - deep_thinking_llm: ChatOpenAI, + quick_thinking_llm: Any, + deep_thinking_llm: Any, tool_nodes: Dict[str, ToolNode], bull_memory, bear_memory, diff --git a/tradingagents/graph/signal_processing.py b/tradingagents/graph/signal_processing.py index f96c1efa3..5ac66c1dd 100644 --- a/tradingagents/graph/signal_processing.py +++ b/tradingagents/graph/signal_processing.py @@ -1,12 +1,12 @@ # TradingAgents/graph/signal_processing.py -from langchain_openai import ChatOpenAI +from typing import Any class SignalProcessor: """Processes trading signals to extract actionable decisions.""" - def __init__(self, quick_thinking_llm: ChatOpenAI): + def __init__(self, quick_thinking_llm: Any): """Initialize with an LLM for processing.""" self.quick_thinking_llm = quick_thinking_llm diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c8cd74929..8e18f9c48 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -259,15 +259,12 @@ class TradingAgentsGraph: } # Save to file - directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") + directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs" directory.mkdir(parents=True, exist_ok=True) - with open( - f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", - "w", - encoding="utf-8", - ) as f: - json.dump(self.log_states_dict, f, indent=4) + log_path = directory / f"full_states_log_{trade_date}.json" + with open(log_path, "w", encoding="utf-8") as f: + json.dump(self.log_states_dict[str(trade_date)], f, indent=4) def reflect_and_remember(self, returns_losses): """Reflect on decisions and update memory based on returns.""" From 4f965bf46af0c1de294993334af3aaf7a5bc79bd Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 4 Apr 2026 07:56:44 +0000 Subject: [PATCH 07/44] feat: dynamic OpenRouter model selection with search (#482, #337) --- cli/utils.py | 46 ++++++++++++++++++++++ tradingagents/llm_clients/model_catalog.py | 12 +----- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/cli/utils.py b/cli/utils.py index 15c4a056b..e071ce068 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -134,9 +134,52 @@ def select_research_depth() -> int: return choice +def _fetch_openrouter_models() -> List[Tuple[str, str]]: + """Fetch available models from the OpenRouter API.""" + import requests + try: + resp = requests.get("https://openrouter.ai/api/v1/models", timeout=10) + resp.raise_for_status() + models = resp.json().get("data", []) + return [(m.get("name") or m["id"], m["id"]) for m in models] + except Exception as e: + console.print(f"\n[yellow]Could not fetch OpenRouter models: {e}[/yellow]") + return [] + + +def select_openrouter_model() -> str: + """Select an OpenRouter model from the newest available, or enter a custom ID.""" + models = _fetch_openrouter_models() + + choices = [questionary.Choice(name, value=mid) for name, mid in models[:5]] + choices.append(questionary.Choice("Custom model ID", value="custom")) + + choice = questionary.select( + "Select OpenRouter Model (latest available):", + choices=choices, + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style([ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ]), + ).ask() + + if choice is None or choice == "custom": + return questionary.text( + "Enter OpenRouter model ID (e.g. google/gemma-4-26b-a4b-it):", + validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.", + ).ask().strip() + + return choice + + def select_shallow_thinking_agent(provider) -> str: """Select shallow thinking llm engine using an interactive selection.""" + if provider.lower() == "openrouter": + return select_openrouter_model() + choice = questionary.select( "Select Your [Quick-Thinking LLM Engine]:", choices=[ @@ -165,6 +208,9 @@ def select_shallow_thinking_agent(provider) -> str: def select_deep_thinking_agent(provider) -> str: """Select deep thinking llm engine using an interactive selection.""" + if provider.lower() == "openrouter": + return select_openrouter_model() + choice = questionary.select( "Select Your [Deep-Thinking LLM Engine]:", choices=[ diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index 91e1659ce..fd91c66db 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -63,16 +63,8 @@ MODEL_OPTIONS: ProviderModeOptions = { ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), ], }, - "openrouter": { - "quick": [ - ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), - ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), - ], - "deep": [ - ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), - ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), - ], - }, + # OpenRouter models are fetched dynamically at CLI runtime. + # No static entries needed; any model ID is accepted by the validator. "ollama": { "quick": [ ("Qwen3:latest (8B, local)", "qwen3:latest"), From 10c136f49c82e11f0e324c9c50cda1638a8ed5a7 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 4 Apr 2026 08:14:01 +0000 Subject: [PATCH 08/44] feat: add Docker support for cross-platform deployment --- .dockerignore | 15 +++++++++++++++ Dockerfile | 27 +++++++++++++++++++++++++++ README.md | 13 +++++++++++++ docker-compose.yml | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+) create mode 100644 .dockerignore create mode 100644 Dockerfile create mode 100644 docker-compose.yml diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..cac710188 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,15 @@ +.git +.venv +.env +.claude +.idea +.vscode +.DS_Store +__pycache__ +*.egg-info +build +dist +results +eval_results +Dockerfile +docker-compose.yml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..940609d35 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +FROM python:3.12-slim AS builder + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +RUN python -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +WORKDIR /build +COPY . . +RUN pip install --no-cache-dir . + +FROM python:3.12-slim + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +COPY --from=builder /opt/venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +RUN useradd --create-home appuser +USER appuser +WORKDIR /home/appuser/app + +COPY --from=builder --chown=appuser:appuser /build . + +ENTRYPOINT ["tradingagents"] diff --git a/README.md b/README.md index 4cfeb4e52..9a92bff99 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,19 @@ Install the package and its dependencies: pip install . ``` +### Docker + +Alternatively, run with Docker: +```bash +cp .env.example .env # add your API keys +docker compose run --rm tradingagents +``` + +For local models with Ollama: +```bash +docker compose --profile ollama run --rm tradingagents-ollama +``` + ### Required APIs TradingAgents supports multiple LLM providers. Set the API key for your chosen provider: diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 000000000..3a5d4e299 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,34 @@ +services: + tradingagents: + build: . + env_file: + - .env + volumes: + - ./results:/home/appuser/app/results + tty: true + stdin_open: true + + ollama: + image: ollama/ollama:latest + volumes: + - ollama_data:/root/.ollama + profiles: + - ollama + + tradingagents-ollama: + build: . + env_file: + - .env + environment: + - LLM_PROVIDER=ollama + volumes: + - ./results:/home/appuser/app/results + depends_on: + - ollama + tty: true + stdin_open: true + profiles: + - ollama + +volumes: + ollama_data: From 59d6b2152d846a94304c4cbb30ba1d6b30e71c05 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 13 Apr 2026 05:26:04 +0000 Subject: [PATCH 09/44] fix: use ~/.tradingagents/ for cache and logs, resolving Docker permission issue (#519) --- docker-compose.yml | 5 +++-- tradingagents/default_config.py | 9 ++++----- tradingagents/graph/trading_graph.py | 6 ++---- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 3a5d4e299..d28135b3c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,7 @@ services: env_file: - .env volumes: - - ./results:/home/appuser/app/results + - tradingagents_data:/home/appuser/.tradingagents tty: true stdin_open: true @@ -22,7 +22,7 @@ services: environment: - LLM_PROVIDER=ollama volumes: - - ./results:/home/appuser/app/results + - tradingagents_data:/home/appuser/.tradingagents depends_on: - ollama tty: true @@ -31,4 +31,5 @@ services: - ollama volumes: + tradingagents_data: ollama_data: diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 26a4e4d28..a9b75e4be 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -1,12 +1,11 @@ import os +_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents") + DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), - "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), - "data_cache_dir": os.path.join( - os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), - "dataflows/data_cache", - ), + "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")), + "data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")), # LLM settings "llm_provider": "openai", "deep_think_llm": "gpt-5.4", diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 8e18f9c48..78bc13e5f 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -66,10 +66,8 @@ class TradingAgentsGraph: set_config(self.config) # Create necessary directories - os.makedirs( - os.path.join(self.config["project_dir"], "dataflows/data_cache"), - exist_ok=True, - ) + os.makedirs(self.config["data_cache_dir"], exist_ok=True) + os.makedirs(self.config["results_dir"], exist_ok=True) # Initialize LLMs with provider-specific thinking configuration llm_kwargs = self._get_provider_kwargs() From b0f6058299ed0e31677a1e3eb31b662aca10abd4 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 13 Apr 2026 07:12:07 +0000 Subject: [PATCH 10/44] feat: add DeepSeek, Qwen, GLM, and Azure OpenAI provider support --- .env.enterprise.example | 5 ++ .env.example | 3 + README.md | 5 ++ cli/main.py | 3 +- cli/utils.py | 92 +++++++++++----------- tradingagents/llm_clients/azure_client.py | 52 ++++++++++++ tradingagents/llm_clients/factory.py | 22 +++--- tradingagents/llm_clients/model_catalog.py | 39 ++++++++- tradingagents/llm_clients/openai_client.py | 3 + 9 files changed, 163 insertions(+), 61 deletions(-) create mode 100644 .env.enterprise.example create mode 100644 tradingagents/llm_clients/azure_client.py diff --git a/.env.enterprise.example b/.env.enterprise.example new file mode 100644 index 000000000..4f7bda3dd --- /dev/null +++ b/.env.enterprise.example @@ -0,0 +1,5 @@ +# Azure OpenAI +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/ +AZURE_OPENAI_DEPLOYMENT_NAME= +# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API diff --git a/.env.example b/.env.example index 1328b838f..be9bf13eb 100644 --- a/.env.example +++ b/.env.example @@ -3,4 +3,7 @@ OPENAI_API_KEY= GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= +DEEPSEEK_API_KEY= +DASHSCOPE_API_KEY= +ZHIPU_API_KEY= OPENROUTER_API_KEY= diff --git a/README.md b/README.md index 9a92bff99..97cbde486 100644 --- a/README.md +++ b/README.md @@ -140,10 +140,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT) export GOOGLE_API_KEY=... # Google (Gemini) export ANTHROPIC_API_KEY=... # Anthropic (Claude) export XAI_API_KEY=... # xAI (Grok) +export DEEPSEEK_API_KEY=... # DeepSeek +export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope) +export ZHIPU_API_KEY=... # GLM (Zhipu) export OPENROUTER_API_KEY=... # OpenRouter export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage ``` +For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials. + For local models, configure Ollama with `llm_provider: "ollama"` in your config. Alternatively, copy `.env.example` to `.env` and fill in your keys: diff --git a/cli/main.py b/cli/main.py index 29294d8d8..52e8a3327 100644 --- a/cli/main.py +++ b/cli/main.py @@ -6,8 +6,9 @@ from functools import wraps from rich.console import Console from dotenv import load_dotenv -# Load environment variables from .env file +# Load environment variables load_dotenv() +load_dotenv(".env.enterprise", override=False) from rich.panel import Panel from rich.spinner import Spinner from rich.live import Live diff --git a/cli/utils.py b/cli/utils.py index e071ce068..85c282edd 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -174,17 +174,30 @@ def select_openrouter_model() -> str: return choice -def select_shallow_thinking_agent(provider) -> str: - """Select shallow thinking llm engine using an interactive selection.""" +def _prompt_custom_model_id() -> str: + """Prompt user to type a custom model ID.""" + return questionary.text( + "Enter model ID:", + validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.", + ).ask().strip() + +def _select_model(provider: str, mode: str) -> str: + """Select a model for the given provider and mode (quick/deep).""" if provider.lower() == "openrouter": return select_openrouter_model() + if provider.lower() == "azure": + return questionary.text( + f"Enter Azure deployment name ({mode}-thinking):", + validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.", + ).ask().strip() + choice = questionary.select( - "Select Your [Quick-Thinking LLM Engine]:", + f"Select Your [{mode.title()}-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in get_model_options(provider, "quick") + for display, value in get_model_options(provider, mode) ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -197,58 +210,45 @@ def select_shallow_thinking_agent(provider) -> str: ).ask() if choice is None: - console.print( - "\n[red]No shallow thinking llm engine selected. Exiting...[/red]" - ) + console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]") exit(1) + if choice == "custom": + return _prompt_custom_model_id() + return choice +def select_shallow_thinking_agent(provider) -> str: + """Select shallow thinking llm engine using an interactive selection.""" + return _select_model(provider, "quick") + + def select_deep_thinking_agent(provider) -> str: """Select deep thinking llm engine using an interactive selection.""" - - if provider.lower() == "openrouter": - return select_openrouter_model() - - choice = questionary.select( - "Select Your [Deep-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in get_model_options(provider, "deep") - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") - exit(1) - - return choice + return _select_model(provider, "deep") def select_llm_provider() -> tuple[str, str | None]: """Select the LLM provider and its API endpoint.""" - BASE_URLS = [ - ("OpenAI", "https://api.openai.com/v1"), - ("Google", None), # google-genai SDK manages its own endpoint - ("Anthropic", "https://api.anthropic.com/"), - ("xAI", "https://api.x.ai/v1"), - ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), + # (display_name, provider_key, base_url) + PROVIDERS = [ + ("OpenAI", "openai", "https://api.openai.com/v1"), + ("Google", "google", None), + ("Anthropic", "anthropic", "https://api.anthropic.com/"), + ("xAI", "xai", "https://api.x.ai/v1"), + ("DeepSeek", "deepseek", "https://api.deepseek.com"), + ("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"), + ("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"), + ("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"), + ("Azure OpenAI", "azure", None), + ("Ollama", "ollama", "http://localhost:11434/v1"), ] - + choice = questionary.select( "Select your LLM Provider:", choices=[ - questionary.Choice(display, value=(display, value)) - for display, value in BASE_URLS + questionary.Choice(display, value=(provider_key, url)) + for display, provider_key, url in PROVIDERS ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -261,13 +261,11 @@ def select_llm_provider() -> tuple[str, str | None]: ).ask() if choice is None: - console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") + console.print("\n[red]No LLM provider selected. Exiting...[/red]") exit(1) - - display_name, url = choice - print(f"You selected: {display_name}\tURL: {url}") - return display_name, url + provider, url = choice + return provider, url def ask_openai_reasoning_effort() -> str: diff --git a/tradingagents/llm_clients/azure_client.py b/tradingagents/llm_clients/azure_client.py new file mode 100644 index 000000000..0c0ae5a44 --- /dev/null +++ b/tradingagents/llm_clients/azure_client.py @@ -0,0 +1,52 @@ +import os +from typing import Any, Optional + +from langchain_openai import AzureChatOpenAI + +from .base_client import BaseLLMClient, normalize_content +from .validators import validate_model + +_PASSTHROUGH_KWARGS = ( + "timeout", "max_retries", "api_key", "reasoning_effort", + "callbacks", "http_client", "http_async_client", +) + + +class NormalizedAzureChatOpenAI(AzureChatOpenAI): + """AzureChatOpenAI with normalized content output.""" + + def invoke(self, input, config=None, **kwargs): + return normalize_content(super().invoke(input, config, **kwargs)) + + +class AzureOpenAIClient(BaseLLMClient): + """Client for Azure OpenAI deployments. + + Requires environment variables: + AZURE_OPENAI_API_KEY: API key + AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://.openai.azure.com/) + AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name + OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview) + """ + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured AzureChatOpenAI instance.""" + self.warn_if_unknown_model() + + llm_kwargs = { + "model": self.model, + "azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model), + } + + for key in _PASSTHROUGH_KWARGS: + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + return NormalizedAzureChatOpenAI(**llm_kwargs) + + def validate_model(self) -> bool: + """Azure accepts any deployed model name.""" + return True diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d34..a9a7e83d8 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,6 +4,12 @@ from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient +from .azure_client import AzureOpenAIClient + +# Providers that use the OpenAI-compatible chat completions API +_OPENAI_COMPATIBLE = ( + "openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter", +) def create_llm_client( @@ -15,16 +21,10 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) + provider: LLM provider name model: Model name/identifier base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments - - http_client: Custom httpx.Client for SSL proxy or certificate customization - - http_async_client: Custom httpx.AsyncClient for async operations - - timeout: Request timeout in seconds - - max_retries: Maximum retry attempts - - api_key: API key for the provider - - callbacks: LangChain callbacks Returns: Configured BaseLLMClient instance @@ -34,16 +34,16 @@ def create_llm_client( """ provider_lower = provider.lower() - if provider_lower in ("openai", "ollama", "openrouter"): + if provider_lower in _OPENAI_COMPATIBLE: return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) - if provider_lower == "xai": - return OpenAIClient(model, base_url, provider="xai", **kwargs) - if provider_lower == "anthropic": return AnthropicClient(model, base_url, **kwargs) if provider_lower == "google": return GoogleClient(model, base_url, **kwargs) + if provider_lower == "azure": + return AzureOpenAIClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index fd91c66db..a2c57ed89 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -63,8 +63,43 @@ MODEL_OPTIONS: ProviderModeOptions = { ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), ], }, - # OpenRouter models are fetched dynamically at CLI runtime. - # No static entries needed; any model ID is accepted by the validator. + "deepseek": { + "quick": [ + ("DeepSeek V3.2", "deepseek-chat"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("DeepSeek V3.2 (thinking)", "deepseek-reasoner"), + ("DeepSeek V3.2", "deepseek-chat"), + ("Custom model ID", "custom"), + ], + }, + "qwen": { + "quick": [ + ("Qwen 3.5 Flash", "qwen3.5-flash"), + ("Qwen Plus", "qwen-plus"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("Qwen 3.6 Plus", "qwen3.6-plus"), + ("Qwen 3.5 Plus", "qwen3.5-plus"), + ("Qwen 3 Max", "qwen3-max"), + ("Custom model ID", "custom"), + ], + }, + "glm": { + "quick": [ + ("GLM-4.7", "glm-4.7"), + ("GLM-5", "glm-5"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("GLM-5.1", "glm-5.1"), + ("GLM-5", "glm-5"), + ("Custom model ID", "custom"), + ], + }, + # OpenRouter: fetched dynamically. Azure: any deployed model name. "ollama": { "quick": [ ("Qwen3:latest (8B, local)", "qwen3:latest"), diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 4f2e1b32b..f943124a9 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -27,6 +27,9 @@ _PASSTHROUGH_KWARGS = ( # Provider base URLs and API key env vars _PROVIDER_CONFIG = { "xai": ("https://api.x.ai/v1", "XAI_API_KEY"), + "deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"), + "qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"), + "glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"), "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), "ollama": ("http://localhost:11434/v1", None), } From fa4d01c23acef4882fd74dd5be75dd3c7a4bc5f7 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 13 Apr 2026 07:21:33 +0000 Subject: [PATCH 11/44] fix: process all chunk messages for tool call logging, harden memory score normalization (#534, #531) --- cli/main.py | 40 +++++++++++++--------------- tradingagents/agents/utils/memory.py | 2 +- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/cli/main.py b/cli/main.py index 52e8a3327..33d110fb6 100644 --- a/cli/main.py +++ b/cli/main.py @@ -80,7 +80,7 @@ class MessageBuffer: self.current_agent = None self.report_sections = {} self.selected_analysts = [] - self._last_message_id = None + self._processed_message_ids = set() def init_for_analysis(self, selected_analysts): """Initialize agent status and report sections based on selected analysts. @@ -115,7 +115,7 @@ class MessageBuffer: self.current_agent = None self.messages.clear() self.tool_calls.clear() - self._last_message_id = None + self._processed_message_ids.clear() def get_completed_reports_count(self): """Count reports that are finalized (their finalizing agent is completed). @@ -1053,28 +1053,24 @@ def run_analysis(): # Stream the analysis trace = [] for chunk in graph.graph.stream(init_agent_state, **args): - # Process messages if present (skip duplicates via message ID) - if len(chunk["messages"]) > 0: - last_message = chunk["messages"][-1] - msg_id = getattr(last_message, "id", None) + # Process all messages in chunk, deduplicating by message ID + for message in chunk.get("messages", []): + msg_id = getattr(message, "id", None) + if msg_id is not None: + if msg_id in message_buffer._processed_message_ids: + continue + message_buffer._processed_message_ids.add(msg_id) - if msg_id != message_buffer._last_message_id: - message_buffer._last_message_id = msg_id + msg_type, content = classify_message_type(message) + if content and content.strip(): + message_buffer.add_message(msg_type, content) - # Add message to buffer - msg_type, content = classify_message_type(last_message) - if content and content.strip(): - message_buffer.add_message(msg_type, content) - - # Handle tool calls - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict): - message_buffer.add_tool_call( - tool_call["name"], tool_call["args"] - ) - else: - message_buffer.add_tool_call(tool_call.name, tool_call.args) + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + if isinstance(tool_call, dict): + message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) + else: + message_buffer.add_tool_call(tool_call.name, tool_call.args) # Update analyst statuses based on report state (runs on every chunk) update_analyst_statuses(message_buffer, chunk) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index d278b3c39..2aefa7a38 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -78,7 +78,7 @@ class FinancialSituationMemory: # Build results results = [] - max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores + max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0 for idx in top_indices: # Normalize score to 0-1 range for consistency From 8536ccacdd1cb05b2e8d2d4aa195d72c5aeb2e0d Mon Sep 17 00:00:00 2001 From: Zhigong Liu Date: Fri, 17 Apr 2026 18:57:31 -0400 Subject: [PATCH 12/44] chore: ignore CLAUDE.md (local AI assistant context file) Co-Authored-By: Claude Sonnet 4.6 --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 9a2904a9c..a8b1d6307 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Claude Code local files +CLAUDE.md + # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] From 6abc768c1d0d2d65263140ba6d30f086b4ea98f6 Mon Sep 17 00:00:00 2001 From: Zhigong Liu Date: Sun, 19 Apr 2026 22:13:53 -0400 Subject: [PATCH 13/44] feat: replace per-agent BM25 memory with persistent append-only decision log Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 1 - tests/test_memory_log.py | 648 ++++++++++++++++++ tradingagents/agents/__init__.py | 2 - .../agents/managers/portfolio_manager.py | 15 +- .../agents/managers/research_manager.py | 18 +- .../agents/researchers/bear_researcher.py | 12 +- .../agents/researchers/bull_researcher.py | 12 +- tradingagents/agents/trader/trader.py | 18 +- tradingagents/agents/utils/agent_states.py | 1 + tradingagents/agents/utils/memory.py | 364 ++++++---- tradingagents/default_config.py | 1 + tradingagents/graph/propagation.py | 3 +- tradingagents/graph/reflection.py | 137 +--- tradingagents/graph/setup.py | 28 +- tradingagents/graph/trading_graph.py | 126 +++- 15 files changed, 1046 insertions(+), 340 deletions(-) create mode 100644 tests/test_memory_log.py diff --git a/pyproject.toml b/pyproject.toml index 98385e32e..a1dfcd75e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "parsel>=1.10.0", "pytz>=2025.2", "questionary>=2.1.0", - "rank-bm25>=0.2.2", "redis>=6.2.0", "requests>=2.32.4", "rich>=14.0.0", diff --git a/tests/test_memory_log.py b/tests/test_memory_log.py new file mode 100644 index 000000000..c299e7fa5 --- /dev/null +++ b/tests/test_memory_log.py @@ -0,0 +1,648 @@ +"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal.""" + +import pytest +import pandas as pd +from unittest.mock import MagicMock, patch + +from tradingagents.agents.utils.memory import TradingMemoryLog +from tradingagents.graph.reflection import Reflector +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.graph.propagation import Propagator +from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager + +_SEP = TradingMemoryLog._SEPARATOR + +DECISION_BUY = "Rating: Buy\nEnter at $189-192, 6% portfolio cap." +DECISION_OVERWEIGHT = ( + "Rating: Overweight\n" + "Executive Summary: Moderate position, await confirmation.\n" + "Investment Thesis: Strong fundamentals but near-term headwinds." +) +DECISION_SELL = "Rating: Sell\nExit position immediately." +DECISION_NO_RATING = ( + "Executive Summary: Complex situation with multiple competing factors.\n" + "Investment Thesis: No clear directional signal at this time." +) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +def make_log(tmp_path, filename="trading_memory.md"): + config = {"memory_log_path": str(tmp_path / filename)} + return TradingMemoryLog(config) + + +def _seed_completed(tmp_path, ticker, date, decision_text, reflection_text, filename="trading_memory.md"): + """Write a completed entry directly to file, bypassing the API.""" + entry = ( + f"[{date} | {ticker} | Buy | +1.0% | +0.5% | 5d]\n\n" + f"DECISION:\n{decision_text}\n\n" + f"REFLECTION:\n{reflection_text}" + + _SEP + ) + with open(tmp_path / filename, "a", encoding="utf-8") as f: + f.write(entry) + + +def _resolve_entry(log, ticker, date, decision, reflection="Good call."): + """Store a decision then immediately resolve it via the API.""" + log.store_decision(ticker, date, decision) + log.update_with_outcome(ticker, date, 0.05, 0.02, 5, reflection) + + +def _price_df(prices): + """Minimal DataFrame matching yfinance .history() output shape.""" + return pd.DataFrame({"Close": prices}) + + +def _make_pm_state(past_context=""): + """Minimal AgentState dict for portfolio_manager_node.""" + return { + "company_of_interest": "NVDA", + "past_context": past_context, + "risk_debate_state": { + "history": "Risk debate history.", + "aggressive_history": "", + "conservative_history": "", + "neutral_history": "", + "judge_decision": "", + "current_aggressive_response": "", + "current_conservative_response": "", + "current_neutral_response": "", + "count": 1, + }, + "market_report": "Market report.", + "sentiment_report": "Sentiment report.", + "news_report": "News report.", + "fundamentals_report": "Fundamentals report.", + "investment_plan": "Research plan.", + "trader_investment_plan": "Trader plan.", + } + + +# --------------------------------------------------------------------------- +# Core: storage and read path +# --------------------------------------------------------------------------- + +class TestTradingMemoryLogCore: + + def test_store_creates_file(self, tmp_path): + log = make_log(tmp_path) + assert not (tmp_path / "trading_memory.md").exists() + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + assert (tmp_path / "trading_memory.md").exists() + + def test_store_appends_not_overwrites(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT) + entries = log.load_entries() + assert len(entries) == 2 + assert entries[0]["ticker"] == "NVDA" + assert entries[1]["ticker"] == "AAPL" + + def test_store_decision_idempotent(self, tmp_path): + """Calling store_decision twice with same (ticker, date) stores only one entry.""" + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + assert len(log.load_entries()) == 1 + + def test_batch_update_resolves_multiple_entries(self, tmp_path): + """batch_update_with_outcomes resolves multiple pending entries in one write.""" + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-05", DECISION_BUY) + log.store_decision("NVDA", "2026-01-12", DECISION_SELL) + + updates = [ + {"ticker": "NVDA", "trade_date": "2026-01-05", + "raw_return": 0.05, "alpha_return": 0.02, "holding_days": 5, + "reflection": "First correct."}, + {"ticker": "NVDA", "trade_date": "2026-01-12", + "raw_return": -0.03, "alpha_return": -0.01, "holding_days": 5, + "reflection": "Second correct."}, + ] + log.batch_update_with_outcomes(updates) + + entries = log.load_entries() + assert len(entries) == 2 + assert all(not e["pending"] for e in entries) + assert entries[0]["reflection"] == "First correct." + assert entries[1]["reflection"] == "Second correct." + + def test_pending_tag_format(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8") + assert "[2026-01-10 | NVDA | Buy | pending]" in text + + # Rating parsing + + def test_rating_parsed_buy(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + assert log.load_entries()[0]["rating"] == "Buy" + + def test_rating_parsed_overweight(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT) + assert log.load_entries()[0]["rating"] == "Overweight" + + def test_rating_fallback_hold(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING) + assert log.load_entries()[0]["rating"] == "Hold" + + def test_rating_priority_over_prose(self, tmp_path): + """'Rating: X' label wins even when an opposing rating word appears earlier in prose.""" + decision = ( + "The sell thesis is weak. The hold case is marginal.\n\n" + "Rating: Buy\n\n" + "Executive Summary: Strong fundamentals support the position." + ) + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", decision) + assert log.load_entries()[0]["rating"] == "Buy" + + # Delimiter robustness + + def test_decision_with_markdown_separator(self, tmp_path): + """LLM decision containing '---' must not corrupt the entry.""" + decision = "Rating: Buy\n\n---\n\nRisk: elevated volatility." + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", decision) + entries = log.load_entries() + assert len(entries) == 1 + assert "Risk: elevated volatility" in entries[0]["decision"] + + # load_entries + + def test_load_entries_empty_file(self, tmp_path): + log = make_log(tmp_path) + assert log.load_entries() == [] + + def test_load_entries_single(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + entries = log.load_entries() + assert len(entries) == 1 + e = entries[0] + assert e["date"] == "2026-01-10" + assert e["ticker"] == "NVDA" + assert e["rating"] == "Buy" + assert e["pending"] is True + assert e["raw"] is None + + def test_load_entries_multiple(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT) + log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING) + entries = log.load_entries() + assert len(entries) == 3 + assert [e["ticker"] for e in entries] == ["NVDA", "AAPL", "MSFT"] + + def test_decision_content_preserved(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + assert log.load_entries()[0]["decision"] == DECISION_BUY.strip() + + # get_pending_entries + + def test_get_pending_returns_pending_only(self, tmp_path): + log = make_log(tmp_path) + _seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA.", "Correct.") + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + pending = log.get_pending_entries() + assert len(pending) == 1 + assert pending[0]["ticker"] == "NVDA" + assert pending[0]["date"] == "2026-01-10" + + # get_past_context + + def test_get_past_context_empty(self, tmp_path): + log = make_log(tmp_path) + assert log.get_past_context("NVDA") == "" + + def test_get_past_context_pending_excluded(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + assert log.get_past_context("NVDA") == "" + + def test_get_past_context_same_ticker(self, tmp_path): + log = make_log(tmp_path) + _seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA — AI capex thesis intact.", "Directionally correct.") + ctx = log.get_past_context("NVDA") + assert "Past analyses of NVDA" in ctx + assert "Buy NVDA" in ctx + + def test_get_past_context_cross_ticker(self, tmp_path): + log = make_log(tmp_path) + _seed_completed(tmp_path, "AAPL", "2026-01-05", "Buy AAPL — Services growth.", "Correct.") + ctx = log.get_past_context("NVDA") + assert "Recent cross-ticker lessons" in ctx + assert "Past analyses of NVDA" not in ctx + + def test_n_same_limit_respected(self, tmp_path): + """Only the n_same most recent same-ticker entries are included.""" + log = make_log(tmp_path) + for i in range(6): + _seed_completed(tmp_path, "NVDA", f"2026-01-{i+1:02d}", f"Buy entry {i}.", "Correct.") + ctx = log.get_past_context("NVDA", n_same=5) + assert "Buy entry 0" not in ctx + assert "Buy entry 5" in ctx + + def test_n_cross_limit_respected(self, tmp_path): + """Only the n_cross most recent cross-ticker entries are included.""" + log = make_log(tmp_path) + for i, ticker in enumerate(["AAPL", "MSFT", "GOOG", "META"]): + _seed_completed(tmp_path, ticker, f"2026-01-{i+1:02d}", f"Buy {ticker}.", "Correct.") + ctx = log.get_past_context("NVDA", n_cross=3) + assert "AAPL" not in ctx + assert "META" in ctx + + # No-op when config is None + + def test_no_log_path_is_noop(self): + log = TradingMemoryLog(config=None) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + assert log.load_entries() == [] + assert log.get_past_context("NVDA") == "" + + # Rating parsing: markdown bold and numbered list formats + + def test_rating_parsed_from_bold_markdown(self, tmp_path): + """**Rating**: Buy — markdown bold wrapper must not prevent parsing.""" + decision = "**Rating**: Buy\nEnter at $190." + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", decision) + assert log.load_entries()[0]["rating"] == "Buy" + + def test_rating_parsed_from_numbered_list(self, tmp_path): + """1. Rating: Buy — numbered list prefix must not prevent parsing.""" + decision = "1. Rating: Buy\nEnter at $190." + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", decision) + assert log.load_entries()[0]["rating"] == "Buy" + + +# --------------------------------------------------------------------------- +# Deferred reflection: update_with_outcome, Reflector, _fetch_returns +# --------------------------------------------------------------------------- + +class TestDeferredReflection: + + # update_with_outcome + + def test_update_replaces_pending_tag(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.") + text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8") + assert "[2026-01-10 | NVDA | Buy | pending]" not in text + assert "+4.2%" in text + assert "+2.1%" in text + assert "5d" in text + + def test_update_appends_reflection(self, tmp_path): + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.") + entries = log.load_entries() + assert len(entries) == 1 + e = entries[0] + assert e["pending"] is False + assert e["reflection"] == "Momentum confirmed." + assert e["decision"] == DECISION_BUY.strip() + + def test_update_preserves_other_entries(self, tmp_path): + """Only the matching entry is modified; all other entries remain unchanged.""" + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + log.store_decision("AAPL", "2026-01-11", "Rating: Hold\nHold AAPL.") + log.store_decision("MSFT", "2026-01-12", DECISION_SELL) + log.update_with_outcome("AAPL", "2026-01-11", 0.01, -0.01, 5, "Neutral result.") + entries = log.load_entries() + assert len(entries) == 3 + nvda, aapl, msft = entries + assert nvda["ticker"] == "NVDA" and nvda["pending"] is True + assert aapl["ticker"] == "AAPL" and aapl["pending"] is False + assert aapl["reflection"] == "Neutral result." + assert msft["ticker"] == "MSFT" and msft["pending"] is True + + def test_update_atomic_write(self, tmp_path): + """A pre-existing .tmp file is overwritten; the log is correctly updated.""" + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + stale_tmp = tmp_path / "trading_memory.tmp" + stale_tmp.write_text("GARBAGE CONTENT — should be overwritten", encoding="utf-8") + log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Correct.") + assert not stale_tmp.exists() + entries = log.load_entries() + assert len(entries) == 1 + assert entries[0]["reflection"] == "Correct." + assert entries[0]["pending"] is False + + def test_update_noop_when_no_log_path(self): + log = TradingMemoryLog(config=None) + log.update_with_outcome("NVDA", "2026-01-10", 0.05, 0.02, 5, "Reflection") + + def test_formatting_roundtrip_after_update(self, tmp_path): + """All fields intact and blank line between tag and DECISION preserved after update.""" + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-10", DECISION_BUY) + log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.") + entries = log.load_entries() + assert len(entries) == 1 + e = entries[0] + assert e["pending"] is False + assert e["decision"] == DECISION_BUY.strip() + assert e["reflection"] == "Momentum confirmed." + assert e["raw"] == "+4.2%" + assert e["alpha"] == "+2.1%" + assert e["holding"] == "5d" + raw_text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8") + assert "[2026-01-10 | NVDA | Buy | +4.2% | +2.1% | 5d]\n\nDECISION:" in raw_text + + # Reflector.reflect_on_final_decision + + def test_reflect_on_final_decision_returns_llm_output(self): + mock_llm = MagicMock() + mock_llm.invoke.return_value.content = "Directionally correct. Thesis confirmed." + reflector = Reflector(mock_llm) + result = reflector.reflect_on_final_decision( + final_decision=DECISION_BUY, raw_return=0.042, alpha_return=0.021 + ) + assert result == "Directionally correct. Thesis confirmed." + mock_llm.invoke.assert_called_once() + + def test_reflect_on_final_decision_includes_returns_in_prompt(self): + """Return figures are present in the human message sent to the LLM.""" + mock_llm = MagicMock() + mock_llm.invoke.return_value.content = "Incorrect call." + reflector = Reflector(mock_llm) + reflector.reflect_on_final_decision( + final_decision=DECISION_SELL, raw_return=-0.08, alpha_return=-0.05 + ) + messages = mock_llm.invoke.call_args[0][0] + human_content = next(content for role, content in messages if role == "human") + assert "-8.0%" in human_content + assert "-5.0%" in human_content + assert "Exit position immediately." in human_content + + # TradingAgentsGraph._fetch_returns + + def test_fetch_returns_valid_ticker(self): + stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0] + spy_prices = [400.0, 402.0, 404.0, 403.0, 405.0, 406.0] + mock_graph = MagicMock(spec=TradingAgentsGraph) + with patch("yfinance.Ticker") as mock_ticker_cls: + def _make_ticker(sym): + m = MagicMock() + m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices) + return m + mock_ticker_cls.side_effect = _make_ticker + raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05") + assert raw is not None and alpha is not None and days is not None + assert isinstance(raw, float) and isinstance(alpha, float) and isinstance(days, int) + assert days == 5 + + def test_fetch_returns_too_recent(self): + """Only 1 data point available → returns (None, None, None), no crash.""" + mock_graph = MagicMock(spec=TradingAgentsGraph) + with patch("yfinance.Ticker") as mock_ticker_cls: + m = MagicMock() + m.history.return_value = _price_df([100.0]) + mock_ticker_cls.return_value = m + raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-04-19") + assert raw is None and alpha is None and days is None + + def test_fetch_returns_delisted(self): + """Empty DataFrame → returns (None, None, None), no crash.""" + mock_graph = MagicMock(spec=TradingAgentsGraph) + with patch("yfinance.Ticker") as mock_ticker_cls: + m = MagicMock() + m.history.return_value = pd.DataFrame({"Close": []}) + mock_ticker_cls.return_value = m + raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "XXXXXFAKE", "2026-01-10") + assert raw is None and alpha is None and days is None + + def test_fetch_returns_spy_shorter_than_stock(self): + """SPY having fewer rows than the stock must not raise IndexError.""" + stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0] + spy_prices = [400.0, 402.0, 403.0] + mock_graph = MagicMock(spec=TradingAgentsGraph) + with patch("yfinance.Ticker") as mock_ticker_cls: + def _make_ticker(sym): + m = MagicMock() + m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices) + return m + mock_ticker_cls.side_effect = _make_ticker + raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05") + assert raw is not None and alpha is not None and days is not None + assert days == 2 + + # TradingAgentsGraph._resolve_pending_entries + + def test_resolve_skips_other_tickers(self, tmp_path): + """Pending AAPL entry is not resolved when the run is for NVDA.""" + log = make_log(tmp_path) + log.store_decision("AAPL", "2026-01-10", DECISION_BUY) + mock_graph = MagicMock(spec=TradingAgentsGraph) + mock_graph.memory_log = log + mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5)) + TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA") + mock_graph._fetch_returns.assert_not_called() + assert len(log.get_pending_entries()) == 1 + + def test_resolve_marks_entry_completed(self, tmp_path): + """After resolve, get_pending_entries() is empty and the entry has a REFLECTION.""" + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-05", DECISION_BUY) + mock_reflector = MagicMock() + mock_reflector.reflect_on_final_decision.return_value = "Momentum confirmed." + mock_graph = MagicMock(spec=TradingAgentsGraph) + mock_graph.memory_log = log + mock_graph.reflector = mock_reflector + mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5)) + TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA") + assert log.get_pending_entries() == [] + entries = log.load_entries() + assert len(entries) == 1 + assert entries[0]["pending"] is False + assert entries[0]["reflection"] == "Momentum confirmed." + assert "+5.0%" in entries[0]["raw"] + assert "+2.0%" in entries[0]["alpha"] + + +# --------------------------------------------------------------------------- +# Portfolio Manager injection: past_context in state and prompt +# --------------------------------------------------------------------------- + +class TestPortfolioManagerInjection: + + # past_context in initial state + + def test_past_context_in_initial_state(self): + propagator = Propagator() + state = propagator.create_initial_state("NVDA", "2026-01-10", past_context="some context") + assert "past_context" in state + assert state["past_context"] == "some context" + + def test_past_context_defaults_to_empty(self): + propagator = Propagator() + state = propagator.create_initial_state("NVDA", "2026-01-10") + assert state["past_context"] == "" + + # PM prompt + + def test_pm_prompt_includes_past_context(self): + captured = {} + mock_llm = MagicMock() + mock_llm.invoke.side_effect = lambda prompt: ( + captured.__setitem__("prompt", prompt) or MagicMock(content="Rating: Hold\nHold.") + ) + pm_node = create_portfolio_manager(mock_llm) + state = _make_pm_state(past_context="[2026-01-05 | NVDA | Buy | +5.0% | +2.0% | 5d]\nGreat call.") + pm_node(state) + assert "Past decisions on this stock" in captured["prompt"] + assert "Great call." in captured["prompt"] + + def test_pm_no_past_context_no_section(self): + """PM prompt omits the lessons section entirely when past_context is empty.""" + captured = {} + mock_llm = MagicMock() + mock_llm.invoke.side_effect = lambda prompt: ( + captured.__setitem__("prompt", prompt) or MagicMock(content="Rating: Hold\nHold.") + ) + pm_node = create_portfolio_manager(mock_llm) + state = _make_pm_state(past_context="") + pm_node(state) + assert "Past decisions on this stock" not in captured["prompt"] + assert "lessons learned" not in captured["prompt"] + + # get_past_context ordering and limits + + def test_same_ticker_prioritised(self, tmp_path): + """Same-ticker entries in same-ticker section; cross-ticker entries in cross-ticker section.""" + log = make_log(tmp_path) + _resolve_entry(log, "NVDA", "2026-01-05", DECISION_BUY, "Momentum confirmed.") + _resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued.") + result = log.get_past_context("NVDA") + assert "Past analyses of NVDA" in result + assert "Recent cross-ticker lessons" in result + same_block, cross_block = result.split("Recent cross-ticker lessons") + assert "NVDA" in same_block + assert "AAPL" in cross_block + + def test_cross_ticker_reflection_only(self, tmp_path): + """Cross-ticker entries show only the REFLECTION text, not the full DECISION.""" + log = make_log(tmp_path) + _resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued correction.") + result = log.get_past_context("NVDA") + assert "Overvalued correction." in result + assert "Exit position immediately." not in result + + def test_n_same_limit_respected(self, tmp_path): + """More than 5 same-ticker completed entries → only 5 injected.""" + log = make_log(tmp_path) + for i in range(7): + _resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.") + result = log.get_past_context("NVDA", n_same=5) + lessons_present = sum(1 for i in range(7) if f"Lesson {i}." in result) + assert lessons_present == 5 + + def test_n_cross_limit_respected(self, tmp_path): + """More than 3 cross-ticker completed entries → only 3 injected.""" + log = make_log(tmp_path) + tickers = ["AAPL", "MSFT", "TSLA", "AMZN", "GOOG"] + for i, ticker in enumerate(tickers): + _resolve_entry(log, ticker, f"2026-01-{i+1:02d}", DECISION_BUY, f"{ticker} lesson.") + result = log.get_past_context("NVDA", n_cross=3) + cross_count = sum(result.count(f"{t} lesson.") for t in tickers) + assert cross_count == 3 + + # Full A→B→C integration cycle + + def test_full_cycle_store_resolve_inject(self, tmp_path): + """store pending → resolve with outcome → past_context non-empty for PM.""" + log = make_log(tmp_path) + log.store_decision("NVDA", "2026-01-05", DECISION_BUY) + assert len(log.get_pending_entries()) == 1 + assert log.get_past_context("NVDA") == "" + log.update_with_outcome("NVDA", "2026-01-05", 0.05, 0.02, 5, "Correct call.") + assert log.get_pending_entries() == [] + past_ctx = log.get_past_context("NVDA") + assert past_ctx != "" + assert "NVDA" in past_ctx + assert "Correct call." in past_ctx + assert "DECISION:" in past_ctx + assert "REFLECTION:" in past_ctx + + +# --------------------------------------------------------------------------- +# Legacy removal: BM25 / FinancialSituationMemory fully gone +# --------------------------------------------------------------------------- + +class TestLegacyRemoval: + + def test_financial_situation_memory_removed(self): + """FinancialSituationMemory must not be importable from the memory module.""" + import tradingagents.agents.utils.memory as m + assert not hasattr(m, "FinancialSituationMemory") + + def test_bm25_not_imported(self): + """rank_bm25 must not be present in the memory module namespace.""" + import tradingagents.agents.utils.memory as m + assert not hasattr(m, "BM25Okapi") + + def test_reflect_and_remember_removed(self): + """TradingAgentsGraph must not expose reflect_and_remember.""" + assert not hasattr(TradingAgentsGraph, "reflect_and_remember") + + def test_portfolio_manager_no_memory_param(self): + """create_portfolio_manager accepts only llm; passing memory= raises TypeError.""" + mock_llm = MagicMock() + create_portfolio_manager(mock_llm) + with pytest.raises(TypeError): + create_portfolio_manager(mock_llm, memory=MagicMock()) + + def test_full_pipeline_no_regression(self, tmp_path): + """propagate() completes without AttributeError after legacy cleanup.""" + fake_state = { + "final_trade_decision": "Rating: Buy\nBuy NVDA.", + "company_of_interest": "NVDA", + "trade_date": "2026-01-10", + "market_report": "", + "sentiment_report": "", + "news_report": "", + "fundamentals_report": "", + "investment_debate_state": { + "bull_history": "", "bear_history": "", "history": "", + "current_response": "", "judge_decision": "", + }, + "investment_plan": "", + "trader_investment_plan": "", + "risk_debate_state": { + "aggressive_history": "", "conservative_history": "", + "neutral_history": "", "history": "", "judge_decision": "", + "current_aggressive_response": "", "current_conservative_response": "", + "current_neutral_response": "", "count": 1, "latest_speaker": "", + }, + } + mock_graph = MagicMock() + mock_graph.memory_log = TradingMemoryLog({"memory_log_path": str(tmp_path / "mem.md")}) + mock_graph.log_states_dict = {} + mock_graph.debug = False + mock_graph.config = {"results_dir": str(tmp_path)} + mock_graph.graph.invoke.return_value = fake_state + mock_graph.propagator.create_initial_state.return_value = fake_state + mock_graph.propagator.get_graph_args.return_value = {} + mock_graph.signal_processor.process_signal.return_value = "Buy" + TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10") + entries = mock_graph.memory_log.load_entries() + assert len(entries) == 1 + assert entries[0]["ticker"] == "NVDA" + assert entries[0]["pending"] is True diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 1f03642c6..2fb4e1bac 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -1,6 +1,5 @@ from .utils.agent_utils import create_msg_delete from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState -from .utils.memory import FinancialSituationMemory from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.market_analyst import create_market_analyst @@ -20,7 +19,6 @@ from .managers.portfolio_manager import create_portfolio_manager from .trader.trader import create_trader __all__ = [ - "FinancialSituationMemory", "AgentState", "create_msg_delete", "InvestDebateState", diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py index 6c69ae9fd..215236156 100644 --- a/tradingagents/agents/managers/portfolio_manager.py +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -1,7 +1,7 @@ from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction -def create_portfolio_manager(llm, memory): +def create_portfolio_manager(llm): def portfolio_manager_node(state) -> dict: instrument_context = build_instrument_context(state["company_of_interest"]) @@ -15,12 +15,11 @@ def create_portfolio_manager(llm, memory): research_plan = state["investment_plan"] trader_plan = state["trader_investment_plan"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" + past_context = state.get("past_context", "") + lessons_line = ( + f"- Past decisions on this stock and lessons learned:\n{past_context}\n" + if past_context else "" + ) prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision. @@ -38,7 +37,7 @@ def create_portfolio_manager(llm, memory): **Context:** - Research Manager's investment plan: **{research_plan}** - Trader's transaction proposal: **{trader_plan}** -- Lessons from past decisions: **{past_memory_str}** +{lessons_line} **Required Output Structure:** 1. **Rating**: State one of Buy / Overweight / Hold / Underweight / Sell. diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 5b4b4fdc5..3902a60c4 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -2,24 +2,13 @@ from tradingagents.agents.utils.agent_utils import build_instrument_context -def create_research_manager(llm, memory): +def create_research_manager(llm): def research_manager_node(state) -> dict: instrument_context = build_instrument_context(state["company_of_interest"]) history = state["investment_debate_state"].get("history", "") - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] investment_debate_state = state["investment_debate_state"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented. Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments. @@ -29,10 +18,7 @@ Additionally, develop a detailed investment plan for the trader. This should inc Your Recommendation: A decisive stance supported by the most convincing arguments. Rationale: An explanation of why these arguments lead to your conclusion. Strategic Actions: Concrete steps for implementing the recommendation. -Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. - -Here are your past reflections on mistakes: -\"{past_memory_str}\" +Present your analysis conversationally, as if speaking naturally, without special formatting. {instrument_context} diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index a44212dc4..9cde9d39c 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -1,6 +1,6 @@ -def create_bear_researcher(llm, memory): +def create_bear_researcher(llm): def bear_node(state) -> dict: investment_debate_state = state["investment_debate_state"] history = investment_debate_state.get("history", "") @@ -12,13 +12,6 @@ def create_bear_researcher(llm, memory): news_report = state["news_report"] fundamentals_report = state["fundamentals_report"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively. Key points to focus on: @@ -37,8 +30,7 @@ Latest world affairs news: {news_report} Company fundamentals report: {fundamentals_report} Conversation history of the debate: {history} Last bull argument: {current_response} -Reflections from similar situations and lessons learned: {past_memory_str} -Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past. +Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. """ response = llm.invoke(prompt) diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index d23d4d76e..d16bc2371 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -1,6 +1,6 @@ -def create_bull_researcher(llm, memory): +def create_bull_researcher(llm): def bull_node(state) -> dict: investment_debate_state = state["investment_debate_state"] history = investment_debate_state.get("history", "") @@ -12,13 +12,6 @@ def create_bull_researcher(llm, memory): news_report = state["news_report"] fundamentals_report = state["fundamentals_report"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively. Key points to focus on: @@ -35,8 +28,7 @@ Latest world affairs news: {news_report} Company fundamentals report: {fundamentals_report} Conversation history of the debate: {history} Last bear argument: {current_response} -Reflections from similar situations and lessons learned: {past_memory_str} -Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past. +Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. """ response = llm.invoke(prompt) diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 07e9f262c..0ecae8888 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -3,25 +3,11 @@ import functools from tradingagents.agents.utils.agent_utils import build_instrument_context -def create_trader(llm, memory): +def create_trader(llm): def trader_node(state, name): company_name = state["company_of_interest"] instrument_context = build_instrument_context(company_name) investment_plan = state["investment_plan"] - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - if past_memories: - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - else: - past_memory_str = "No past memories found." context = { "role": "user", @@ -31,7 +17,7 @@ def create_trader(llm, memory): messages = [ { "role": "system", - "content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Apply lessons from past decisions to strengthen your analysis. Here are reflections from similar situations you traded in and the lessons learned: {past_memory_str}""", + "content": "You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation.", }, context, ] diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 6423b9363..6151a3863 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -70,3 +70,4 @@ class AgentState(MessagesState): RiskDebateState, "Current state of the debate on evaluating risk" ] final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] + past_context: Annotated[str, "Memory log context injected at run start (same-ticker decisions + cross-ticker lessons)"] diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 2aefa7a38..fd14449e3 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,144 +1,272 @@ -"""Financial situation memory using BM25 for lexical similarity matching. +"""Append-only markdown decision log for TradingAgents.""" -Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls, -no token limits, works offline with any LLM provider. -""" - -from rank_bm25 import BM25Okapi -from typing import List, Tuple +from typing import List, Optional +from pathlib import Path import re -class FinancialSituationMemory: - """Memory system for storing and retrieving financial situations using BM25.""" +class TradingMemoryLog: + """Append-only markdown log of trading decisions and reflections.""" - def __init__(self, name: str, config: dict = None): - """Initialize the memory system. + RATINGS = {"buy", "overweight", "hold", "underweight", "sell"} + # HTML comment: cannot appear in LLM prose output, safe as a hard delimiter + _SEPARATOR = "\n\n\n\n" + # Precompiled patterns — avoids re-compilation on every load_entries() call + _DECISION_RE = re.compile(r"DECISION:\n(.*?)(?=\nREFLECTION:|\Z)", re.DOTALL) + _REFLECTION_RE = re.compile(r"REFLECTION:\n(.*?)$", re.DOTALL) + _RATING_LABEL_RE = re.compile(r"rating.*?[:\-]\s*(\w+)", re.IGNORECASE) - Args: - name: Name identifier for this memory instance - config: Configuration dict (kept for API compatibility, not used for BM25) - """ - self.name = name - self.documents: List[str] = [] - self.recommendations: List[str] = [] - self.bm25 = None + def __init__(self, config: dict = None): + self._log_path = None + path = (config or {}).get("memory_log_path") + if path: + self._log_path = Path(path).expanduser() + self._log_path.parent.mkdir(parents=True, exist_ok=True) - def _tokenize(self, text: str) -> List[str]: - """Tokenize text for BM25 indexing. + # --- Write path (Phase A) --- - Simple whitespace + punctuation tokenization with lowercasing. - """ - # Lowercase and split on non-alphanumeric characters - tokens = re.findall(r'\b\w+\b', text.lower()) - return tokens + def store_decision( + self, + ticker: str, + trade_date: str, + final_trade_decision: str, + ) -> None: + """Append pending entry at end of propagate(). No LLM call.""" + if not self._log_path: + return + # Idempotency guard: fast raw-text scan instead of full parse + if self._log_path.exists(): + raw = self._log_path.read_text(encoding="utf-8") + for line in raw.splitlines(): + if line.startswith(f"[{trade_date} | {ticker} |") and line.endswith("| pending]"): + return + rating = self._parse_rating(final_trade_decision) + tag = f"[{trade_date} | {ticker} | {rating} | pending]" + entry = f"{tag}\n\nDECISION:\n{final_trade_decision}{self._SEPARATOR}" + with open(self._log_path, "a", encoding="utf-8") as f: + f.write(entry) - def _rebuild_index(self): - """Rebuild the BM25 index after adding documents.""" - if self.documents: - tokenized_docs = [self._tokenize(doc) for doc in self.documents] - self.bm25 = BM25Okapi(tokenized_docs) - else: - self.bm25 = None + # --- Read path (Phase A) --- - def add_situations(self, situations_and_advice: List[Tuple[str, str]]): - """Add financial situations and their corresponding advice. - - Args: - situations_and_advice: List of tuples (situation, recommendation) - """ - for situation, recommendation in situations_and_advice: - self.documents.append(situation) - self.recommendations.append(recommendation) - - # Rebuild BM25 index with new documents - self._rebuild_index() - - def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]: - """Find matching recommendations using BM25 similarity. - - Args: - current_situation: The current financial situation to match against - n_matches: Number of top matches to return - - Returns: - List of dicts with matched_situation, recommendation, and similarity_score - """ - if not self.documents or self.bm25 is None: + def load_entries(self) -> List[dict]: + """Parse all entries from log. Returns list of dicts.""" + if not self._log_path or not self._log_path.exists(): return [] + text = self._log_path.read_text(encoding="utf-8") + raw_entries = [e.strip() for e in text.split(self._SEPARATOR) if e.strip()] + entries = [] + for raw in raw_entries: + parsed = self._parse_entry(raw) + if parsed: + entries.append(parsed) + return entries - # Tokenize query - query_tokens = self._tokenize(current_situation) + def get_pending_entries(self) -> List[dict]: + """Return entries with outcome:pending (for Phase B).""" + return [e for e in self.load_entries() if e.get("pending")] - # Get BM25 scores for all documents - scores = self.bm25.get_scores(query_tokens) + def get_past_context(self, ticker: str, n_same: int = 5, n_cross: int = 3) -> str: + """Return formatted past context string for agent prompt injection.""" + entries = [e for e in self.load_entries() if not e.get("pending")] + if not entries: + return "" - # Get top-n indices sorted by score (descending) - top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_matches] + same, cross = [], [] + for e in reversed(entries): + if len(same) >= n_same and len(cross) >= n_cross: + break + if e["ticker"] == ticker and len(same) < n_same: + same.append(e) + elif e["ticker"] != ticker and len(cross) < n_cross: + cross.append(e) - # Build results - results = [] - max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0 + if not same and not cross: + return "" - for idx in top_indices: - # Normalize score to 0-1 range for consistency - normalized_score = scores[idx] / max_score if max_score > 0 else 0 - results.append({ - "matched_situation": self.documents[idx], - "recommendation": self.recommendations[idx], - "similarity_score": normalized_score, - }) + parts = [] + if same: + parts.append(f"Past analyses of {ticker} (most recent first):") + parts.extend(self._format_full(e) for e in same) + if cross: + parts.append("Recent cross-ticker lessons:") + parts.extend(self._format_reflection_only(e) for e in cross) + return "\n\n".join(parts) - return results + # --- Update path (Phase B) --- - def clear(self): - """Clear all stored memories.""" - self.documents = [] - self.recommendations = [] - self.bm25 = None + def update_with_outcome( + self, + ticker: str, + trade_date: str, + raw_return: float, + alpha_return: float, + holding_days: int, + reflection: str, + ) -> None: + """Replace pending tag and append REFLECTION section using atomic write. + Finds the first pending entry matching (trade_date, ticker), updates + its tag with return figures, and appends a REFLECTION section. Uses + a temp-file + os.replace() so a crash mid-write never corrupts the log. + """ + if not self._log_path or not self._log_path.exists(): + return -if __name__ == "__main__": - # Example usage - matcher = FinancialSituationMemory("test_memory") + text = self._log_path.read_text(encoding="utf-8") + blocks = text.split(self._SEPARATOR) - # Example data - example_data = [ - ( - "High inflation rate with rising interest rates and declining consumer spending", - "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", - ), - ( - "Tech sector showing high volatility with increasing institutional selling pressure", - "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", - ), - ( - "Strong dollar affecting emerging markets with increasing forex volatility", - "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", - ), - ( - "Market showing signs of sector rotation with rising yields", - "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", - ), - ] + pending_prefix = f"[{trade_date} | {ticker} |" + raw_pct = f"{raw_return:+.1%}" + alpha_pct = f"{alpha_return:+.1%}" - # Add the example situations and recommendations - matcher.add_situations(example_data) + updated = False + new_blocks = [] + for block in blocks: + stripped = block.strip() + if not stripped: + new_blocks.append(block) + continue - # Example query - current_situation = """ - Market showing increased volatility in tech sector, with institutional investors - reducing positions and rising interest rates affecting growth stock valuations - """ + lines = stripped.splitlines() + tag_line = lines[0].strip() - try: - recommendations = matcher.get_memories(current_situation, n_matches=2) + if ( + not updated + and tag_line.startswith(pending_prefix) + and tag_line.endswith("| pending]") + ): + # Parse rating from the existing pending tag + fields = [f.strip() for f in tag_line[1:-1].split("|")] + rating = fields[2] + new_tag = ( + f"[{trade_date} | {ticker} | {rating}" + f" | {raw_pct} | {alpha_pct} | {holding_days}d]" + ) + rest = "\n".join(lines[1:]) + new_blocks.append( + f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{reflection}" + ) + updated = True + else: + new_blocks.append(block) - for i, rec in enumerate(recommendations, 1): - print(f"\nMatch {i}:") - print(f"Similarity Score: {rec['similarity_score']:.2f}") - print(f"Matched Situation: {rec['matched_situation']}") - print(f"Recommendation: {rec['recommendation']}") + if not updated: + return - except Exception as e: - print(f"Error during recommendation: {str(e)}") + new_text = self._SEPARATOR.join(new_blocks) + tmp_path = self._log_path.with_suffix(".tmp") + tmp_path.write_text(new_text, encoding="utf-8") + tmp_path.replace(self._log_path) + + def batch_update_with_outcomes(self, updates: List[dict]) -> None: + """Apply multiple outcome updates in a single read + atomic write. + + Each element of updates must have keys: ticker, trade_date, + raw_return, alpha_return, holding_days, reflection. + """ + if not self._log_path or not self._log_path.exists() or not updates: + return + + text = self._log_path.read_text(encoding="utf-8") + blocks = text.split(self._SEPARATOR) + + # Build lookup keyed by (trade_date, ticker) for O(1) dispatch + update_map = {(u["trade_date"], u["ticker"]): u for u in updates} + + new_blocks = [] + for block in blocks: + stripped = block.strip() + if not stripped: + new_blocks.append(block) + continue + + lines = stripped.splitlines() + tag_line = lines[0].strip() + + matched = False + for (trade_date, ticker), upd in list(update_map.items()): + pending_prefix = f"[{trade_date} | {ticker} |" + if tag_line.startswith(pending_prefix) and tag_line.endswith("| pending]"): + fields = [f.strip() for f in tag_line[1:-1].split("|")] + rating = fields[2] + raw_pct = f"{upd['raw_return']:+.1%}" + alpha_pct = f"{upd['alpha_return']:+.1%}" + new_tag = ( + f"[{trade_date} | {ticker} | {rating}" + f" | {raw_pct} | {alpha_pct} | {upd['holding_days']}d]" + ) + rest = "\n".join(lines[1:]) + new_blocks.append( + f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{upd['reflection']}" + ) + del update_map[(trade_date, ticker)] + matched = True + break + + if not matched: + new_blocks.append(block) + + new_text = self._SEPARATOR.join(new_blocks) + tmp_path = self._log_path.with_suffix(".tmp") + tmp_path.write_text(new_text, encoding="utf-8") + tmp_path.replace(self._log_path) + + # --- Helpers --- + + def _parse_rating(self, text: str) -> str: + # First pass: explicit "Rating: X" label — search handles markdown bold/numbered lists + for line in text.splitlines(): + m = self._RATING_LABEL_RE.search(line) + if m and m.group(1).lower() in self.RATINGS: + return m.group(1).capitalize() + # Fallback: first rating word found anywhere in the text + for line in text.splitlines(): + for word in line.lower().split(): + clean = word.strip("*:.,") + if clean in self.RATINGS: + return clean.capitalize() + return "Hold" + + def _parse_entry(self, raw: str) -> Optional[dict]: + lines = raw.strip().splitlines() + if not lines: + return None + tag_line = lines[0].strip() + if not (tag_line.startswith("[") and tag_line.endswith("]")): + return None + fields = [f.strip() for f in tag_line[1:-1].split("|")] + if len(fields) < 4: + return None + entry = { + "date": fields[0], + "ticker": fields[1], + "rating": fields[2], + "pending": fields[3] == "pending", + "raw": fields[3] if fields[3] != "pending" else None, + "alpha": fields[4] if len(fields) > 4 else None, + "holding": fields[5] if len(fields) > 5 else None, + } + body = "\n".join(lines[1:]).strip() + decision_match = self._DECISION_RE.search(body) + reflection_match = self._REFLECTION_RE.search(body) + entry["decision"] = decision_match.group(1).strip() if decision_match else "" + entry["reflection"] = reflection_match.group(1).strip() if reflection_match else "" + return entry + + def _format_full(self, e: dict) -> str: + raw = e["raw"] or "n/a" + alpha = e["alpha"] or "n/a" + holding = e["holding"] or "n/a" + tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {raw} | {alpha} | {holding}]" + parts = [tag, f"DECISION:\n{e['decision']}"] + if e["reflection"]: + parts.append(f"REFLECTION:\n{e['reflection']}") + return "\n\n".join(parts) + + def _format_reflection_only(self, e: dict) -> str: + tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {e['raw'] or 'n/a'}]" + if e["reflection"]: + return f"{tag}\n{e['reflection']}" + text = e["decision"][:300] + suffix = "..." if len(e["decision"]) > 300 else "" + return f"{tag}\n{text}{suffix}" diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index a9b75e4be..274c29a51 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -6,6 +6,7 @@ DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")), "data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")), + "memory_log_path": os.path.join(_TRADINGAGENTS_HOME, "memory", "trading_memory.md"), # LLM settings "llm_provider": "openai", "deep_think_llm": "gpt-5.4", diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index 0fd10c0c3..2a5efb1ba 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -16,13 +16,14 @@ class Propagator: self.max_recur_limit = max_recur_limit def create_initial_state( - self, company_name: str, trade_date: str + self, company_name: str, trade_date: str, past_context: str = "" ) -> Dict[str, Any]: """Create the initial state for the agent graph.""" return { "messages": [("human", company_name)], "company_of_interest": company_name, "trade_date": str(trade_date), + "past_context": past_context, "investment_debate_state": InvestDebateState( { "bull_history": "", diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 2a680038c..813114428 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -1,120 +1,53 @@ # TradingAgents/graph/reflection.py -from typing import Any, Dict +from typing import Any class Reflector: - """Handles reflection on decisions and updating memory.""" + """Handles reflection on trading decisions.""" def __init__(self, quick_thinking_llm: Any): """Initialize the reflector with an LLM.""" self.quick_thinking_llm = quick_thinking_llm - self.reflection_system_prompt = self._get_reflection_prompt() + self.log_reflection_prompt = self._get_log_reflection_prompt() - def _get_reflection_prompt(self) -> str: - """Get the system prompt for reflection.""" - return """ -You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis. -Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines: + def _get_log_reflection_prompt(self) -> str: + """Concise prompt for reflect_on_final_decision (Phase B log entries). -1. Reasoning: - - For each trading decision, determine whether it was correct or incorrect. A correct decision results in an increase in returns, while an incorrect decision does the opposite. - - Analyze the contributing factors to each success or mistake. Consider: - - Market intelligence. - - Technical indicators. - - Technical signals. - - Price movement analysis. - - Overall market data analysis - - News analysis. - - Social media and sentiment analysis. - - Fundamental data analysis. - - Weight the importance of each factor in the decision-making process. + Produces 2-4 sentences of plain prose — compact enough to be re-injected + into future agent prompts without bloating the context window. + """ + return ( + "You are a trading analyst reviewing your own past decision now that the outcome is known.\n" + "Write exactly 2-4 sentences of plain prose (no bullets, no headers, no markdown).\n\n" + "Cover in order:\n" + "1. Was the directional call correct? (cite the alpha figure)\n" + "2. Which part of the investment thesis held or failed?\n" + "3. One concrete lesson to apply to the next similar analysis.\n\n" + "Be specific and terse. Your output will be stored verbatim in a decision log " + "and re-read by future analysts, so every word must earn its place." + ) -2. Improvement: - - For any incorrect decisions, propose revisions to maximize returns. - - Provide a detailed list of corrective actions or improvements, including specific recommendations (e.g., changing a decision from HOLD to BUY on a particular date). - -3. Summary: - - Summarize the lessons learned from the successes and mistakes. - - Highlight how these lessons can be adapted for future trading scenarios and draw connections between similar situations to apply the knowledge gained. - -4. Query: - - Extract key insights from the summary into a concise sentence of no more than 1000 tokens. - - Ensure the condensed sentence captures the essence of the lessons and reasoning for easy reference. - -Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis. -""" - - def _extract_current_situation(self, current_state: Dict[str, Any]) -> str: - """Extract the current market situation from the state.""" - curr_market_report = current_state["market_report"] - curr_sentiment_report = current_state["sentiment_report"] - curr_news_report = current_state["news_report"] - curr_fundamentals_report = current_state["fundamentals_report"] - - return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" - - def _reflect_on_component( - self, component_type: str, report: str, situation: str, returns_losses + def reflect_on_final_decision( + self, + final_decision: str, + raw_return: float, + alpha_return: float, ) -> str: - """Generate reflection for a component.""" + """Single reflection call on the final trade decision with outcome context. + + Used by Phase B deferred reflection. The final_trade_decision already + synthesises all analyst insights, so no separate market context is needed. + """ messages = [ - ("system", self.reflection_system_prompt), + ("system", self.log_reflection_prompt), ( "human", - f"Returns: {returns_losses}\n\nAnalysis/Decision: {report}\n\nObjective Market Reports for Reference: {situation}", + ( + f"Raw return: {raw_return:+.1%}\n" + f"Alpha vs SPY: {alpha_return:+.1%}\n\n" + f"Final Decision:\n{final_decision}" + ), ), ] - - result = self.quick_thinking_llm.invoke(messages).content - return result - - def reflect_bull_researcher(self, current_state, returns_losses, bull_memory): - """Reflect on bull researcher's analysis and update memory.""" - situation = self._extract_current_situation(current_state) - bull_debate_history = current_state["investment_debate_state"]["bull_history"] - - result = self._reflect_on_component( - "BULL", bull_debate_history, situation, returns_losses - ) - bull_memory.add_situations([(situation, result)]) - - def reflect_bear_researcher(self, current_state, returns_losses, bear_memory): - """Reflect on bear researcher's analysis and update memory.""" - situation = self._extract_current_situation(current_state) - bear_debate_history = current_state["investment_debate_state"]["bear_history"] - - result = self._reflect_on_component( - "BEAR", bear_debate_history, situation, returns_losses - ) - bear_memory.add_situations([(situation, result)]) - - def reflect_trader(self, current_state, returns_losses, trader_memory): - """Reflect on trader's decision and update memory.""" - situation = self._extract_current_situation(current_state) - trader_decision = current_state["trader_investment_plan"] - - result = self._reflect_on_component( - "TRADER", trader_decision, situation, returns_losses - ) - trader_memory.add_situations([(situation, result)]) - - def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory): - """Reflect on investment judge's decision and update memory.""" - situation = self._extract_current_situation(current_state) - judge_decision = current_state["investment_debate_state"]["judge_decision"] - - result = self._reflect_on_component( - "INVEST JUDGE", judge_decision, situation, returns_losses - ) - invest_judge_memory.add_situations([(situation, result)]) - - def reflect_portfolio_manager(self, current_state, returns_losses, portfolio_manager_memory): - """Reflect on portfolio manager's decision and update memory.""" - situation = self._extract_current_situation(current_state) - judge_decision = current_state["risk_debate_state"]["judge_decision"] - - result = self._reflect_on_component( - "PORTFOLIO MANAGER", judge_decision, situation, returns_losses - ) - portfolio_manager_memory.add_situations([(situation, result)]) + return self.quick_thinking_llm.invoke(messages).content diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index ae90489c1..1686fc5b0 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -18,22 +18,12 @@ class GraphSetup: quick_thinking_llm: Any, deep_thinking_llm: Any, tool_nodes: Dict[str, ToolNode], - bull_memory, - bear_memory, - trader_memory, - invest_judge_memory, - portfolio_manager_memory, conditional_logic: ConditionalLogic, ): """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm self.deep_thinking_llm = deep_thinking_llm self.tool_nodes = tool_nodes - self.bull_memory = bull_memory - self.bear_memory = bear_memory - self.trader_memory = trader_memory - self.invest_judge_memory = invest_judge_memory - self.portfolio_manager_memory = portfolio_manager_memory self.conditional_logic = conditional_logic def setup_graph( @@ -85,24 +75,16 @@ class GraphSetup: tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] # Create researcher and manager nodes - bull_researcher_node = create_bull_researcher( - self.quick_thinking_llm, self.bull_memory - ) - bear_researcher_node = create_bear_researcher( - self.quick_thinking_llm, self.bear_memory - ) - research_manager_node = create_research_manager( - self.deep_thinking_llm, self.invest_judge_memory - ) - trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) + bull_researcher_node = create_bull_researcher(self.quick_thinking_llm) + bear_researcher_node = create_bear_researcher(self.quick_thinking_llm) + research_manager_node = create_research_manager(self.deep_thinking_llm) + trader_node = create_trader(self.quick_thinking_llm) # Create risk analysis nodes aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm) neutral_analyst = create_neutral_debator(self.quick_thinking_llm) conservative_analyst = create_conservative_debator(self.quick_thinking_llm) - portfolio_manager_node = create_portfolio_manager( - self.deep_thinking_llm, self.portfolio_manager_memory - ) + portfolio_manager_node = create_portfolio_manager(self.deep_thinking_llm) # Create workflow workflow = StateGraph(AgentState) diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 78bc13e5f..cf5662598 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,18 +1,23 @@ # TradingAgents/graph/trading_graph.py +import logging import os from pathlib import Path import json -from datetime import date +from datetime import datetime, timedelta from typing import Dict, Any, Tuple, List, Optional +import yfinance as yf + +logger = logging.getLogger(__name__) + from langgraph.prebuilt import ToolNode from tradingagents.llm_clients import create_llm_client from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG -from tradingagents.agents.utils.memory import FinancialSituationMemory +from tradingagents.agents.utils.memory import TradingMemoryLog from tradingagents.agents.utils.agent_states import ( AgentState, InvestDebateState, @@ -92,12 +97,7 @@ class TradingAgentsGraph: self.deep_thinking_llm = deep_client.get_llm() self.quick_thinking_llm = quick_client.get_llm() - # Initialize memories - self.bull_memory = FinancialSituationMemory("bull_memory", self.config) - self.bear_memory = FinancialSituationMemory("bear_memory", self.config) - self.trader_memory = FinancialSituationMemory("trader_memory", self.config) - self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) - self.portfolio_manager_memory = FinancialSituationMemory("portfolio_manager_memory", self.config) + self.memory_log = TradingMemoryLog(self.config) # Create tool nodes self.tool_nodes = self._create_tool_nodes() @@ -111,11 +111,6 @@ class TradingAgentsGraph: self.quick_thinking_llm, self.deep_thinking_llm, self.tool_nodes, - self.bull_memory, - self.bear_memory, - self.trader_memory, - self.invest_judge_memory, - self.portfolio_manager_memory, self.conditional_logic, ) @@ -189,14 +184,90 @@ class TradingAgentsGraph: ), } + def _fetch_returns( + self, ticker: str, trade_date: str, holding_days: int = 5 + ) -> Tuple[Optional[float], Optional[float], Optional[int]]: + """Fetch raw and alpha return for ticker over holding_days from trade_date. + + Returns (raw_return, alpha_return, actual_holding_days) or + (None, None, None) if price data is unavailable (too recent, delisted, + or network error). + """ + try: + start = datetime.strptime(trade_date, "%Y-%m-%d") + end = start + timedelta(days=holding_days + 7) # buffer for weekends/holidays + end_str = end.strftime("%Y-%m-%d") + + stock = yf.Ticker(ticker).history(start=trade_date, end=end_str) + spy = yf.Ticker("SPY").history(start=trade_date, end=end_str) + + if len(stock) < 2 or len(spy) < 2: + return None, None, None + + actual_days = min(holding_days, len(stock) - 1, len(spy) - 1) + raw = float( + (stock["Close"].iloc[actual_days] - stock["Close"].iloc[0]) + / stock["Close"].iloc[0] + ) + spy_ret = float( + (spy["Close"].iloc[actual_days] - spy["Close"].iloc[0]) + / spy["Close"].iloc[0] + ) + alpha = raw - spy_ret + return raw, alpha, actual_days + except Exception as e: + logger.debug("_fetch_returns failed for %s@%s: %s", ticker, trade_date, e) + return None, None, None + + def _resolve_pending_entries(self, ticker: str) -> None: + """Resolve pending log entries for ticker at the start of a new run. + + Fetches returns for each same-ticker pending entry, generates reflections, + then writes all updates in a single atomic batch write to avoid redundant I/O. + Skips entries whose price data is not yet available (too recent or delisted). + + Trade-off: only same-ticker entries are resolved per run. Entries for + other tickers accumulate until that ticker is run again. + """ + pending = [e for e in self.memory_log.get_pending_entries() if e["ticker"] == ticker] + if not pending: + return + + updates = [] + for entry in pending: + raw, alpha, days = self._fetch_returns(ticker, entry["date"]) + if raw is None: + continue # price not available yet — try again next run + reflection = self.reflector.reflect_on_final_decision( + final_decision=entry.get("decision", ""), + raw_return=raw, + alpha_return=alpha, + ) + updates.append({ + "ticker": ticker, + "trade_date": entry["date"], + "raw_return": raw, + "alpha_return": alpha, + "holding_days": days, + "reflection": reflection, + }) + + if updates: + self.memory_log.batch_update_with_outcomes(updates) + def propagate(self, company_name, trade_date): """Run the trading agents graph for a company on a specific date.""" self.ticker = company_name - # Initialize state + # Resolve any pending log entries for this ticker before the pipeline runs. + # This adds the outcome + reflection from the previous run at zero latency cost. + self._resolve_pending_entries(company_name) + + # Initialize state — inject memory log context for PM + past_context = self.memory_log.get_past_context(company_name) init_agent_state = self.propagator.create_initial_state( - company_name, trade_date + company_name, trade_date, past_context=past_context ) args = self.propagator.get_graph_args() @@ -221,6 +292,13 @@ class TradingAgentsGraph: # Log state self._log_state(trade_date, final_state) + # Store decision for deferred reflection. + self.memory_log.store_decision( + ticker=company_name, + trade_date=trade_date, + final_trade_decision=final_state["final_trade_decision"], + ) + # Return decision and processed signal return final_state, self.process_signal(final_state["final_trade_decision"]) @@ -264,24 +342,6 @@ class TradingAgentsGraph: with open(log_path, "w", encoding="utf-8") as f: json.dump(self.log_states_dict[str(trade_date)], f, indent=4) - def reflect_and_remember(self, returns_losses): - """Reflect on decisions and update memory based on returns.""" - self.reflector.reflect_bull_researcher( - self.curr_state, returns_losses, self.bull_memory - ) - self.reflector.reflect_bear_researcher( - self.curr_state, returns_losses, self.bear_memory - ) - self.reflector.reflect_trader( - self.curr_state, returns_losses, self.trader_memory - ) - self.reflector.reflect_invest_judge( - self.curr_state, returns_losses, self.invest_judge_memory - ) - self.reflector.reflect_portfolio_manager( - self.curr_state, returns_losses, self.portfolio_manager_memory - ) - def process_signal(self, full_signal): """Process a signal to extract the core decision.""" return self.signal_processor.process_signal(full_signal) From 872b063e6917929335bafc20ef1070fba9e54a69 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 07:25:32 +0000 Subject: [PATCH 14/44] fix: use explicit encoding="utf-8" for all file I/O so Windows users avoid cp1252 crashes (#543, #550, #576) --- cli/main.py | 34 ++++++++++----------- tradingagents/__init__.py | 2 -- tradingagents/dataflows/stockstats_utils.py | 4 +-- tradingagents/dataflows/utils.py | 2 +- 4 files changed, 20 insertions(+), 22 deletions(-) diff --git a/cli/main.py b/cli/main.py index 33d110fb6..6e838fc8b 100644 --- a/cli/main.py +++ b/cli/main.py @@ -463,7 +463,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non def get_user_selections(): """Get all user selections before starting the analysis display.""" # Display ASCII art welcome message - with open(Path(__file__).parent / "static" / "welcome.txt", "r") as f: + with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f: welcome_ascii = f.read() # Create welcome box content @@ -646,19 +646,19 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): analyst_parts = [] if final_state.get("market_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "market.md").write_text(final_state["market_report"]) + (analysts_dir / "market.md").write_text(final_state["market_report"], encoding="utf-8") analyst_parts.append(("Market Analyst", final_state["market_report"])) if final_state.get("sentiment_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"]) + (analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"], encoding="utf-8") analyst_parts.append(("Social Analyst", final_state["sentiment_report"])) if final_state.get("news_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "news.md").write_text(final_state["news_report"]) + (analysts_dir / "news.md").write_text(final_state["news_report"], encoding="utf-8") analyst_parts.append(("News Analyst", final_state["news_report"])) if final_state.get("fundamentals_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"]) + (analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"], encoding="utf-8") analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"])) if analyst_parts: content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts) @@ -671,15 +671,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): research_parts = [] if debate.get("bull_history"): research_dir.mkdir(exist_ok=True) - (research_dir / "bull.md").write_text(debate["bull_history"]) + (research_dir / "bull.md").write_text(debate["bull_history"], encoding="utf-8") research_parts.append(("Bull Researcher", debate["bull_history"])) if debate.get("bear_history"): research_dir.mkdir(exist_ok=True) - (research_dir / "bear.md").write_text(debate["bear_history"]) + (research_dir / "bear.md").write_text(debate["bear_history"], encoding="utf-8") research_parts.append(("Bear Researcher", debate["bear_history"])) if debate.get("judge_decision"): research_dir.mkdir(exist_ok=True) - (research_dir / "manager.md").write_text(debate["judge_decision"]) + (research_dir / "manager.md").write_text(debate["judge_decision"], encoding="utf-8") research_parts.append(("Research Manager", debate["judge_decision"])) if research_parts: content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts) @@ -689,7 +689,7 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): if final_state.get("trader_investment_plan"): trading_dir = save_path / "3_trading" trading_dir.mkdir(exist_ok=True) - (trading_dir / "trader.md").write_text(final_state["trader_investment_plan"]) + (trading_dir / "trader.md").write_text(final_state["trader_investment_plan"], encoding="utf-8") sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}") # 4. Risk Management @@ -699,15 +699,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): risk_parts = [] if risk.get("aggressive_history"): risk_dir.mkdir(exist_ok=True) - (risk_dir / "aggressive.md").write_text(risk["aggressive_history"]) + (risk_dir / "aggressive.md").write_text(risk["aggressive_history"], encoding="utf-8") risk_parts.append(("Aggressive Analyst", risk["aggressive_history"])) if risk.get("conservative_history"): risk_dir.mkdir(exist_ok=True) - (risk_dir / "conservative.md").write_text(risk["conservative_history"]) + (risk_dir / "conservative.md").write_text(risk["conservative_history"], encoding="utf-8") risk_parts.append(("Conservative Analyst", risk["conservative_history"])) if risk.get("neutral_history"): risk_dir.mkdir(exist_ok=True) - (risk_dir / "neutral.md").write_text(risk["neutral_history"]) + (risk_dir / "neutral.md").write_text(risk["neutral_history"], encoding="utf-8") risk_parts.append(("Neutral Analyst", risk["neutral_history"])) if risk_parts: content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts) @@ -717,12 +717,12 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): if risk.get("judge_decision"): portfolio_dir = save_path / "5_portfolio" portfolio_dir.mkdir(exist_ok=True) - (portfolio_dir / "decision.md").write_text(risk["judge_decision"]) + (portfolio_dir / "decision.md").write_text(risk["judge_decision"], encoding="utf-8") sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}") # Write consolidated report header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - (save_path / "complete_report.md").write_text(header + "\n\n".join(sections)) + (save_path / "complete_report.md").write_text(header + "\n\n".join(sections), encoding="utf-8") return save_path / "complete_report.md" @@ -980,7 +980,7 @@ def run_analysis(): func(*args, **kwargs) timestamp, message_type, content = obj.messages[-1] content = content.replace("\n", " ") # Replace newlines with spaces - with open(log_file, "a") as f: + with open(log_file, "a", encoding="utf-8") as f: f.write(f"{timestamp} [{message_type}] {content}\n") return wrapper @@ -991,7 +991,7 @@ def run_analysis(): func(*args, **kwargs) timestamp, tool_name, args = obj.tool_calls[-1] args_str = ", ".join(f"{k}={v}" for k, v in args.items()) - with open(log_file, "a") as f: + with open(log_file, "a", encoding="utf-8") as f: f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") return wrapper @@ -1005,7 +1005,7 @@ def run_analysis(): if content: file_name = f"{section_name}.md" text = "\n".join(str(item) for item in content) if isinstance(content, list) else content - with open(report_dir / file_name, "w") as f: + with open(report_dir / file_name, "w", encoding="utf-8") as f: f.write(text) return wrapper diff --git a/tradingagents/__init__.py b/tradingagents/__init__.py index 43a2b4398..e69de29bb 100644 --- a/tradingagents/__init__.py +++ b/tradingagents/__init__.py @@ -1,2 +0,0 @@ -import os -os.environ.setdefault("PYTHONUTF8", "1") diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index 507478830..cb24c5d6a 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -67,7 +67,7 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame: ) if os.path.exists(data_file): - data = pd.read_csv(data_file, on_bad_lines="skip") + data = pd.read_csv(data_file, on_bad_lines="skip", encoding="utf-8") else: data = yf_retry(lambda: yf.download( symbol, @@ -78,7 +78,7 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame: auto_adjust=True, )) data = data.reset_index() - data.to_csv(data_file, index=False) + data.to_csv(data_file, index=False, encoding="utf-8") data = _clean_dataframe(data) diff --git a/tradingagents/dataflows/utils.py b/tradingagents/dataflows/utils.py index 4523de19f..c99b777ab 100644 --- a/tradingagents/dataflows/utils.py +++ b/tradingagents/dataflows/utils.py @@ -8,7 +8,7 @@ SavePathType = Annotated[str, "File path to save data. If None, data is not save def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None: if save_path: - data.to_csv(save_path) + data.to_csv(save_path, encoding="utf-8") print(f"{tag} saved to {save_path}") From 8e7654f0df735e89436de3adb1eab53ba33afeb8 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 07:41:36 +0000 Subject: [PATCH 15/44] fix: drop past-memory directive and placeholder from agent prompts when memory is empty (#572) --- .../agents/managers/portfolio_manager.py | 16 +++++++++++++--- .../agents/managers/research_manager.py | 13 +++++++------ .../agents/researchers/bear_researcher.py | 14 ++++++++++++-- .../agents/researchers/bull_researcher.py | 14 ++++++++++++-- tradingagents/agents/trader/trader.py | 10 +++++++--- 5 files changed, 51 insertions(+), 16 deletions(-) diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py index 6c69ae9fd..5d4631b82 100644 --- a/tradingagents/agents/managers/portfolio_manager.py +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -22,6 +22,17 @@ def create_portfolio_manager(llm, memory): for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" + lessons_line = ( + f"- Lessons from past decisions: **{past_memory_str.strip()}**\n" + if past_memories + else "" + ) + thesis_instruction = ( + "3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate and past reflections." + if past_memories + else "3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate." + ) + prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision. {instrument_context} @@ -38,12 +49,11 @@ def create_portfolio_manager(llm, memory): **Context:** - Research Manager's investment plan: **{research_plan}** - Trader's transaction proposal: **{trader_plan}** -- Lessons from past decisions: **{past_memory_str}** - +{lessons_line} **Required Output Structure:** 1. **Rating**: State one of Buy / Overweight / Hold / Underweight / Sell. 2. **Executive Summary**: A concise action plan covering entry strategy, position sizing, key risk levels, and time horizon. -3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate and past reflections. +{thesis_instruction} --- diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 5b4b4fdc5..03e32e492 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -20,6 +20,12 @@ def create_research_manager(llm, memory): for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" + past_memory_block = ( + f'Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. \n\nHere are your past reflections on mistakes:\n"{past_memory_str.strip()}"\n\n' + if past_memories + else "" + ) + prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented. Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments. @@ -29,12 +35,7 @@ Additionally, develop a detailed investment plan for the trader. This should inc Your Recommendation: A decisive stance supported by the most convincing arguments. Rationale: An explanation of why these arguments lead to your conclusion. Strategic Actions: Concrete steps for implementing the recommendation. -Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. - -Here are your past reflections on mistakes: -\"{past_memory_str}\" - -{instrument_context} +{past_memory_block}{instrument_context} Here is the debate: Debate History: diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index a44212dc4..e3922e8c3 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -19,6 +19,17 @@ def create_bear_researcher(llm, memory): for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" + memory_section = ( + f"Reflections from similar situations and lessons learned: {past_memory_str.strip()}\n" + if past_memories + else "" + ) + memory_instruction = ( + " You must also address reflections and learn from lessons and mistakes you made in the past." + if past_memories + else "" + ) + prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively. Key points to focus on: @@ -37,8 +48,7 @@ Latest world affairs news: {news_report} Company fundamentals report: {fundamentals_report} Conversation history of the debate: {history} Last bull argument: {current_response} -Reflections from similar situations and lessons learned: {past_memory_str} -Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past. +{memory_section}Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock.{memory_instruction} """ response = llm.invoke(prompt) diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index d23d4d76e..9724415d6 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -19,6 +19,17 @@ def create_bull_researcher(llm, memory): for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" + memory_section = ( + f"Reflections from similar situations and lessons learned: {past_memory_str.strip()}\n" + if past_memories + else "" + ) + memory_instruction = ( + " You must also address reflections and learn from lessons and mistakes you made in the past." + if past_memories + else "" + ) + prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively. Key points to focus on: @@ -35,8 +46,7 @@ Latest world affairs news: {news_report} Company fundamentals report: {fundamentals_report} Conversation history of the debate: {history} Last bear argument: {current_response} -Reflections from similar situations and lessons learned: {past_memory_str} -Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past. +{memory_section}Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position.{memory_instruction} """ response = llm.invoke(prompt) diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 07e9f262c..964733467 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -20,8 +20,12 @@ def create_trader(llm, memory): if past_memories: for i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" - else: - past_memory_str = "No past memories found." + + memory_instruction = ( + f" Apply lessons from past decisions to strengthen your analysis. Here are reflections from similar situations you traded in and the lessons learned: {past_memory_str.strip()}" + if past_memories + else "" + ) context = { "role": "user", @@ -31,7 +35,7 @@ def create_trader(llm, memory): messages = [ { "role": "system", - "content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Apply lessons from past decisions to strengthen your analysis. Here are reflections from similar situations you traded in and the lessons learned: {past_memory_str}""", + "content": f"You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation.{memory_instruction}", }, context, ] From f85f5d9f5d9243549fe7860be8611698051fda29 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 07:41:36 +0000 Subject: [PATCH 16/44] test: lazy-load LLM provider clients and add API-key fixtures so the test suite runs cleanly without credentials (#588) --- pyproject.toml | 12 ++++++++ tests/conftest.py | 42 ++++++++++++++++++++++++++++ tests/test_google_api_key.py | 3 ++ tests/test_model_validation.py | 3 ++ tests/test_ticker_symbol_handling.py | 3 ++ tradingagents/llm_clients/factory.py | 12 +++++--- 6 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 tests/conftest.py diff --git a/pyproject.toml b/pyproject.toml index 98385e32e..110611630 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,3 +40,15 @@ include = ["tradingagents*", "cli*"] [tool.setuptools.package-data] cli = ["static/*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-ra --strict-markers" +markers = [ + "unit: fast isolated unit tests", + "integration: tests requiring external services", + "smoke: quick sanity-check tests", +] +filterwarnings = [ + "ignore::DeprecationWarning", +] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..504ffb12d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,42 @@ +"""Shared pytest fixtures that prevent CI hangs when API keys are absent.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + + +def pytest_configure(config): + for marker in ("unit", "integration", "smoke"): + config.addinivalue_line("markers", f"{marker}: {marker}-level tests") + + +_API_KEY_ENV_VARS = ( + "OPENAI_API_KEY", + "GOOGLE_API_KEY", + "ANTHROPIC_API_KEY", + "XAI_API_KEY", + "DEEPSEEK_API_KEY", + "DASHSCOPE_API_KEY", + "ZHIPU_API_KEY", + "OPENROUTER_API_KEY", + "AZURE_OPENAI_API_KEY", + "ALPHA_VANTAGE_API_KEY", +) + + +@pytest.fixture(autouse=True) +def _dummy_api_keys(monkeypatch): + for env_var in _API_KEY_ENV_VARS: + monkeypatch.setenv(env_var, os.environ.get(env_var, "placeholder")) + + +@pytest.fixture() +def mock_llm_client(): + client = MagicMock() + client.get_llm.return_value = MagicMock() + with patch( + "tradingagents.llm_clients.factory.create_llm_client", + return_value=client, + ): + yield client diff --git a/tests/test_google_api_key.py b/tests/test_google_api_key.py index e1607c49a..53376ab10 100644 --- a/tests/test_google_api_key.py +++ b/tests/test_google_api_key.py @@ -1,9 +1,12 @@ import unittest from unittest.mock import patch +import pytest + from tradingagents.llm_clients.google_client import GoogleClient +@pytest.mark.unit class TestGoogleApiKeyStandardization(unittest.TestCase): """Verify GoogleClient accepts unified api_key parameter.""" diff --git a/tests/test_model_validation.py b/tests/test_model_validation.py index 50f263182..5392d7cd9 100644 --- a/tests/test_model_validation.py +++ b/tests/test_model_validation.py @@ -1,6 +1,8 @@ import unittest import warnings +import pytest + from tradingagents.llm_clients.base_client import BaseLLMClient from tradingagents.llm_clients.model_catalog import get_known_models from tradingagents.llm_clients.validators import validate_model @@ -19,6 +21,7 @@ class DummyLLMClient(BaseLLMClient): return validate_model(self.provider, self.model) +@pytest.mark.unit class ModelValidationTests(unittest.TestCase): def test_cli_catalog_models_are_all_validator_approved(self): for provider, models in get_known_models().items(): diff --git a/tests/test_ticker_symbol_handling.py b/tests/test_ticker_symbol_handling.py index 858d26cd5..7fbe5315d 100644 --- a/tests/test_ticker_symbol_handling.py +++ b/tests/test_ticker_symbol_handling.py @@ -1,9 +1,12 @@ import unittest +import pytest + from cli.utils import normalize_ticker_symbol from tradingagents.agents.utils.agent_utils import build_instrument_context +@pytest.mark.unit class TickerSymbolHandlingTests(unittest.TestCase): def test_normalize_ticker_symbol_preserves_exchange_suffix(self): self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO") diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index a9a7e83d8..e1d24557e 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -1,10 +1,6 @@ from typing import Optional from .base_client import BaseLLMClient -from .openai_client import OpenAIClient -from .anthropic_client import AnthropicClient -from .google_client import GoogleClient -from .azure_client import AzureOpenAIClient # Providers that use the OpenAI-compatible chat completions API _OPENAI_COMPATIBLE = ( @@ -20,6 +16,10 @@ def create_llm_client( ) -> BaseLLMClient: """Create an LLM client for the specified provider. + Provider modules are imported lazily so that simply importing this + factory (e.g. during test collection) does not pull in heavy LLM SDKs + or fail when their API keys are absent. + Args: provider: LLM provider name model: Model name/identifier @@ -35,15 +35,19 @@ def create_llm_client( provider_lower = provider.lower() if provider_lower in _OPENAI_COMPATIBLE: + from .openai_client import OpenAIClient return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) if provider_lower == "anthropic": + from .anthropic_client import AnthropicClient return AnthropicClient(model, base_url, **kwargs) if provider_lower == "google": + from .google_client import GoogleClient return GoogleClient(model, base_url, **kwargs) if provider_lower == "azure": + from .azure_client import AzureOpenAIClient return AzureOpenAIClient(model, base_url, **kwargs) raise ValueError(f"Unsupported LLM provider: {provider}") From 4cbd4b086fd94074d79987c8ca31daec8d33902c Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 08:39:27 +0000 Subject: [PATCH 17/44] feat: add LangGraph checkpoint resume for crash recovery (#594) Long analyses can take many minutes; a crash or interruption forced users to re-run from scratch and re-pay every LLM call. This adds an opt-in checkpoint layer backed by per-ticker SQLite databases so the graph resumes from the last successful node. How to use: - CLI: tradingagents analyze --checkpoint - CLI: tradingagents analyze --clear-checkpoints - Python: config["checkpoint_enabled"] = True Lifecycle: - propagate() recompiles the graph with a SqliteSaver when enabled and injects a deterministic thread_id derived from ticker+date so the same ticker+date resumes while a different date starts fresh. - On successful completion the per-thread checkpoint rows are cleared. - The context manager is closed in a try/finally so a crash never leaks the SQLite connection or leaves the graph in checkpoint mode. Storage: ~/.tradingagents/cache/checkpoints/.db (override via TRADINGAGENTS_CACHE_DIR). The checkpointer module is new (tradingagents/graph/checkpointer.py) and the GraphSetup now returns the uncompiled workflow so it can be recompiled with a saver when needed. Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify the crash/resume cycle and that a different date starts fresh. --- README.md | 32 +++++- cli/main.py | 22 +++- pyproject.toml | 1 + tests/test_checkpoint_resume.py | 147 +++++++++++++++++++++++++++ tests/test_memory_log.py | 9 +- tradingagents/default_config.py | 3 + tradingagents/graph/checkpointer.py | 86 ++++++++++++++++ tradingagents/graph/setup.py | 3 +- tradingagents/graph/trading_graph.py | 67 +++++++++--- 9 files changed, 349 insertions(+), 21 deletions(-) create mode 100644 tests/test_checkpoint_resume.py create mode 100644 tradingagents/graph/checkpointer.py diff --git a/README.md b/README.md index 97cbde486..6c8f644ec 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ An interface will appear showing results as they load, letting you track the age ### Implementation Details -We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama. +We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise. ### Python Usage @@ -207,7 +207,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG config = DEFAULT_CONFIG.copy() -config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama +config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, openrouter, ollama, azure config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks config["max_debate_rounds"] = 2 @@ -219,6 +219,34 @@ print(decision) See `tradingagents/default_config.py` for all configuration options. +## Persistence and Recovery + +TradingAgents persists two kinds of state across runs. + +### Decision log + +The decision log is always on. Each completed run appends its decision to `~/.tradingagents/memory/trading_memory.md`. On the next run for the same ticker, TradingAgents fetches the realised return (raw and alpha vs SPY), generates a one-paragraph reflection, and injects the most recent same-ticker decisions plus recent cross-ticker lessons into the Portfolio Manager prompt, so each analysis carries forward what worked and what didn't. + +Override the path with `TRADINGAGENTS_MEMORY_LOG_PATH`. + +### Checkpoint resume + +Checkpoint resume is opt-in via `--checkpoint`. When enabled, LangGraph saves state after each node so a crashed or interrupted run resumes from the last successful step instead of starting over. On a resume run you will see `Resuming from step N for on ` in the logs; on a new run you will see `Starting fresh`. Checkpoints are cleared automatically on successful completion. + +Per-ticker SQLite databases live at `~/.tradingagents/cache/checkpoints/.db` (override the base with `TRADINGAGENTS_CACHE_DIR`). Use `--clear-checkpoints` to reset all of them before a run. + +```bash +tradingagents analyze --checkpoint # enable for this run +tradingagents analyze --clear-checkpoints # reset before running +``` + +```python +config = DEFAULT_CONFIG.copy() +config["checkpoint_enabled"] = True +ta = TradingAgentsGraph(config=config) +_, decision = ta.propagate("NVDA", "2026-01-15") +``` + ## Contributing We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/). diff --git a/cli/main.py b/cli/main.py index 6e838fc8b..534f50379 100644 --- a/cli/main.py +++ b/cli/main.py @@ -926,7 +926,7 @@ def format_tool_args(args, max_length=80) -> str: return result[:max_length - 3] + "..." return result -def run_analysis(): +def run_analysis(checkpoint: bool = False): # First get all user selections selections = get_user_selections() @@ -943,6 +943,7 @@ def run_analysis(): config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["anthropic_effort"] = selections.get("anthropic_effort") config["output_language"] = selections.get("output_language", "English") + config["checkpoint_enabled"] = checkpoint # Create stats callback handler for tracking LLM/tool calls stats_handler = StatsCallbackHandler() @@ -1197,8 +1198,23 @@ def run_analysis(): @app.command() -def analyze(): - run_analysis() +def analyze( + checkpoint: bool = typer.Option( + False, + "--checkpoint", + help="Enable checkpoint/resume: save state after each node so a crashed run can resume.", + ), + clear_checkpoints: bool = typer.Option( + False, + "--clear-checkpoints", + help="Delete all saved checkpoints before running (force fresh start).", + ), +): + if clear_checkpoints: + from tradingagents.graph.checkpointer import clear_all_checkpoints + n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"]) + console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]") + run_analysis(checkpoint=checkpoint) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index b3dbc6feb..b569504ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "langchain-google-genai>=4.0.0", "langchain-openai>=0.3.23", "langgraph>=0.4.8", + "langgraph-checkpoint-sqlite>=2.0.0", "pandas>=2.3.0", "parsel>=1.10.0", "pytz>=2025.2", diff --git a/tests/test_checkpoint_resume.py b/tests/test_checkpoint_resume.py new file mode 100644 index 000000000..6f2692bd8 --- /dev/null +++ b/tests/test_checkpoint_resume.py @@ -0,0 +1,147 @@ +"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node.""" + +import sqlite3 +import tempfile +import unittest +from pathlib import Path +from typing import TypedDict + +from langgraph.checkpoint.sqlite import SqliteSaver +from langgraph.graph import END, StateGraph + +from tradingagents.graph.checkpointer import ( + checkpoint_step, + clear_checkpoint, + get_checkpointer, + has_checkpoint, + thread_id, +) + +# Mutable flag to simulate crash on first run +_should_crash = False + + +class _SimpleState(TypedDict): + count: int + + +def _node_a(state: _SimpleState) -> dict: + return {"count": state["count"] + 1} + + +def _node_b(state: _SimpleState) -> dict: + if _should_crash: + raise RuntimeError("simulated mid-analysis crash") + return {"count": state["count"] + 10} + + +def _build_graph() -> StateGraph: + builder = StateGraph(_SimpleState) + builder.add_node("analyst", _node_a) + builder.add_node("trader", _node_b) + builder.set_entry_point("analyst") + builder.add_edge("analyst", "trader") + builder.add_edge("trader", END) + return builder + + +class TestCheckpointResume(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.ticker = "TEST" + self.date = "2026-04-20" + + def test_crash_and_resume(self): + """Crash at 'trader' node, then resume from checkpoint.""" + global _should_crash + builder = _build_graph() + tid = thread_id(self.ticker, self.date) + cfg = {"configurable": {"thread_id": tid}} + + # Run 1: crash at trader node + _should_crash = True + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config=cfg) + + # Checkpoint should exist at step 1 (analyst completed) + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + step = checkpoint_step(self.tmpdir, self.ticker, self.date) + self.assertEqual(step, 1) + + # Run 2: resume — trader succeeds this time + _should_crash = False + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke(None, config=cfg) + + # analyst added 1, trader added 10 → 11 + self.assertEqual(result["count"], 11) + + def test_clear_checkpoint_allows_fresh_start(self): + """After clearing, the graph starts from scratch.""" + global _should_crash + builder = _build_graph() + tid = thread_id(self.ticker, self.date) + cfg = {"configurable": {"thread_id": tid}} + + # Create a checkpoint by crashing + _should_crash = True + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config=cfg) + + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # Clear it + clear_checkpoint(self.tmpdir, self.ticker, self.date) + self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # Fresh run succeeds from scratch + _should_crash = False + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke({"count": 0}, config=cfg) + + self.assertEqual(result["count"], 11) + + + def test_different_date_starts_fresh(self): + """A different date must NOT resume from an existing checkpoint.""" + global _should_crash + builder = _build_graph() + date2 = "2026-04-21" + + # Run with date1 — crash to leave a checkpoint + _should_crash = True + tid1 = thread_id(self.ticker, self.date) + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}}) + + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # date2 should have no checkpoint + self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2)) + + # Run with date2 — should start fresh and succeed + _should_crash = False + tid2 = thread_id(self.ticker, date2) + self.assertNotEqual(tid1, tid2) + + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}}) + + # Fresh run: analyst +1, trader +10 = 11 + self.assertEqual(result["count"], 11) + + # Original date checkpoint still exists (untouched) + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_memory_log.py b/tests/test_memory_log.py index 915bc3b15..ccd1ca7e7 100644 --- a/tests/test_memory_log.py +++ b/tests/test_memory_log.py @@ -629,7 +629,9 @@ class TestLegacyRemoval: create_portfolio_manager(mock_llm, memory=MagicMock()) def test_full_pipeline_no_regression(self, tmp_path): - """propagate() completes without AttributeError after legacy cleanup.""" + """propagate() completes and stores the decision after the redesign.""" + import functools + fake_state = { "final_trade_decision": "Rating: Buy\nBuy NVDA.", "company_of_interest": "NVDA", @@ -660,6 +662,11 @@ class TestLegacyRemoval: mock_graph.propagator.create_initial_state.return_value = fake_state mock_graph.propagator.get_graph_args.return_value = {} mock_graph.signal_processor.process_signal.return_value = "Buy" + # Bind the real _run_graph so propagate's call to self._run_graph executes + # the actual write path instead of the auto-MagicMock. + mock_graph._run_graph = functools.partial( + TradingAgentsGraph._run_graph, mock_graph + ) TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10") entries = mock_graph.memory_log.load_entries() assert len(entries) == 1 diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 19dbe1c7a..89b517659 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -16,6 +16,9 @@ DEFAULT_CONFIG = { "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" "anthropic_effort": None, # "high", "medium", "low" + # Checkpoint/resume: when True, LangGraph saves state after each node + # so a crashed run can resume from the last successful step. + "checkpoint_enabled": False, # Output language for analyst reports and final decision # Internal agent debate stays in English for reasoning quality "output_language": "English", diff --git a/tradingagents/graph/checkpointer.py b/tradingagents/graph/checkpointer.py new file mode 100644 index 000000000..7a73ee446 --- /dev/null +++ b/tradingagents/graph/checkpointer.py @@ -0,0 +1,86 @@ +"""LangGraph checkpoint support for resumable analysis runs. + +Per-ticker SQLite databases so concurrent tickers don't contend. +""" + +from __future__ import annotations + +import hashlib +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Generator + +from langgraph.checkpoint.sqlite import SqliteSaver + + +def _db_path(data_dir: str | Path, ticker: str) -> Path: + """Return the SQLite checkpoint DB path for a ticker.""" + p = Path(data_dir) / "checkpoints" + p.mkdir(parents=True, exist_ok=True) + return p / f"{ticker.upper()}.db" + + +def thread_id(ticker: str, date: str) -> str: + """Deterministic thread ID for a ticker+date pair.""" + return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16] + + +@contextmanager +def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]: + """Context manager yielding a SqliteSaver backed by a per-ticker DB.""" + db = _db_path(data_dir, ticker) + conn = sqlite3.connect(str(db), check_same_thread=False) + try: + saver = SqliteSaver(conn) + saver.setup() + yield saver + finally: + conn.close() + + +def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool: + """Check whether a resumable checkpoint exists for ticker+date.""" + return checkpoint_step(data_dir, ticker, date) is not None + + +def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None: + """Return the step number of the latest checkpoint, or None if none exists.""" + db = _db_path(data_dir, ticker) + if not db.exists(): + return None + tid = thread_id(ticker, date) + with get_checkpointer(data_dir, ticker) as saver: + config = {"configurable": {"thread_id": tid}} + cp = saver.get_tuple(config) + if cp is None: + return None + return cp.metadata.get("step") + + +def clear_all_checkpoints(data_dir: str | Path) -> int: + """Remove all checkpoint DBs. Returns number of files deleted.""" + cp_dir = Path(data_dir) / "checkpoints" + if not cp_dir.exists(): + return 0 + dbs = list(cp_dir.glob("*.db")) + for db in dbs: + db.unlink() + return len(dbs) + + +def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None: + """Remove checkpoint for a specific ticker+date by deleting the thread's rows.""" + db = _db_path(data_dir, ticker) + if not db.exists(): + return + tid = thread_id(ticker, date) + conn = sqlite3.connect(str(db)) + try: + for table in ("writes", "checkpoints"): + conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,)) + conn.commit() + except sqlite3.OperationalError: + pass + finally: + conn.close() diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 1686fc5b0..45d6bfd38 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -179,5 +179,4 @@ class GraphSetup: workflow.add_edge("Portfolio Manager", END) - # Compile and return - return workflow.compile() + return workflow diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 4f92b1889..bd6f1fc5c 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -38,6 +38,7 @@ from tradingagents.agents.utils.agent_utils import ( get_global_news ) +from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator @@ -123,8 +124,10 @@ class TradingAgentsGraph: self.ticker = None self.log_states_dict = {} # date to full state dict - # Set up the graph - self.graph = self.graph_setup.setup_graph(selected_analysts) + # Set up the graph: keep the workflow for recompilation with a checkpointer. + self.workflow = self.graph_setup.setup_graph(selected_analysts) + self.graph = self.workflow.compile() + self._checkpointer_ctx = None def _get_provider_kwargs(self) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" @@ -259,23 +262,58 @@ class TradingAgentsGraph: self.memory_log.batch_update_with_outcomes(updates) def propagate(self, company_name, trade_date): - """Run the trading agents graph for a company on a specific date.""" + """Run the trading agents graph for a company on a specific date. + When ``checkpoint_enabled`` is set in config, the graph is recompiled + with a per-ticker SqliteSaver so a crashed run can resume from the last + successful node on a subsequent invocation with the same ticker+date. + """ self.ticker = company_name - # Resolve any pending log entries for this ticker before the pipeline runs. - # This adds the outcome + reflection from the previous run at zero latency cost. + # Resolve any pending memory-log entries for this ticker before the pipeline runs. self._resolve_pending_entries(company_name) - # Initialize state — inject memory log context for PM + # Recompile with a checkpointer if the user opted in. + if self.config.get("checkpoint_enabled"): + self._checkpointer_ctx = get_checkpointer( + self.config["data_cache_dir"], company_name + ) + saver = self._checkpointer_ctx.__enter__() + self.graph = self.workflow.compile(checkpointer=saver) + + step = checkpoint_step( + self.config["data_cache_dir"], company_name, str(trade_date) + ) + if step is not None: + logger.info( + "Resuming from step %d for %s on %s", step, company_name, trade_date + ) + else: + logger.info("Starting fresh for %s on %s", company_name, trade_date) + + try: + return self._run_graph(company_name, trade_date) + finally: + if self._checkpointer_ctx is not None: + self._checkpointer_ctx.__exit__(None, None, None) + self._checkpointer_ctx = None + self.graph = self.workflow.compile() + + def _run_graph(self, company_name, trade_date): + """Execute the graph and write the resulting state to disk and memory log.""" + # Initialize state — inject memory log context for PM. past_context = self.memory_log.get_past_context(company_name) init_agent_state = self.propagator.create_initial_state( company_name, trade_date, past_context=past_context ) args = self.propagator.get_graph_args() + # Inject thread_id so same ticker+date resumes, different date starts fresh. + if self.config.get("checkpoint_enabled"): + tid = thread_id(company_name, str(trade_date)) + args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid + if self.debug: - # Debug mode with tracing trace = [] for chunk in self.graph.stream(init_agent_state, **args): if len(chunk["messages"]) == 0: @@ -283,26 +321,29 @@ class TradingAgentsGraph: else: chunk["messages"][-1].pretty_print() trace.append(chunk) - final_state = trace[-1] else: - # Standard mode without tracing final_state = self.graph.invoke(init_agent_state, **args) - # Store current state for reflection + # Store current state for reflection. self.curr_state = final_state - # Log state + # Log state to disk. self._log_state(trade_date, final_state) - # Store decision for deferred reflection. + # Store decision for deferred reflection on the next same-ticker run. self.memory_log.store_decision( ticker=company_name, trade_date=trade_date, final_trade_decision=final_state["final_trade_decision"], ) - # Return decision and processed signal + # Clear checkpoint on successful completion to avoid stale state. + if self.config.get("checkpoint_enabled"): + clear_checkpoint( + self.config["data_cache_dir"], company_name, str(trade_date) + ) + return final_state, self.process_signal(final_state["final_trade_decision"]) def _log_state(self, trade_date, final_state): From 0fda24515f56be44bdf568a19d85b20f69ee60b6 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 19:57:26 +0000 Subject: [PATCH 18/44] feat: structured-output Portfolio Manager + 5-tier rating consistency (#434) Three related changes that take the rating pipeline from heuristic-only to type-safe at the source. 1) Research Manager prompt now uses the same 5-tier scale (Buy / Overweight / Hold / Underweight / Sell) as the Portfolio Manager, signal_processing, and the memory log. The prior 3-tier wording (Buy / Sell / Hold) was the only remaining inconsistency in the pipeline. 2) Centralise the 5-tier vocabulary and the heuristic prose-rating parser into tradingagents/agents/utils/rating.py. Both the memory log and the signal processor now share the same parser instead of duplicating regex and word-walker logic. 3) Make structured output a first-class part of the Portfolio Manager's primary call. The PM uses llm.with_structured_output(PortfolioDecision) so each provider's native structured-output mode (json_schema for OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic, function_calling for OpenAI-compatible providers) yields a typed Pydantic instance directly. A render helper turns that instance back into the same markdown shape downstream consumers (memory log, CLI display, saved reports) already expect, so no other code has to know the PM now produces structured output. Providers without structured support fall back gracefully to free-text + the deterministic heuristic. The previous SignalProcessor had been making a second LLM call to re-extract the rating from the PM's prose; that round-trip is now eliminated. SignalProcessor is a thin adapter over parse_rating(), makes zero LLM calls, and stays for backwards compatibility with process_signal() callers. Schema (PortfolioDecision) captures rating + executive_summary + investment_thesis + optional price_target + time_horizon, with field descriptions doubling as output instructions. Agent prose remains the primary artifact; structured output is layered onto the PM only because it is the one agent whose output has machine-readable downstream consumers. 15 new tests cover the heuristic parser (markdown-bold edge cases that had no coverage before), the structured PM happy path, the free-text fallback path, and that SignalProcessor never invokes the LLM. Full suite: 77 tests pass in ~2s without API keys. --- tests/test_memory_log.py | 70 +++++++++++--- tests/test_signal_processing.py | 90 ++++++++++++++++++ .../agents/managers/portfolio_manager.py | 83 ++++++++++++----- .../agents/managers/research_manager.py | 34 ++++--- tradingagents/agents/schemas.py | 93 +++++++++++++++++++ tradingagents/agents/utils/memory.py | 20 +--- tradingagents/agents/utils/rating.py | 50 ++++++++++ tradingagents/graph/signal_processing.py | 46 +++++---- 8 files changed, 399 insertions(+), 87 deletions(-) create mode 100644 tests/test_signal_processing.py create mode 100644 tradingagents/agents/schemas.py create mode 100644 tradingagents/agents/utils/rating.py diff --git a/tests/test_memory_log.py b/tests/test_memory_log.py index ccd1ca7e7..e0da15efc 100644 --- a/tests/test_memory_log.py +++ b/tests/test_memory_log.py @@ -5,6 +5,7 @@ import pandas as pd from unittest.mock import MagicMock, patch from tradingagents.agents.utils.memory import TradingMemoryLog +from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating from tradingagents.graph.reflection import Reflector from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.propagation import Propagator @@ -82,6 +83,25 @@ def _make_pm_state(past_context=""): } +def _structured_pm_llm(captured: dict, decision: PortfolioDecision | None = None): + """Build a MagicMock LLM whose with_structured_output binding captures the + prompt and returns a real PortfolioDecision (so render_pm_decision works). + """ + if decision is None: + decision = PortfolioDecision( + rating=PortfolioRating.HOLD, + executive_summary="Hold the position; await catalyst.", + investment_thesis="Balanced view; neither side carried the debate.", + ) + structured = MagicMock() + structured.invoke.side_effect = lambda prompt: ( + captured.__setitem__("prompt", prompt) or decision + ) + llm = MagicMock() + llm.with_structured_output.return_value = structured + return llm + + # --------------------------------------------------------------------------- # Core: storage and read path # --------------------------------------------------------------------------- @@ -518,29 +538,55 @@ class TestPortfolioManagerInjection: def test_pm_prompt_includes_past_context(self): captured = {} - mock_llm = MagicMock() - mock_llm.invoke.side_effect = lambda prompt: ( - captured.__setitem__("prompt", prompt) or MagicMock(content="Rating: Hold\nHold.") - ) - pm_node = create_portfolio_manager(mock_llm) + llm = _structured_pm_llm(captured) + pm_node = create_portfolio_manager(llm) state = _make_pm_state(past_context="[2026-01-05 | NVDA | Buy | +5.0% | +2.0% | 5d]\nGreat call.") pm_node(state) assert "Lessons from prior decisions and outcomes" in captured["prompt"] assert "Great call." in captured["prompt"] - assert "and the lessons from prior decisions" in captured["prompt"] def test_pm_no_past_context_no_section(self): """PM prompt omits the lessons section entirely when past_context is empty.""" captured = {} - mock_llm = MagicMock() - mock_llm.invoke.side_effect = lambda prompt: ( - captured.__setitem__("prompt", prompt) or MagicMock(content="Rating: Hold\nHold.") - ) - pm_node = create_portfolio_manager(mock_llm) + llm = _structured_pm_llm(captured) + pm_node = create_portfolio_manager(llm) state = _make_pm_state(past_context="") pm_node(state) assert "Lessons from prior decisions" not in captured["prompt"] - assert "and the lessons from prior decisions" not in captured["prompt"] + + def test_pm_returns_rendered_markdown_with_rating(self): + """The structured PortfolioDecision is rendered to markdown that + downstream consumers (memory log, signal processor, CLI display) + can parse without any extra LLM call.""" + captured = {} + decision = PortfolioDecision( + rating=PortfolioRating.OVERWEIGHT, + executive_summary="Build position gradually over the next two weeks.", + investment_thesis="AI capex cycle remains intact; institutional flows constructive.", + price_target=215.0, + time_horizon="3-6 months", + ) + llm = _structured_pm_llm(captured, decision) + pm_node = create_portfolio_manager(llm) + result = pm_node(_make_pm_state()) + md = result["final_trade_decision"] + assert "**Rating**: Overweight" in md + assert "**Executive Summary**: Build position gradually" in md + assert "**Investment Thesis**: AI capex cycle" in md + assert "**Price Target**: 215.0" in md + assert "**Time Horizon**: 3-6 months" in md + + def test_pm_falls_back_to_freetext_when_structured_unavailable(self): + """If a provider does not support with_structured_output, the agent + falls back to a plain invoke and returns whatever prose the model + produced, so the pipeline never blocks.""" + plain_response = "**Rating**: Sell\n\nExit ahead of guidance." + llm = MagicMock() + llm.with_structured_output.side_effect = NotImplementedError("provider unsupported") + llm.invoke.return_value = MagicMock(content=plain_response) + pm_node = create_portfolio_manager(llm) + result = pm_node(_make_pm_state()) + assert result["final_trade_decision"] == plain_response # get_past_context ordering and limits diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py new file mode 100644 index 000000000..4bbfb7475 --- /dev/null +++ b/tests/test_signal_processing.py @@ -0,0 +1,90 @@ +"""Tests for the shared rating heuristic and the SignalProcessor adapter. + +The Portfolio Manager produces a typed PortfolioDecision via structured +output and renders it to markdown that always contains a ``**Rating**: X`` +header. The deterministic heuristic in ``tradingagents.agents.utils.rating`` +is therefore sufficient to extract the rating downstream — no second LLM +call is needed — and SignalProcessor is now a thin adapter that delegates +to it. +""" + +import pytest + +from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating +from tradingagents.graph.signal_processing import SignalProcessor + + +# --------------------------------------------------------------------------- +# Heuristic parser +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestParseRating: + def test_explicit_label_buy(self): + assert parse_rating("Rating: Buy\nReasoning here.") == "Buy" + + def test_explicit_label_overweight(self): + assert parse_rating("Rating: Overweight\nDetails.") == "Overweight" + + def test_explicit_label_with_markdown_bold_value(self): + # Regression: Rating: **Sell** — markdown around the value. + assert parse_rating("Rating: **Sell**\nExit immediately.") == "Sell" + + def test_explicit_label_with_markdown_bold_label(self): + assert parse_rating("**Rating**: Underweight\nTrim exposure.") == "Underweight" + + def test_rendered_pm_markdown_shape(self): + # The exact shape produced by render_pm_decision must always parse. + text = ( + "**Rating**: Buy\n\n" + "**Executive Summary**: Enter at $189-192, 6% portfolio cap.\n\n" + "**Investment Thesis**: AI capex cycle intact; institutional flows constructive." + ) + assert parse_rating(text) == "Buy" + + def test_explicit_label_wins_over_prose_with_markdown(self): + text = ( + "The buy thesis is weakened by guidance.\n" + "Rating: **Sell**\n" + "Exit before earnings." + ) + assert parse_rating(text) == "Sell" + + def test_no_rating_returns_default(self): + assert parse_rating("No clear directional signal at this time.") == "Hold" + + def test_no_rating_custom_default(self): + assert parse_rating("Plain prose.", default="Underweight") == "Underweight" + + def test_all_five_tiers_recognised(self): + for r in RATINGS_5_TIER: + assert parse_rating(f"Rating: {r}") == r + + +# --------------------------------------------------------------------------- +# SignalProcessor: thin adapter over the heuristic +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestSignalProcessor: + def test_returns_rating_from_pm_markdown(self): + sp = SignalProcessor() + md = "**Rating**: Overweight\n\n**Executive Summary**: Build gradually." + assert sp.process_signal(md) == "Overweight" + + def test_makes_no_llm_calls(self): + """SignalProcessor must not invoke the LLM it was constructed with — + the rating is parseable from the rendered PM markdown directly.""" + from unittest.mock import MagicMock + + llm = MagicMock() + sp = SignalProcessor(llm) + sp.process_signal("Rating: Buy\nDetails.") + llm.invoke.assert_not_called() + llm.with_structured_output.assert_not_called() + + def test_default_when_no_rating_present(self): + sp = SignalProcessor() + assert sp.process_signal("Plain prose without a recommendation.") == "Hold" diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py index 6780c7dc1..38ab840af 100644 --- a/tradingagents/agents/managers/portfolio_manager.py +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -1,29 +1,53 @@ -from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction +"""Portfolio Manager: synthesises the risk-analyst debate into the final decision. + +Uses LangChain's ``with_structured_output`` so the LLM produces a typed +``PortfolioDecision`` directly, in a single call. The result is rendered +back to markdown for storage in ``final_trade_decision`` so memory log, +CLI display, and saved reports continue to consume the same shape they do +today. When a provider does not expose structured output, the agent falls +back to a free-text invocation and the existing heuristic rating parser. +""" + +from __future__ import annotations + +import logging + +from tradingagents.agents.schemas import PortfolioDecision, render_pm_decision +from tradingagents.agents.utils.agent_utils import ( + build_instrument_context, + get_language_instruction, +) + +logger = logging.getLogger(__name__) def create_portfolio_manager(llm): - def portfolio_manager_node(state) -> dict: + # Wrap once at agent construction; if the provider does not support + # structured output we keep ``structured_llm`` as None and use the + # free-text fallback for every call. + try: + structured_llm = llm.with_structured_output(PortfolioDecision) + except (NotImplementedError, AttributeError) as exc: + logger.warning( + "Portfolio Manager: provider does not support with_structured_output (%s); " + "falling back to free-text generation", + exc, + ) + structured_llm = None + def portfolio_manager_node(state) -> dict: instrument_context = build_instrument_context(state["company_of_interest"]) history = state["risk_debate_state"]["history"] risk_debate_state = state["risk_debate_state"] - market_research_report = state["market_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - sentiment_report = state["sentiment_report"] research_plan = state["investment_plan"] trader_plan = state["trader_investment_plan"] past_context = state.get("past_context", "") lessons_line = ( f"- Lessons from prior decisions and outcomes:\n{past_context}\n" - if past_context else "" - ) - thesis_instruction = ( - "3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate and the lessons from prior decisions." if past_context - else "3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate." + else "" ) prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision. @@ -43,14 +67,6 @@ def create_portfolio_manager(llm): - Research Manager's investment plan: **{research_plan}** - Trader's transaction proposal: **{trader_plan}** {lessons_line} - -**Required Output Structure:** -1. **Rating**: State one of Buy / Overweight / Hold / Underweight / Sell. -2. **Executive Summary**: A concise action plan covering entry strategy, position sizing, key risk levels, and time horizon. -{thesis_instruction} - ---- - **Risk Analysts Debate History:** {history} @@ -58,10 +74,10 @@ def create_portfolio_manager(llm): Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}""" - response = llm.invoke(prompt) + final_trade_decision = _invoke_pm(structured_llm, llm, prompt) new_risk_debate_state = { - "judge_decision": response.content, + "judge_decision": final_trade_decision, "history": risk_debate_state["history"], "aggressive_history": risk_debate_state["aggressive_history"], "conservative_history": risk_debate_state["conservative_history"], @@ -75,7 +91,30 @@ Be decisive and ground every conclusion in specific evidence from the analysts.{ return { "risk_debate_state": new_risk_debate_state, - "final_trade_decision": response.content, + "final_trade_decision": final_trade_decision, } return portfolio_manager_node + + +def _invoke_pm(structured_llm, plain_llm, prompt: str) -> str: + """Run the PM call and return the markdown-rendered decision. + + Tries the structured-output path first; if it fails for any reason + (provider does not support it, model returns malformed JSON, network + glitch on the structured endpoint), falls back to the plain free-text + invocation so the pipeline still produces a result. + """ + if structured_llm is not None: + try: + decision = structured_llm.invoke(prompt) + return render_pm_decision(decision) + except Exception as exc: + logger.warning( + "Portfolio Manager: structured-output invocation failed (%s); " + "retrying once as free text", + exc, + ) + + response = plain_llm.invoke(prompt) + return response.content diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 3902a60c4..020b719e4 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -9,21 +9,31 @@ def create_research_manager(llm): investment_debate_state = state["investment_debate_state"] - prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented. - -Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments. - -Additionally, develop a detailed investment plan for the trader. This should include: - -Your Recommendation: A decisive stance supported by the most convincing arguments. -Rationale: An explanation of why these arguments lead to your conclusion. -Strategic Actions: Concrete steps for implementing the recommendation. -Present your analysis conversationally, as if speaking naturally, without special formatting. + prompt = f"""As the Research Manager and debate facilitator, your role is to critically evaluate this round of debate and deliver a clear, actionable investment plan for the trader. {instrument_context} -Here is the debate: -Debate History: +--- + +**Rating Scale** (use exactly one): +- **Buy**: Strong conviction in the bull thesis; recommend taking or growing the position +- **Overweight**: Constructive view; recommend gradually increasing exposure +- **Hold**: Balanced view; recommend maintaining the current position +- **Underweight**: Cautious view; recommend trimming exposure +- **Sell**: Strong conviction in the bear thesis; recommend exiting or avoiding the position + +Commit to a clear stance whenever the debate's strongest arguments warrant one; reserve Hold for situations where the evidence on both sides is genuinely balanced. + +**Required Output Structure:** +1. **Recommendation**: State one of Buy / Overweight / Hold / Underweight / Sell. +2. **Rationale**: Summarise the key points from both sides and explain which arguments led to this recommendation. +3. **Strategic Actions**: Concrete steps for the trader to implement the recommendation, including position sizing guidance consistent with the rating. + +Present your analysis conversationally, as if speaking naturally to a teammate. + +--- + +**Debate History:** {history}""" response = llm.invoke(prompt) diff --git a/tradingagents/agents/schemas.py b/tradingagents/agents/schemas.py new file mode 100644 index 000000000..aadcaf0b6 --- /dev/null +++ b/tradingagents/agents/schemas.py @@ -0,0 +1,93 @@ +"""Pydantic schemas used by agents that produce structured output. + +The framework's primary artifact is still prose: each agent's natural-language +reasoning is what users read, what gets stored in the memory log, and what +gets saved as markdown reports. Structured output is layered onto agents +whose results have downstream machine-readable consumers (currently only +the Portfolio Manager) so that: + +- The rating is type-safe and never has to be regex-extracted +- Schema field descriptions become the model's output instructions +- Each provider's native structured-output mode is used (json_schema for + OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic) +- A render helper turns the parsed Pydantic instance back into the same + markdown shape the rest of the system already consumes, so display, + memory log, and saved reports keep working unchanged +""" + +from __future__ import annotations + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class PortfolioRating(str, Enum): + """5-tier portfolio rating used by the Research Manager and Portfolio Manager.""" + + BUY = "Buy" + OVERWEIGHT = "Overweight" + HOLD = "Hold" + UNDERWEIGHT = "Underweight" + SELL = "Sell" + + +class PortfolioDecision(BaseModel): + """Structured output produced by the Portfolio Manager. + + The model fills every field as part of its primary LLM call; no separate + extraction pass is required. Field descriptions double as the model's + output instructions, so the prompt body only needs to convey context and + the rating-scale guidance. + """ + + rating: PortfolioRating = Field( + description=( + "The final position rating. Exactly one of Buy / Overweight / Hold / " + "Underweight / Sell, picked based on the analysts' debate." + ), + ) + executive_summary: str = Field( + description=( + "A concise action plan covering entry strategy, position sizing, " + "key risk levels, and time horizon. Two to four sentences." + ), + ) + investment_thesis: str = Field( + description=( + "Detailed reasoning anchored in specific evidence from the analysts' " + "debate. If prior lessons are referenced in the prompt context, " + "incorporate them; otherwise rely solely on the current analysis." + ), + ) + price_target: Optional[float] = Field( + default=None, + description="Optional target price in the instrument's quote currency.", + ) + time_horizon: Optional[str] = Field( + default=None, + description="Optional recommended holding period, e.g. '3-6 months'.", + ) + + +def render_pm_decision(decision: PortfolioDecision) -> str: + """Render a PortfolioDecision back to the markdown shape the rest of the system expects. + + Memory log, CLI display, and saved report files all read this markdown, + so the rendered output preserves the exact section headers (``**Rating**``, + ``**Executive Summary**``, ``**Investment Thesis**``) that downstream + parsers and the report writers already handle. + """ + parts = [ + f"**Rating**: {decision.rating.value}", + "", + f"**Executive Summary**: {decision.executive_summary}", + "", + f"**Investment Thesis**: {decision.investment_thesis}", + ] + if decision.price_target is not None: + parts.extend(["", f"**Price Target**: {decision.price_target}"]) + if decision.time_horizon: + parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"]) + return "\n".join(parts) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 199cb8946..fee5ac4a2 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -4,17 +4,17 @@ from typing import List, Optional from pathlib import Path import re +from tradingagents.agents.utils.rating import parse_rating + class TradingMemoryLog: """Append-only markdown log of trading decisions and reflections.""" - RATINGS = {"buy", "overweight", "hold", "underweight", "sell"} # HTML comment: cannot appear in LLM prose output, safe as a hard delimiter _SEPARATOR = "\n\n\n\n" # Precompiled patterns — avoids re-compilation on every load_entries() call _DECISION_RE = re.compile(r"DECISION:\n(.*?)(?=\nREFLECTION:|\Z)", re.DOTALL) _REFLECTION_RE = re.compile(r"REFLECTION:\n(.*?)$", re.DOTALL) - _RATING_LABEL_RE = re.compile(r"rating.*?[:\-][\s*]*(\w+)", re.IGNORECASE) def __init__(self, config: dict = None): self._log_path = None @@ -40,7 +40,7 @@ class TradingMemoryLog: for line in raw.splitlines(): if line.startswith(f"[{trade_date} | {ticker} |") and line.endswith("| pending]"): return - rating = self._parse_rating(final_trade_decision) + rating = parse_rating(final_trade_decision) tag = f"[{trade_date} | {ticker} | {rating} | pending]" entry = f"{tag}\n\nDECISION:\n{final_trade_decision}{self._SEPARATOR}" with open(self._log_path, "a", encoding="utf-8") as f: @@ -213,20 +213,6 @@ class TradingMemoryLog: # --- Helpers --- - def _parse_rating(self, text: str) -> str: - # First pass: explicit "Rating: X" label — search handles markdown bold/numbered lists - for line in text.splitlines(): - m = self._RATING_LABEL_RE.search(line) - if m and m.group(1).lower() in self.RATINGS: - return m.group(1).capitalize() - # Fallback: first rating word found anywhere in the text - for line in text.splitlines(): - for word in line.lower().split(): - clean = word.strip("*:.,") - if clean in self.RATINGS: - return clean.capitalize() - return "Hold" - def _parse_entry(self, raw: str) -> Optional[dict]: lines = raw.strip().splitlines() if not lines: diff --git a/tradingagents/agents/utils/rating.py b/tradingagents/agents/utils/rating.py new file mode 100644 index 000000000..d5032346a --- /dev/null +++ b/tradingagents/agents/utils/rating.py @@ -0,0 +1,50 @@ +"""Shared 5-tier rating vocabulary and a deterministic heuristic parser. + +The same five-tier scale (Buy, Overweight, Hold, Underweight, Sell) is used by: +- The Research Manager (investment plan recommendation) +- The Portfolio Manager (final position decision) +- The signal processor (rating extracted for downstream consumers) +- The memory log (rating tag stored alongside each decision entry) + +Centralising it here avoids drift between those call sites. +""" + +from __future__ import annotations + +import re +from typing import Tuple + + +# Canonical, ordered 5-tier scale (most bullish to most bearish). +RATINGS_5_TIER: Tuple[str, ...] = ( + "Buy", "Overweight", "Hold", "Underweight", "Sell", +) + +_RATING_SET = {r.lower() for r in RATINGS_5_TIER} + +# Matches "Rating: X" / "rating - X" / "Rating: **X**" — tolerates markdown +# bold wrappers and either a colon or hyphen separator. +_RATING_LABEL_RE = re.compile(r"rating.*?[:\-][\s*]*(\w+)", re.IGNORECASE) + + +def parse_rating(text: str, default: str = "Hold") -> str: + """Heuristically extract a 5-tier rating from prose text. + + Two-pass strategy: + 1. Look for an explicit "Rating: X" label (tolerant of markdown bold). + 2. Fall back to the first 5-tier rating word found anywhere in the text. + + Returns a Title-cased rating string, or ``default`` if no rating word appears. + """ + for line in text.splitlines(): + m = _RATING_LABEL_RE.search(line) + if m and m.group(1).lower() in _RATING_SET: + return m.group(1).capitalize() + + for line in text.splitlines(): + for word in line.lower().split(): + clean = word.strip("*:.,") + if clean in _RATING_SET: + return clean.capitalize() + + return default diff --git a/tradingagents/graph/signal_processing.py b/tradingagents/graph/signal_processing.py index 5ac66c1dd..90fafd04b 100644 --- a/tradingagents/graph/signal_processing.py +++ b/tradingagents/graph/signal_processing.py @@ -1,33 +1,31 @@ -# TradingAgents/graph/signal_processing.py +"""Extract the 5-tier portfolio rating from the Portfolio Manager's decision. + +The Portfolio Manager produces a typed ``PortfolioDecision`` via structured +output and renders it to markdown that always carries a ``**Rating**: X`` +header (see :func:`tradingagents.agents.schemas.render_pm_decision`). The +deterministic heuristic in :mod:`tradingagents.agents.utils.rating` is more +than sufficient to extract that rating; no extra LLM call is needed. + +This module exists for backwards compatibility with callers that expect a +``SignalProcessor.process_signal(text)`` interface. +""" + +from __future__ import annotations from typing import Any +from tradingagents.agents.utils.rating import parse_rating + class SignalProcessor: - """Processes trading signals to extract actionable decisions.""" + """Read the 5-tier rating out of a Portfolio Manager decision.""" - def __init__(self, quick_thinking_llm: Any): - """Initialize with an LLM for processing.""" + def __init__(self, quick_thinking_llm: Any = None): + # The LLM argument is accepted for backwards compatibility but no + # longer used: the PM's structured output guarantees the rating is + # parseable from the rendered markdown without a second LLM call. self.quick_thinking_llm = quick_thinking_llm def process_signal(self, full_signal: str) -> str: - """ - Process a full trading signal to extract the core decision. - - Args: - full_signal: Complete trading signal text - - Returns: - Extracted rating (BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, or SELL) - """ - messages = [ - ( - "system", - "You are an efficient assistant that extracts the trading decision from analyst reports. " - "Extract the rating as exactly one of: BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, SELL. " - "Output only the single rating word, nothing else.", - ), - ("human", full_signal), - ] - - return self.quick_thinking_llm.invoke(messages).content + """Return one of Buy / Overweight / Hold / Underweight / Sell.""" + return parse_rating(full_signal) From bba147798f9b440bcf08fe3f3db3ca8464f9c73a Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 20:27:23 +0000 Subject: [PATCH 19/44] feat: structured-output Trader and Research Manager (#434, finishes the trio) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the canonical structured-output pattern from the Portfolio Manager to the other two decision-making agents. Each of the three agents now returns a typed Pydantic instance via llm.with_structured_output() in a single primary call, and a render helper turns the result into the same markdown shape downstream agents and saved reports already consume. - ResearchPlan: 5-tier recommendation, conversational rationale, concrete strategic actions for the trader. - TraderProposal: 3-tier action (transaction direction is naturally Buy / Hold / Sell — position sizing happens later at the Portfolio Manager), reasoning, and optional entry_price / stop_loss / position_sizing. Rendered output preserves the trailing "FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**" line for backward compatibility with the analyst stop-signal text. - PortfolioDecision: 5-tier rating, executive summary, investment thesis, optional price_target / time_horizon (unchanged). The shared try-structured-then-fallback pattern is extracted into tradingagents/agents/utils/structured.py (bind_structured + invoke_structured_or_freetext) so all three agents go through the same code path and log the same warning when a provider lacks structured output and the agent falls back to free-text generation. Net effect for users: every saved markdown report (research/manager.md, trading/trader.md, portfolio/decision.md) now has consistent section headers across runs and providers, easier to scan. Net effect for the runtime: the rating extraction round-trip is gone — the rating comes from the structured response itself, not a second LLM call. SignalProcessor was already simplified to a heuristic adapter in the previous commit. 11 new tests in tests/test_structured_agents.py cover the Trader and Research Manager render functions, structured-output happy paths, and free-text fallback. Full suite: 88 tests pass in ~2s without API keys. --- tests/test_structured_agents.py | 232 ++++++++++++++++++ .../agents/managers/portfolio_manager.py | 54 +--- .../agents/managers/research_manager.py | 32 ++- tradingagents/agents/schemas.py | 149 ++++++++++- tradingagents/agents/trader/trader.py | 48 +++- tradingagents/agents/utils/structured.py | 73 ++++++ 6 files changed, 519 insertions(+), 69 deletions(-) create mode 100644 tests/test_structured_agents.py create mode 100644 tradingagents/agents/utils/structured.py diff --git a/tests/test_structured_agents.py b/tests/test_structured_agents.py new file mode 100644 index 000000000..ea771a4b0 --- /dev/null +++ b/tests/test_structured_agents.py @@ -0,0 +1,232 @@ +"""Tests for structured-output agents (Trader and Research Manager). + +The Portfolio Manager has its own coverage in tests/test_memory_log.py +(which exercises the full memory-log → PM injection cycle). This file +covers the parallel schemas, render functions, and graceful-fallback +behavior we added for the Trader and Research Manager so all three +decision-making agents share the same shape. +""" + +from unittest.mock import MagicMock + +import pytest + +from tradingagents.agents.managers.research_manager import create_research_manager +from tradingagents.agents.schemas import ( + PortfolioRating, + ResearchPlan, + TraderAction, + TraderProposal, + render_research_plan, + render_trader_proposal, +) +from tradingagents.agents.trader.trader import create_trader + + +# --------------------------------------------------------------------------- +# Render functions +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestRenderTraderProposal: + def test_minimal_required_fields(self): + p = TraderProposal(action=TraderAction.HOLD, reasoning="Balanced setup; no edge.") + md = render_trader_proposal(p) + assert "**Action**: Hold" in md + assert "**Reasoning**: Balanced setup; no edge." in md + # The trailing FINAL TRANSACTION PROPOSAL line is preserved for the + # analyst stop-signal text and any external code that greps for it. + assert "FINAL TRANSACTION PROPOSAL: **HOLD**" in md + + def test_optional_fields_included_when_present(self): + p = TraderProposal( + action=TraderAction.BUY, + reasoning="Strong technicals + fundamentals.", + entry_price=189.5, + stop_loss=178.0, + position_sizing="6% of portfolio", + ) + md = render_trader_proposal(p) + assert "**Action**: Buy" in md + assert "**Entry Price**: 189.5" in md + assert "**Stop Loss**: 178.0" in md + assert "**Position Sizing**: 6% of portfolio" in md + assert "FINAL TRANSACTION PROPOSAL: **BUY**" in md + + def test_optional_fields_omitted_when_absent(self): + p = TraderProposal(action=TraderAction.SELL, reasoning="Guidance cut.") + md = render_trader_proposal(p) + assert "Entry Price" not in md + assert "Stop Loss" not in md + assert "Position Sizing" not in md + assert "FINAL TRANSACTION PROPOSAL: **SELL**" in md + + +@pytest.mark.unit +class TestRenderResearchPlan: + def test_required_fields(self): + p = ResearchPlan( + recommendation=PortfolioRating.OVERWEIGHT, + rationale="Bull case carried; tailwinds intact.", + strategic_actions="Build position over two weeks; cap at 5%.", + ) + md = render_research_plan(p) + assert "**Recommendation**: Overweight" in md + assert "**Rationale**: Bull case carried" in md + assert "**Strategic Actions**: Build position" in md + + def test_all_5_tier_ratings_render(self): + for rating in PortfolioRating: + p = ResearchPlan( + recommendation=rating, + rationale="r", + strategic_actions="s", + ) + md = render_research_plan(p) + assert f"**Recommendation**: {rating.value}" in md + + +# --------------------------------------------------------------------------- +# Trader agent: structured happy path + fallback +# --------------------------------------------------------------------------- + + +def _make_trader_state(): + return { + "company_of_interest": "NVDA", + "investment_plan": "**Recommendation**: Buy\n**Rationale**: ...\n**Strategic Actions**: ...", + } + + +def _structured_trader_llm(captured: dict, proposal: TraderProposal | None = None): + """Build a MagicMock LLM whose with_structured_output binding captures the + prompt and returns a real TraderProposal so render_trader_proposal works. + """ + if proposal is None: + proposal = TraderProposal( + action=TraderAction.BUY, + reasoning="Strong setup.", + ) + structured = MagicMock() + structured.invoke.side_effect = lambda prompt: ( + captured.__setitem__("prompt", prompt) or proposal + ) + llm = MagicMock() + llm.with_structured_output.return_value = structured + return llm + + +@pytest.mark.unit +class TestTraderAgent: + def test_structured_path_produces_rendered_markdown(self): + captured = {} + proposal = TraderProposal( + action=TraderAction.BUY, + reasoning="AI capex cycle intact; institutional flows constructive.", + entry_price=189.5, + stop_loss=178.0, + position_sizing="6% of portfolio", + ) + llm = _structured_trader_llm(captured, proposal) + trader = create_trader(llm) + result = trader(_make_trader_state()) + plan = result["trader_investment_plan"] + assert "**Action**: Buy" in plan + assert "**Entry Price**: 189.5" in plan + assert "FINAL TRANSACTION PROPOSAL: **BUY**" in plan + # The same rendered markdown is also added to messages for downstream agents. + assert plan in result["messages"][0].content + + def test_prompt_includes_investment_plan(self): + captured = {} + llm = _structured_trader_llm(captured) + trader = create_trader(llm) + trader(_make_trader_state()) + # The investment plan is in the user message of the captured prompt. + prompt = captured["prompt"] + assert any("Proposed Investment Plan" in m["content"] for m in prompt) + + def test_falls_back_to_freetext_when_structured_unavailable(self): + plain_response = ( + "**Action**: Sell\n\nGuidance cut hits margins.\n\n" + "FINAL TRANSACTION PROPOSAL: **SELL**" + ) + llm = MagicMock() + llm.with_structured_output.side_effect = NotImplementedError("provider unsupported") + llm.invoke.return_value = MagicMock(content=plain_response) + trader = create_trader(llm) + result = trader(_make_trader_state()) + assert result["trader_investment_plan"] == plain_response + + +# --------------------------------------------------------------------------- +# Research Manager agent: structured happy path + fallback +# --------------------------------------------------------------------------- + + +def _make_rm_state(): + return { + "company_of_interest": "NVDA", + "investment_debate_state": { + "history": "Bull and bear arguments here.", + "bull_history": "Bull says...", + "bear_history": "Bear says...", + "current_response": "", + "judge_decision": "", + "count": 1, + }, + } + + +def _structured_rm_llm(captured: dict, plan: ResearchPlan | None = None): + if plan is None: + plan = ResearchPlan( + recommendation=PortfolioRating.HOLD, + rationale="Balanced view across both sides.", + strategic_actions="Hold current position; reassess after earnings.", + ) + structured = MagicMock() + structured.invoke.side_effect = lambda prompt: ( + captured.__setitem__("prompt", prompt) or plan + ) + llm = MagicMock() + llm.with_structured_output.return_value = structured + return llm + + +@pytest.mark.unit +class TestResearchManagerAgent: + def test_structured_path_produces_rendered_markdown(self): + captured = {} + plan = ResearchPlan( + recommendation=PortfolioRating.OVERWEIGHT, + rationale="Bull case is stronger; AI tailwind intact.", + strategic_actions="Build position gradually over two weeks.", + ) + llm = _structured_rm_llm(captured, plan) + rm = create_research_manager(llm) + result = rm(_make_rm_state()) + ip = result["investment_plan"] + assert "**Recommendation**: Overweight" in ip + assert "**Rationale**: Bull case" in ip + assert "**Strategic Actions**: Build position" in ip + + def test_prompt_uses_5_tier_rating_scale(self): + """The RM prompt must list all five tiers so the schema enum matches user expectations.""" + captured = {} + llm = _structured_rm_llm(captured) + rm = create_research_manager(llm) + rm(_make_rm_state()) + prompt = captured["prompt"] + for tier in ("Buy", "Overweight", "Hold", "Underweight", "Sell"): + assert f"**{tier}**" in prompt, f"missing {tier} in prompt" + + def test_falls_back_to_freetext_when_structured_unavailable(self): + plain_response = "**Recommendation**: Sell\n\n**Rationale**: ...\n\n**Strategic Actions**: ..." + llm = MagicMock() + llm.with_structured_output.side_effect = NotImplementedError("provider unsupported") + llm.invoke.return_value = MagicMock(content=plain_response) + rm = create_research_manager(llm) + result = rm(_make_rm_state()) + assert result["investment_plan"] == plain_response diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py index 38ab840af..0e7c18234 100644 --- a/tradingagents/agents/managers/portfolio_manager.py +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -5,35 +5,24 @@ Uses LangChain's ``with_structured_output`` so the LLM produces a typed back to markdown for storage in ``final_trade_decision`` so memory log, CLI display, and saved reports continue to consume the same shape they do today. When a provider does not expose structured output, the agent falls -back to a free-text invocation and the existing heuristic rating parser. +back gracefully to free-text generation. """ from __future__ import annotations -import logging - from tradingagents.agents.schemas import PortfolioDecision, render_pm_decision from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_language_instruction, ) - -logger = logging.getLogger(__name__) +from tradingagents.agents.utils.structured import ( + bind_structured, + invoke_structured_or_freetext, +) def create_portfolio_manager(llm): - # Wrap once at agent construction; if the provider does not support - # structured output we keep ``structured_llm`` as None and use the - # free-text fallback for every call. - try: - structured_llm = llm.with_structured_output(PortfolioDecision) - except (NotImplementedError, AttributeError) as exc: - logger.warning( - "Portfolio Manager: provider does not support with_structured_output (%s); " - "falling back to free-text generation", - exc, - ) - structured_llm = None + structured_llm = bind_structured(llm, PortfolioDecision, "Portfolio Manager") def portfolio_manager_node(state) -> dict: instrument_context = build_instrument_context(state["company_of_interest"]) @@ -74,7 +63,13 @@ def create_portfolio_manager(llm): Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}""" - final_trade_decision = _invoke_pm(structured_llm, llm, prompt) + final_trade_decision = invoke_structured_or_freetext( + structured_llm, + llm, + prompt, + render_pm_decision, + "Portfolio Manager", + ) new_risk_debate_state = { "judge_decision": final_trade_decision, @@ -95,26 +90,3 @@ Be decisive and ground every conclusion in specific evidence from the analysts.{ } return portfolio_manager_node - - -def _invoke_pm(structured_llm, plain_llm, prompt: str) -> str: - """Run the PM call and return the markdown-rendered decision. - - Tries the structured-output path first; if it fails for any reason - (provider does not support it, model returns malformed JSON, network - glitch on the structured endpoint), falls back to the plain free-text - invocation so the pipeline still produces a result. - """ - if structured_llm is not None: - try: - decision = structured_llm.invoke(prompt) - return render_pm_decision(decision) - except Exception as exc: - logger.warning( - "Portfolio Manager: structured-output invocation failed (%s); " - "retrying once as free text", - exc, - ) - - response = plain_llm.invoke(prompt) - return response.content diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 020b719e4..0e2206b2e 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -1,8 +1,18 @@ +"""Research Manager: turns the bull/bear debate into a structured investment plan for the trader.""" +from __future__ import annotations + +from tradingagents.agents.schemas import ResearchPlan, render_research_plan from tradingagents.agents.utils.agent_utils import build_instrument_context +from tradingagents.agents.utils.structured import ( + bind_structured, + invoke_structured_or_freetext, +) def create_research_manager(llm): + structured_llm = bind_structured(llm, ResearchPlan, "Research Manager") + def research_manager_node(state) -> dict: instrument_context = build_instrument_context(state["company_of_interest"]) history = state["investment_debate_state"].get("history", "") @@ -24,31 +34,31 @@ def create_research_manager(llm): Commit to a clear stance whenever the debate's strongest arguments warrant one; reserve Hold for situations where the evidence on both sides is genuinely balanced. -**Required Output Structure:** -1. **Recommendation**: State one of Buy / Overweight / Hold / Underweight / Sell. -2. **Rationale**: Summarise the key points from both sides and explain which arguments led to this recommendation. -3. **Strategic Actions**: Concrete steps for the trader to implement the recommendation, including position sizing guidance consistent with the rating. - -Present your analysis conversationally, as if speaking naturally to a teammate. - --- **Debate History:** {history}""" - response = llm.invoke(prompt) + + investment_plan = invoke_structured_or_freetext( + structured_llm, + llm, + prompt, + render_research_plan, + "Research Manager", + ) new_investment_debate_state = { - "judge_decision": response.content, + "judge_decision": investment_plan, "history": investment_debate_state.get("history", ""), "bear_history": investment_debate_state.get("bear_history", ""), "bull_history": investment_debate_state.get("bull_history", ""), - "current_response": response.content, + "current_response": investment_plan, "count": investment_debate_state["count"], } return { "investment_debate_state": new_investment_debate_state, - "investment_plan": response.content, + "investment_plan": investment_plan, } return research_manager_node diff --git a/tradingagents/agents/schemas.py b/tradingagents/agents/schemas.py index aadcaf0b6..55f0e3cfb 100644 --- a/tradingagents/agents/schemas.py +++ b/tradingagents/agents/schemas.py @@ -1,15 +1,16 @@ """Pydantic schemas used by agents that produce structured output. The framework's primary artifact is still prose: each agent's natural-language -reasoning is what users read, what gets stored in the memory log, and what -gets saved as markdown reports. Structured output is layered onto agents -whose results have downstream machine-readable consumers (currently only -the Portfolio Manager) so that: +reasoning is what users read in the saved markdown reports and what the +downstream agents read as context. Structured output is layered onto the +three decision-making agents (Research Manager, Trader, Portfolio Manager) +so that: -- The rating is type-safe and never has to be regex-extracted -- Schema field descriptions become the model's output instructions +- Their outputs follow consistent section headers across runs and providers - Each provider's native structured-output mode is used (json_schema for OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic) +- Schema field descriptions become the model's output instructions, freeing + the prompt body to focus on context and the rating-scale guidance - A render helper turns the parsed Pydantic instance back into the same markdown shape the rest of the system already consumes, so display, memory log, and saved reports keep working unchanged @@ -23,8 +24,13 @@ from typing import Optional from pydantic import BaseModel, Field +# --------------------------------------------------------------------------- +# Shared rating types +# --------------------------------------------------------------------------- + + class PortfolioRating(str, Enum): - """5-tier portfolio rating used by the Research Manager and Portfolio Manager.""" + """5-tier rating used by the Research Manager and Portfolio Manager.""" BUY = "Buy" OVERWEIGHT = "Overweight" @@ -33,6 +39,135 @@ class PortfolioRating(str, Enum): SELL = "Sell" +class TraderAction(str, Enum): + """3-tier transaction direction used by the Trader. + + The Trader's job is to translate the Research Manager's investment plan + into a concrete transaction proposal: should the desk execute a Buy, a + Sell, or sit on Hold this round. Position sizing and the nuanced + Overweight / Underweight calls happen later at the Portfolio Manager. + """ + + BUY = "Buy" + HOLD = "Hold" + SELL = "Sell" + + +# --------------------------------------------------------------------------- +# Research Manager +# --------------------------------------------------------------------------- + + +class ResearchPlan(BaseModel): + """Structured investment plan produced by the Research Manager. + + Hand-off to the Trader: the recommendation pins the directional view, + the rationale captures which side of the bull/bear debate carried the + argument, and the strategic actions translate that into concrete + instructions the trader can execute against. + """ + + recommendation: PortfolioRating = Field( + description=( + "The investment recommendation. Exactly one of Buy / Overweight / " + "Hold / Underweight / Sell. Reserve Hold for situations where the " + "evidence on both sides is genuinely balanced; otherwise commit to " + "the side with the stronger arguments." + ), + ) + rationale: str = Field( + description=( + "Conversational summary of the key points from both sides of the " + "debate, ending with which arguments led to the recommendation. " + "Speak naturally, as if to a teammate." + ), + ) + strategic_actions: str = Field( + description=( + "Concrete steps for the trader to implement the recommendation, " + "including position sizing guidance consistent with the rating." + ), + ) + + +def render_research_plan(plan: ResearchPlan) -> str: + """Render a ResearchPlan to markdown for storage and the trader's prompt context.""" + return "\n".join([ + f"**Recommendation**: {plan.recommendation.value}", + "", + f"**Rationale**: {plan.rationale}", + "", + f"**Strategic Actions**: {plan.strategic_actions}", + ]) + + +# --------------------------------------------------------------------------- +# Trader +# --------------------------------------------------------------------------- + + +class TraderProposal(BaseModel): + """Structured transaction proposal produced by the Trader. + + The trader reads the Research Manager's investment plan and the analyst + reports, then turns them into a concrete transaction: what action to + take, the reasoning that justifies it, and the practical levels for + entry, stop-loss, and sizing. + """ + + action: TraderAction = Field( + description="The transaction direction. Exactly one of Buy / Hold / Sell.", + ) + reasoning: str = Field( + description=( + "The case for this action, anchored in the analysts' reports and " + "the research plan. Two to four sentences." + ), + ) + entry_price: Optional[float] = Field( + default=None, + description="Optional entry price target in the instrument's quote currency.", + ) + stop_loss: Optional[float] = Field( + default=None, + description="Optional stop-loss price in the instrument's quote currency.", + ) + position_sizing: Optional[str] = Field( + default=None, + description="Optional sizing guidance, e.g. '5% of portfolio'.", + ) + + +def render_trader_proposal(proposal: TraderProposal) -> str: + """Render a TraderProposal to markdown. + + The trailing ``FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**`` line is + preserved for backward compatibility with the analyst stop-signal text + and any external code that greps for it. + """ + parts = [ + f"**Action**: {proposal.action.value}", + "", + f"**Reasoning**: {proposal.reasoning}", + ] + if proposal.entry_price is not None: + parts.extend(["", f"**Entry Price**: {proposal.entry_price}"]) + if proposal.stop_loss is not None: + parts.extend(["", f"**Stop Loss**: {proposal.stop_loss}"]) + if proposal.position_sizing: + parts.extend(["", f"**Position Sizing**: {proposal.position_sizing}"]) + parts.extend([ + "", + f"FINAL TRANSACTION PROPOSAL: **{proposal.action.value.upper()}**", + ]) + return "\n".join(parts) + + +# --------------------------------------------------------------------------- +# Portfolio Manager +# --------------------------------------------------------------------------- + + class PortfolioDecision(BaseModel): """Structured output produced by the Portfolio Manager. diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 0ecae8888..ea3f6b232 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -1,32 +1,60 @@ +"""Trader: turns the Research Manager's investment plan into a concrete transaction proposal.""" + +from __future__ import annotations + import functools +from langchain_core.messages import AIMessage + +from tradingagents.agents.schemas import TraderProposal, render_trader_proposal from tradingagents.agents.utils.agent_utils import build_instrument_context +from tradingagents.agents.utils.structured import ( + bind_structured, + invoke_structured_or_freetext, +) def create_trader(llm): + structured_llm = bind_structured(llm, TraderProposal, "Trader") + def trader_node(state, name): company_name = state["company_of_interest"] instrument_context = build_instrument_context(company_name) investment_plan = state["investment_plan"] - context = { - "role": "user", - "content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. {instrument_context} This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.", - } - messages = [ { "role": "system", - "content": "You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation.", + "content": ( + "You are a trading agent analyzing market data to make investment decisions. " + "Based on your analysis, provide a specific recommendation to buy, sell, or hold. " + "Anchor your reasoning in the analysts' reports and the research plan." + ), + }, + { + "role": "user", + "content": ( + f"Based on a comprehensive analysis by a team of analysts, here is an investment " + f"plan tailored for {company_name}. {instrument_context} This plan incorporates " + f"insights from current technical market trends, macroeconomic indicators, and " + f"social media sentiment. Use this plan as a foundation for evaluating your next " + f"trading decision.\n\nProposed Investment Plan: {investment_plan}\n\n" + f"Leverage these insights to make an informed and strategic decision." + ), }, - context, ] - result = llm.invoke(messages) + trader_plan = invoke_structured_or_freetext( + structured_llm, + llm, + messages, + render_trader_proposal, + "Trader", + ) return { - "messages": [result], - "trader_investment_plan": result.content, + "messages": [AIMessage(content=trader_plan)], + "trader_investment_plan": trader_plan, "sender": name, } diff --git a/tradingagents/agents/utils/structured.py b/tradingagents/agents/utils/structured.py new file mode 100644 index 000000000..400e1a82b --- /dev/null +++ b/tradingagents/agents/utils/structured.py @@ -0,0 +1,73 @@ +"""Shared helpers for invoking an agent with structured output and a graceful fallback. + +The Portfolio Manager, Trader, and Research Manager all follow the same +canonical pattern: + +1. At agent creation, wrap the LLM with ``with_structured_output(Schema)`` + so the model returns a typed Pydantic instance. If the provider does + not support structured output (rare; mostly older Ollama models), the + wrap is skipped and the agent uses free-text generation instead. +2. At invocation, run the structured call and render the result back to + markdown. If the structured call itself fails for any reason + (malformed JSON from a weak model, transient provider issue), fall + back to a plain ``llm.invoke`` so the pipeline never blocks. + +Centralising the pattern here keeps the agent factories small and ensures +all three agents log the same warnings when fallback fires. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Optional, TypeVar + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]: + """Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported. + + Logs a warning when the binding fails so the user understands the agent + will use free-text generation for every call instead of one-shot fallback. + """ + try: + return llm.with_structured_output(schema) + except (NotImplementedError, AttributeError) as exc: + logger.warning( + "%s: provider does not support with_structured_output (%s); " + "falling back to free-text generation", + agent_name, exc, + ) + return None + + +def invoke_structured_or_freetext( + structured_llm: Optional[Any], + plain_llm: Any, + prompt: Any, + render: Callable[[T], str], + agent_name: str, +) -> str: + """Run the structured call and render to markdown; fall back to free-text on any failure. + + ``prompt`` is whatever the underlying LLM accepts (a string for chat + invocations, a list of message dicts for chat models that take that + shape). The same value is forwarded to the free-text path so the + fallback sees the same input the structured call did. + """ + if structured_llm is not None: + try: + result = structured_llm.invoke(prompt) + return render(result) + except Exception as exc: + logger.warning( + "%s: structured-output invocation failed (%s); retrying once as free text", + agent_name, exc, + ) + + response = plain_llm.invoke(prompt) + return response.content From 4016fd4efa8d43ddde9d00cc59531490b9ac6a9a Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 20:54:19 +0000 Subject: [PATCH 20/44] fix: stop leaking OpenAI base_url into non-OpenAI provider clients Default config had backend_url='https://api.openai.com/v1' which was forwarded to every provider client, including Google. ChatGoogleGenerativeAI constructed requests against that base, producing malformed URLs like https://api.openai.com/v1/v1beta/models/gemini-2.5-flash:generateContent that 404 with empty body. Discovered while running propagate() against Gemini end-to-end. The structured-output smoke worked because that path constructed the LLM without going through the factory and without forwarding backend_url; propagate() goes through TradingAgentsGraph.__init__ which forwards config['backend_url'] to every provider. Fix: default to None. Each provider client falls back to its own endpoint (api.openai.com for OpenAI via _PROVIDER_CONFIG, Gemini's default for Google, and so on). The CLI flow already sets backend_url explicitly per provider when the user picks one, so that path is unchanged. Verified: full propagate() now passes end-to-end on both OpenAI gpt-5.4-mini and Gemini gemini-3-flash-preview, with all nine structure/log/signal checks green for each. --- tradingagents/default_config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 89b517659..7498d1883 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -11,7 +11,12 @@ DEFAULT_CONFIG = { "llm_provider": "openai", "deep_think_llm": "gpt-5.4", "quick_think_llm": "gpt-5.4-mini", - "backend_url": "https://api.openai.com/v1", + # When None, each provider's client falls back to its own default endpoint + # (api.openai.com for OpenAI, generativelanguage.googleapis.com for Gemini, ...). + # The CLI overrides this per provider when the user picks one. Keeping a + # provider-specific URL here would leak (e.g. OpenAI's /v1 was previously + # being forwarded to Gemini, producing malformed request URLs). + "backend_url": None, # Provider-specific thinking configuration "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" From 7c37249f808f9c169ad2198dc384166e7ca7adf9 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 21:54:30 +0000 Subject: [PATCH 21/44] =?UTF-8?q?chore:=20release=20v0.2.4=20=E2=80=94=20s?= =?UTF-8?q?tructured=20agents,=20checkpoint,=20memory=20log,=20providers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This release bundles substantial work since v0.2.3: - Structured-output Research Manager, Trader, and Portfolio Manager (canonical with_structured_output pattern, single LLM call per agent, rendered markdown preserves the existing report shape). - LangGraph checkpoint resume for crash recovery (--checkpoint flag). - Persistent decision log replacing the per-agent BM25 memory, with deferred reflection driven by yfinance returns + alpha vs SPY. - DeepSeek, Qwen, GLM, and Azure OpenAI provider support; dynamic OpenRouter model selection. - Docker support; cache and logs moved to ~/.tradingagents/ to fix Docker permission issues. - Windows UTF-8 encoding fix on every file I/O site. - 5-tier rating consistency (Buy / Overweight / Hold / Underweight / Sell) across Research Manager, Portfolio Manager, signal processor, memory log. Plus the small quality items in this commit: 1. Suppress noisy Pydantic serializer warnings from OpenAI Responses-API parse path by defaulting structured-output to method="function_calling" (root-cause fix, not a warnings filter — same typed result, no warnings). 2. Ship scripts/smoke_structured_output.py so contributors can verify their provider's structured-output path with one command. 3. Add opt-in memory_log_max_entries config — when set, oldest resolved memory log entries are pruned once the cap is exceeded; pending entries (unresolved) are never pruned. 4. backend_url default changed from the OpenAI URL to None so the per-provider client falls back to its native endpoint instead of leaking OpenAI's URL into Gemini / other clients. CHANGELOG.md added with the full v0.2.4 entry. 92 tests pass without API keys. --- CHANGELOG.md | 266 +++++++++++++++++++++ README.md | 3 + pyproject.toml | 2 +- scripts/smoke_structured_output.py | 176 ++++++++++++++ tests/test_memory_log.py | 53 ++++ tradingagents/agents/utils/memory.py | 44 +++- tradingagents/default_config.py | 4 + tradingagents/llm_clients/openai_client.py | 16 ++ 8 files changed, 562 insertions(+), 2 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 scripts/smoke_structured_output.py diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..3fd6afa2b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,266 @@ +# Changelog + +All notable changes to TradingAgents are documented here. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +Breaking changes within the 0.x line are called out explicitly. + +## [0.2.4] — 2026-04-25 + +### Added + +- **Structured-output decision agents.** Research Manager, Trader, and Portfolio + Manager now use `llm.with_structured_output(Schema)` on their primary call + and return typed Pydantic instances. Each provider's native structured-output + mode is used (`json_schema` for OpenAI / xAI, `response_schema` for Gemini, + tool-use for Anthropic, function-calling for OpenAI-compatible providers). + Render helpers preserve the existing markdown shape so memory log, CLI + display, and saved reports keep working unchanged. (#434) +- **LangGraph checkpoint resume** — opt-in via `--checkpoint`. State is saved + after each node so crashed or interrupted runs resume from the last + successful step. Per-ticker SQLite databases under + `~/.tradingagents/cache/checkpoints/`. `--clear-checkpoints` resets them. (#594) +- **Persistent decision log** replacing the per-agent BM25 memory. Decisions + are stored automatically at the end of `propagate()`; the next same-ticker + run resolves prior pending entries with realised return, alpha vs SPY, and + a one-paragraph reflection. Override path with `TRADINGAGENTS_MEMORY_LOG_PATH`. + Optional `memory_log_max_entries` config caps resolved entries; pending + entries are never pruned. (#578, #563, #564, #579) +- **DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), and Azure OpenAI** + providers, plus dynamic OpenRouter model selection. +- **Docker support** — multi-stage build with separate dev and runtime images. +- **`scripts/smoke_structured_output.py`** — diagnostic that exercises the + three structured-output agents against any provider so contributors can + verify their setup with one command. +- **5-tier rating scale** (Buy / Overweight / Hold / Underweight / Sell) used + consistently by Research Manager, Portfolio Manager, signal processor, and + the memory log; Trader keeps 3-tier (Buy / Hold / Sell) since transaction + direction is naturally ternary. +- **Pytest fixtures** — lazy LLM client imports plus placeholder API keys so + the test suite runs cleanly without credentials. (#588) + +### Changed + +- **`backend_url` default is now `None`** rather than the OpenAI URL. Each + provider client falls back to its native default. The previous default + leaked the OpenAI URL into non-OpenAI clients (e.g. Gemini), producing + malformed request URLs for Python users who switched providers without + overriding `backend_url`. The CLI flow is unaffected. +- All file I/O passes explicit `encoding="utf-8"` so Windows users no longer + hit `UnicodeEncodeError` with the cp1252 default. (#543, #550, #576) +- Cache and log directories moved to `~/.tradingagents/` to resolve Docker + permission issues. (#519) +- `SignalProcessor` reads the rating from the Portfolio Manager's rendered + markdown via a deterministic heuristic — no extra LLM call. +- OpenAI structured-output calls default to `method="function_calling"` to + avoid noisy `PydanticSerializationUnexpectedValue` warnings emitted by + langchain-openai's Responses-API parse path. Same typed result, no warnings. + +### Fixed + +- Empty memory no longer triggers fabricated past-lessons in agent prompts; + the memory-log redesign makes this structurally impossible since only the + Portfolio Manager consults memory and only when entries exist. (#572) +- Tool-call logging processes every chunk message, not just the last one, and + memory score normalization handles empty score arrays. (#534, #531) + +### Removed + +- `FinancialSituationMemory` (the per-agent BM25 system) and the dead + `reflect_and_remember()` plumbing; subsumed by the persistent decision log. +- Hardcoded Google endpoint that caused 404 when `langchain-google-genai` + changed its API path. (#493, #496) + +### Contributors + +Thanks to everyone who shaped this release through code, design, and reports: + +- [@claytonbrown](https://github.com/claytonbrown) — checkpoint resume (#594), test fixtures (#588), design feedback on cost tracking (#582) and structured validation (#583) +- [@Bcardo](https://github.com/Bcardo) — memory-log redesign (#579), empty-memory hallucination report (#572), encoding fix proposal (#570) +- [@voidborne-d](https://github.com/voidborne-d) — memory persistence design (#564), portfolio manager state fix (#503) +- [@mannubaveja007](https://github.com/mannubaveja007) — structured-output feature request (#434) +- [@kelder66](https://github.com/kelder66) — RAM-only memory issue (#563) +- [@Gujiassh](https://github.com/Gujiassh) — tool-call logging fix (#534), test stub PR (#533) +- [@iuyup](https://github.com/iuyup) — memory score normalization fix (#531) +- [@kaihg](https://github.com/kaihg) — Google base_url fix (#496) +- [@32ryh98yfe](https://github.com/32ryh98yfe) — Gemini 404 report (#493) +- [@uppb](https://github.com/uppb) — OpenRouter dynamic model selection (#482) +- [@guoz14](https://github.com/guoz14) — OpenRouter limited-model report (#337) +- [@samchenku](https://github.com/samchenku) — indicator name normalization (#490) +- [@JasonOA888](https://github.com/JasonOA888) — y_finance pandas import fix (#488) +- [@tiffanychum](https://github.com/tiffanychum) — stale import cleanup (#499) +- [@zaizou](https://github.com/zaizou) — Docker permission issue (#519) +- [@Stosman123](https://github.com/Stosman123), [@mauropuga](https://github.com/mauropuga), [@hotwind2015](https://github.com/hotwind2015) — Windows encoding bug reports (#543, #550, #576) +- [@nnishad](https://github.com/nnishad), [@atharvajoshi01](https://github.com/atharvajoshi01) — encoding fix proposals (#568, #549) + +## [0.2.3] — 2026-03-29 + +### Added + +- **Multi-language output** for analyst reports and final decisions, with a + CLI selector. Internal agent debate stays in English for reasoning quality. (#472) +- **GPT-5.4 family models** in the default catalog, with deep/quick model split. +- **Unified model catalog** as a single source of truth for CLI options and + provider validation. + +### Changed + +- `base_url` is forwarded to Google and Anthropic clients so corporate proxies + work consistently across providers. (#427) +- Standardised the Google `api_key` parameter to the unified `api_key` form. + +### Fixed + +- Backtesting fetchers no longer leak look-ahead data when `curr_date` is in + the middle of a fetched window. (#475) +- Invalid indicator names from the LLM are caught at the tool boundary instead + of crashing the run. (#429) +- yfinance news fetchers respect the same exponential-backoff retry as price + fetchers. (#445) + +### Contributors + +- [@ahmedk20](https://github.com/ahmedk20) — multi-language output (#472) +- [@CadeYu](https://github.com/CadeYu) — model catalog typing (#464) +- [@javierdejesusda](https://github.com/javierdejesusda) — unified Google API key parameter (#453) +- [@voidborne-d](https://github.com/voidborne-d) — yfinance news retry (#445) +- [@kostakost2](https://github.com/kostakost2) — look-ahead bias report (#475) +- [@lu-zhengda](https://github.com/lu-zhengda) — proxy/base_url support request (#427) +- [@VamsiKrishna2021](https://github.com/VamsiKrishna2021) — invalid indicator crash report (#429) + +## [0.2.2] — 2026-03-22 + +### Added + +- **Five-tier rating scale** (Buy / Overweight / Hold / Underweight / Sell) + introduced for the Portfolio Manager. +- **Anthropic effort level** support for Claude models. +- **OpenAI Responses API** path for native OpenAI models. + +### Changed + +- `risk_manager` renamed to `portfolio_manager` to match the role description + shown in the CLI display. +- Exchange-qualified tickers (e.g. `7203.T`, `BRK.B`) preserved across all + agent prompts and tool calls. +- Process-level UTF-8 default attempted for cross-platform consistency + (note: this approach did not actually take effect; replaced in v0.2.4 with + explicit per-call `encoding="utf-8"` arguments). + +### Fixed + +- yfinance rate-limit errors are retried with exponential backoff. (#426) +- HTTP client SSL customisation is supported for environments that need + custom certificate bundles. (#379) +- Report-section writes handle list-of-string content gracefully. + +### Contributors + +- [@CadeYu](https://github.com/CadeYu) — exchange-qualified ticker preservation (#413) +- [@yang1002378395-cmyk](https://github.com/yang1002378395-cmyk) — HTTP client SSL customisation (#379) + +## [0.2.1] — 2026-03-15 + +### Security + +- Patched `langchain-core` vulnerability (LangGrinch). (#335) +- Removed `chainlit` dependency affected by CVE-2026-22218. + +### Added + +- `pyproject.toml` build-system configuration; the project now installs via + modern packaging tooling. + +### Removed + +- `setup.py` — dependencies consolidated to `pyproject.toml`. + +### Fixed + +- Risk manager reads the correct fundamental report source. (#341) +- All `open()` calls receive an explicit UTF-8 encoding (initial pass). +- `get_indicators` tool handles comma-separated indicator names from the LLM. (#368) +- `Propagation` initialises every debate-state field so risk debaters never + see missing keys. +- Stock data parsing tolerates malformed CSVs and NaN values. +- Conditional debate logic respects the configured round count. (#361) + +### Contributors + +- [@RinZ27](https://github.com/RinZ27) — `langchain-core` security patch (#335) +- [@Ljx-007](https://github.com/Ljx-007) — risk manager fundamental-report fix (#341) +- [@makk9](https://github.com/makk9) — debate-rounds config issue (#361) + +## [0.2.0] — 2026-02-04 + +This is the largest release since the initial public version. The framework +moved from single-provider to a multi-provider architecture and grew several +production-ready surfaces. + +### Added + +- **Multi-provider LLM support** (OpenAI, Google, Anthropic, xAI, OpenRouter, + Ollama) via a factory pattern, with provider-specific thinking configurations. +- **Alpha Vantage** integration as a configurable primary data provider, with + yfinance as a community-stability fallback. +- **Footer statistics** in the CLI: real-time tracking of LLM calls, tool + calls, and token usage via LangChain callbacks. +- **Post-analysis report saving** — the framework writes per-section markdown + files (analyst reports, debate transcripts, final decision) when a run + completes. +- **Announcements panel** — fetches updates from `api.tauric.ai/v1/announcements` + for the CLI welcome screen. +- **Tool fallbacks** so a single vendor outage does not stop the pipeline. + +### Changed + +- Risky / Safe risk debaters renamed to **Aggressive / Conservative** for + consistency with the displayed agent labels. +- Default data vendor switched to balance reliability and quota across + community deployments. +- Ollama and OpenRouter model lists updated; default endpoints clarified. + +### Fixed + +- Analyst status tracking and message deduplication in the live display. +- Infinite-loop guard in the agent loop; reflection and logging hardened. +- Various data-vendor implementation bugs and tool-signature mismatches. + +### Contributors + +This release is the first with substantial outside contributions; many community +PRs from late 2025 also landed here. + +- [@luohy15](https://github.com/luohy15) — Alpha Vantage data-vendor integration (#235) +- [@EdwardoSunny](https://github.com/EdwardoSunny) — yfinance fetching optimisations (#245) +- [@Mirza-Samad-Ahmed-Baig](https://github.com/Mirza-Samad-Ahmed-Baig) — infinite-loop guard, reflection, and logging fixes (#89) +- [@ZeroAct](https://github.com/ZeroAct) — saved results path support (#29) +- [@Zhongyi-Lu](https://github.com/Zhongyi-Lu) — `.env` gitignore (#49) +- [@csoboy](https://github.com/csoboy) — local Ollama setup (#53) +- [@chauhang](https://github.com/chauhang) — initial Docker support attempt (#47, later reverted; the merged Docker support shipped in v0.2.4) + +## [0.1.1] — 2025-06-07 + +### Removed + +- Static site assets that had been bundled with v0.1.0; the public site now + lives separately. + +## [0.1.0] — 2025-06-05 + +### Added + +- **Initial public release** of the TradingAgents multi-agent trading + framework: market / sentiment / news / fundamentals analysts; bull and bear + researchers; trader; aggressive, conservative, and neutral risk debaters; + portfolio manager. LangGraph orchestration, yfinance data, per-agent + BM25 memory, single-provider OpenAI integration, interactive CLI. + +[0.2.4]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.3...v0.2.4 +[0.2.3]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.2...v0.2.3 +[0.2.2]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.1...v0.2.2 +[0.2.1]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.0...v0.2.1 +[0.2.0]: https://github.com/TauricResearch/TradingAgents/compare/v0.1.1...v0.2.0 +[0.1.1]: https://github.com/TauricResearch/TradingAgents/compare/v0.1.0...v0.1.1 +[0.1.0]: https://github.com/TauricResearch/TradingAgents/releases/tag/v0.1.0 diff --git a/README.md b/README.md index 6c8f644ec..54af501a9 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ # TradingAgents: Multi-Agents LLM Financial Trading Framework ## News +- [2026-04] **TradingAgents v0.2.4** released with structured-output agents (Research Manager, Trader, Portfolio Manager), LangGraph checkpoint resume, persistent decision log, DeepSeek/Qwen/GLM/Azure provider support, Docker, and a Windows UTF-8 encoding fix. See [CHANGELOG.md](CHANGELOG.md) for the full list. - [2026-03] **TradingAgents v0.2.3** released with multi-language support, GPT-5.4 family models, unified model catalog, backtesting date fidelity, and proxy support. - [2026-03] **TradingAgents v0.2.2** released with GPT-5.4/Gemini 3.1/Claude 4.6 model coverage, five-tier rating scale, OpenAI Responses API, Anthropic effort control, and cross-platform stability. - [2026-02] **TradingAgents v0.2.0** released with multi-provider LLM support (GPT-5.x, Gemini 3.x, Claude 4.x, Grok 4.x) and improved system architecture. @@ -251,6 +252,8 @@ _, decision = ta.propagate("NVDA", "2026-01-15") We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/). +Past contributions, including code, design feedback, and bug reports, are credited per release in [`CHANGELOG.md`](CHANGELOG.md). + ## Citation Please reference our work if you find *TradingAgents* provides you with some help :) diff --git a/pyproject.toml b/pyproject.toml index b569504ef..07cbbd3f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "tradingagents" -version = "0.2.3" +version = "0.2.4" description = "TradingAgents: Multi-Agents LLM Financial Trading Framework" readme = "README.md" requires-python = ">=3.10" diff --git a/scripts/smoke_structured_output.py b/scripts/smoke_structured_output.py new file mode 100644 index 000000000..1d3cf681c --- /dev/null +++ b/scripts/smoke_structured_output.py @@ -0,0 +1,176 @@ +"""End-to-end smoke for structured-output agents against a real LLM provider. + +Runs the three decision-making agents (Research Manager, Trader, Portfolio +Manager) directly with their structured-output bindings and prints the +typed Pydantic instance + the rendered markdown for each. Use this to +verify a provider's native structured-output mode (json_schema for +OpenAI / xAI / DeepSeek / Qwen / GLM, response_schema for Gemini, tool-use +for Anthropic) returns clean instances on the schemas we ship. + +Usage: + OPENAI_API_KEY=... python scripts/smoke_structured_output.py openai + GOOGLE_API_KEY=... python scripts/smoke_structured_output.py google + ANTHROPIC_API_KEY=... python scripts/smoke_structured_output.py anthropic + DEEPSEEK_API_KEY=... python scripts/smoke_structured_output.py deepseek + +The script does NOT call propagate(), to keep the surface tight and the +cost low — it exercises only the three structured-output calls we just +added, plus the heuristic SignalProcessor. +""" + +from __future__ import annotations + +import argparse +import os +import sys + +from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager +from tradingagents.agents.managers.research_manager import create_research_manager +from tradingagents.agents.trader.trader import create_trader +from tradingagents.graph.signal_processing import SignalProcessor +from tradingagents.llm_clients import create_llm_client + + +PROVIDER_DEFAULTS = { + "openai": ("gpt-5.4-mini", None), + "google": ("gemini-2.5-flash", None), + "anthropic": ("claude-sonnet-4-6", None), + "deepseek": ("deepseek-chat", None), + "qwen": ("qwen-plus", None), + "glm": ("glm-5", None), + "xai": ("grok-4", None), +} + + +# Minimal but realistic state for the three agents. +DEBATE_HISTORY = """ +Bull Analyst: NVDA's data-center revenue grew 60% YoY last quarter, driven by +Blackwell ramp; sovereign AI deals with multiple governments add a $40B+ +multi-year tailwind. Margins remain above peer average. + +Bear Analyst: Concentration risk is real — top three customers are >40% of +revenue. Any pause in hyperscaler capex would compress the multiple. China +export restrictions still cap a meaningful portion of demand. +""" + + +def _make_rm_state(): + return { + "company_of_interest": "NVDA", + "investment_debate_state": { + "history": DEBATE_HISTORY, + "bull_history": "Bull Analyst: NVDA's data-center revenue grew 60% YoY...", + "bear_history": "Bear Analyst: Concentration risk is real...", + "current_response": "", + "judge_decision": "", + "count": 1, + }, + } + + +def _make_trader_state(investment_plan: str): + return { + "company_of_interest": "NVDA", + "investment_plan": investment_plan, + } + + +def _make_pm_state(investment_plan: str, trader_plan: str): + return { + "company_of_interest": "NVDA", + "past_context": "", + "risk_debate_state": { + "history": "Aggressive: lean in. Conservative: trim. Neutral: balanced sizing.", + "aggressive_history": "Aggressive: ...", + "conservative_history": "Conservative: ...", + "neutral_history": "Neutral: ...", + "judge_decision": "", + "current_aggressive_response": "", + "current_conservative_response": "", + "current_neutral_response": "", + "count": 1, + }, + "market_report": "Market report.", + "sentiment_report": "Sentiment report.", + "news_report": "News report.", + "fundamentals_report": "Fundamentals report.", + "investment_plan": investment_plan, + "trader_investment_plan": trader_plan, + } + + +def _print_section(title: str, content: str) -> None: + bar = "=" * 70 + print(f"\n{bar}\n{title}\n{bar}\n{content}") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("provider", choices=list(PROVIDER_DEFAULTS.keys())) + parser.add_argument("--deep-model", default=None, help="Override deep_think_llm") + parser.add_argument("--quick-model", default=None, help="Override quick_think_llm") + args = parser.parse_args() + + default_model, _ = PROVIDER_DEFAULTS[args.provider] + deep_model = args.deep_model or default_model + quick_model = args.quick_model or default_model + + print(f"Provider: {args.provider}") + print(f"Deep model: {deep_model}") + print(f"Quick model: {quick_model}") + + # Build the LLM clients via the framework's factory. + deep_client = create_llm_client(provider=args.provider, model=deep_model) + quick_client = create_llm_client(provider=args.provider, model=quick_model) + deep_llm = deep_client.get_llm() + quick_llm = quick_client.get_llm() + + # 1) Research Manager + rm = create_research_manager(deep_llm) + rm_result = rm(_make_rm_state()) + investment_plan = rm_result["investment_plan"] + _print_section("[1] Research Manager — investment_plan", investment_plan) + + # 2) Trader (consumes RM's plan) + trader = create_trader(quick_llm) + trader_result = trader(_make_trader_state(investment_plan)) + trader_plan = trader_result["trader_investment_plan"] + _print_section("[2] Trader — trader_investment_plan", trader_plan) + + # 3) Portfolio Manager (consumes both) + pm = create_portfolio_manager(deep_llm) + pm_result = pm(_make_pm_state(investment_plan, trader_plan)) + final_decision = pm_result["final_trade_decision"] + _print_section("[3] Portfolio Manager — final_trade_decision", final_decision) + + # 4) SignalProcessor extracts the rating with zero LLM calls. + sp = SignalProcessor() + rating = sp.process_signal(final_decision) + _print_section("[4] SignalProcessor → rating", rating) + + # 5) Lightweight checks: each rendered output should carry the expected + # section headers so downstream consumers (memory log, CLI display, + # saved reports) keep working. + checks = [ + ("Research Manager", investment_plan, ["**Recommendation**:"]), + ("Trader", trader_plan, ["**Action**:", "FINAL TRANSACTION PROPOSAL:"]), + ("Portfolio Manager", final_decision, ["**Rating**:", "**Executive Summary**:", "**Investment Thesis**:"]), + ] + print("\n" + "=" * 70 + "\nStructure checks\n" + "=" * 70) + failures = 0 + for name, text, required in checks: + for marker in required: + ok = marker in text + print(f" {'PASS' if ok else 'FAIL'} {name}: contains {marker!r}") + failures += int(not ok) + + print() + if failures: + print(f"Smoke FAILED: {failures} structure check(s) missing.") + return 1 + print("Smoke PASSED: structured output → rendered markdown chain works for", args.provider) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_memory_log.py b/tests/test_memory_log.py index e0da15efc..5d7f7f844 100644 --- a/tests/test_memory_log.py +++ b/tests/test_memory_log.py @@ -291,6 +291,59 @@ class TestTradingMemoryLogCore: assert log.load_entries() == [] assert log.get_past_context("NVDA") == "" + # Rotation: opt-in cap on resolved entries + + def test_rotation_disabled_by_default(self, tmp_path): + """Without max_entries, all resolved entries are kept.""" + log = make_log(tmp_path) + for i in range(7): + _resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.") + assert len(log.load_entries()) == 7 + + def test_rotation_prunes_oldest_resolved(self, tmp_path): + """When max_entries is set and exceeded, oldest resolved entries are pruned.""" + log = TradingMemoryLog({ + "memory_log_path": str(tmp_path / "trading_memory.md"), + "memory_log_max_entries": 3, + }) + # Resolve 5 entries; rotation should keep only the 3 most recent. + for i in range(5): + _resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.") + entries = log.load_entries() + assert len(entries) == 3 + # Confirm the OLDEST were dropped, not the newest. + dates = [e["date"] for e in entries] + assert dates == ["2026-01-03", "2026-01-04", "2026-01-05"] + + def test_rotation_never_prunes_pending(self, tmp_path): + """Pending entries (unresolved) are kept regardless of the cap.""" + log = TradingMemoryLog({ + "memory_log_path": str(tmp_path / "trading_memory.md"), + "memory_log_max_entries": 2, + }) + # 3 resolved + 2 pending. With cap=2, only 2 resolved survive; both pending stay. + for i in range(3): + _resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Resolved {i}.") + log.store_decision("NVDA", "2026-02-01", DECISION_BUY) + log.store_decision("NVDA", "2026-02-02", DECISION_OVERWEIGHT) + # Trigger rotation by resolving one more entry — pending entries must stay. + _resolve_entry(log, "NVDA", "2026-01-04", DECISION_BUY, "Resolved 3.") + entries = log.load_entries() + pending = [e for e in entries if e["pending"]] + resolved = [e for e in entries if not e["pending"]] + assert len(pending) == 2, "pending entries must never be pruned" + assert len(resolved) == 2, f"expected 2 resolved after rotation, got {len(resolved)}" + + def test_rotation_under_cap_is_noop(self, tmp_path): + """No rotation when resolved count <= max_entries.""" + log = TradingMemoryLog({ + "memory_log_path": str(tmp_path / "trading_memory.md"), + "memory_log_max_entries": 10, + }) + for i in range(3): + _resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.") + assert len(log.load_entries()) == 3 + # Rating parsing: markdown bold and numbered list formats def test_rating_parsed_from_bold_markdown(self, tmp_path): diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index fee5ac4a2..c94717556 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -17,11 +17,14 @@ class TradingMemoryLog: _REFLECTION_RE = re.compile(r"REFLECTION:\n(.*?)$", re.DOTALL) def __init__(self, config: dict = None): + cfg = config or {} self._log_path = None - path = (config or {}).get("memory_log_path") + path = cfg.get("memory_log_path") if path: self._log_path = Path(path).expanduser() self._log_path.parent.mkdir(parents=True, exist_ok=True) + # Optional cap on resolved entries. None disables rotation. + self._max_entries = cfg.get("memory_log_max_entries") # --- Write path (Phase A) --- @@ -153,6 +156,7 @@ class TradingMemoryLog: if not updated: return + new_blocks = self._apply_rotation(new_blocks) new_text = self._SEPARATOR.join(new_blocks) tmp_path = self._log_path.with_suffix(".tmp") tmp_path.write_text(new_text, encoding="utf-8") @@ -206,6 +210,7 @@ class TradingMemoryLog: if not matched: new_blocks.append(block) + new_blocks = self._apply_rotation(new_blocks) new_text = self._SEPARATOR.join(new_blocks) tmp_path = self._log_path.with_suffix(".tmp") tmp_path.write_text(new_text, encoding="utf-8") @@ -213,6 +218,43 @@ class TradingMemoryLog: # --- Helpers --- + def _apply_rotation(self, blocks: List[str]) -> List[str]: + """Drop oldest resolved blocks when their count exceeds max_entries. + + Pending blocks are always kept (they represent unprocessed work). + Returns ``blocks`` unchanged when rotation is disabled or under cap. + """ + if not self._max_entries or self._max_entries <= 0: + return blocks + + # Tag each block with (kept, is_resolved) by parsing tag-line markers. + decisions = [] + for block in blocks: + stripped = block.strip() + if not stripped: + decisions.append((block, False)) + continue + tag_line = stripped.splitlines()[0].strip() + is_resolved = ( + tag_line.startswith("[") + and tag_line.endswith("]") + and not tag_line.endswith("| pending]") + ) + decisions.append((block, is_resolved)) + + resolved_count = sum(1 for _, r in decisions if r) + if resolved_count <= self._max_entries: + return blocks + + to_drop = resolved_count - self._max_entries + kept: List[str] = [] + for block, is_resolved in decisions: + if is_resolved and to_drop > 0: + to_drop -= 1 + continue + kept.append(block) + return kept + def _parse_entry(self, raw: str) -> Optional[dict]: lines = raw.strip().splitlines() if not lines: diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 7498d1883..fa6d5742c 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -7,6 +7,10 @@ DEFAULT_CONFIG = { "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")), "data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")), "memory_log_path": os.getenv("TRADINGAGENTS_MEMORY_LOG_PATH", os.path.join(_TRADINGAGENTS_HOME, "memory", "trading_memory.md")), + # Optional cap on the number of resolved memory log entries. When set, + # the oldest resolved entries are pruned once this limit is exceeded. + # Pending entries are never pruned. None disables rotation entirely. + "memory_log_max_entries": None, # LLM settings "llm_provider": "openai", "deep_think_llm": "gpt-5.4", diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index f943124a9..bbfcd39e3 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -18,6 +18,22 @@ class NormalizedChatOpenAI(ChatOpenAI): def invoke(self, input, config=None, **kwargs): return normalize_content(super().invoke(input, config, **kwargs)) + def with_structured_output(self, schema, *, method=None, **kwargs): + """Wrap with structured output, defaulting to function_calling for OpenAI. + + langchain-openai's Responses-API-parse path (the default for json_schema + when use_responses_api=True) calls response.model_dump(...) on the OpenAI + SDK's union-typed parsed response, which makes Pydantic emit ~20 + PydanticSerializationUnexpectedValue warnings per call. The function-calling + path returns a plain tool-call shape that does not trigger that + serialization, so it is the cleaner choice for our combination of + use_responses_api=True + with_structured_output. Both paths use OpenAI's + strict mode and produce the same typed Pydantic instance. + """ + if method is None: + method = "function_calling" + return super().with_structured_output(schema, method=method, **kwargs) + # Kwargs forwarded from user config to ChatOpenAI _PASSTHROUGH_KWARGS = ( "timeout", "max_retries", "reasoning_effort", From 2c97bad45c773760db3ef7cb787a9fd1fbf7ac67 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Fri, 1 May 2026 18:56:36 +0000 Subject: [PATCH 22/44] fix(security): validate ticker before using as path component (#618) The ticker symbol reaches three filesystem-path construction sites (load_ohlcv cache filename, checkpointer DB path, _log_state results directory) without validation. A value containing path separators or "../" escapes the configured cache / checkpoints / results directory. Two attack vectors: - Programmatic callers passing arbitrary ticker to propagate() - Prompt injection via fetched news content steering the LLM into tool calls with attacker-chosen ticker Fix: new safe_ticker_component() validator in tradingagents/dataflows/ utils.py applied at all three sites. Allows the standard ticker character set ([A-Za-z0-9._\-\^], up to 32 chars) and explicitly rejects dot-only values like "." and ".." which would otherwise pass the regex but traverse parent directories. Seven test cases cover the accepted formats (BRK-B, 7203.T, ^GSPC, etc.) and the rejected inputs (path separators, null bytes, whitespace, empty values, overlong strings, dot-only values). Closes #618. --- tests/test_safe_ticker_component.py | 52 +++++++++++++++++++++ tradingagents/dataflows/stockstats_utils.py | 7 ++- tradingagents/dataflows/utils.py | 35 ++++++++++++++ tradingagents/graph/checkpointer.py | 6 ++- tradingagents/graph/trading_graph.py | 7 ++- 5 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 tests/test_safe_ticker_component.py diff --git a/tests/test_safe_ticker_component.py b/tests/test_safe_ticker_component.py new file mode 100644 index 000000000..3bdc02234 --- /dev/null +++ b/tests/test_safe_ticker_component.py @@ -0,0 +1,52 @@ +"""Tests for the ticker path-component validator that blocks directory traversal.""" + +import os +import unittest + +import pytest + +from tradingagents.dataflows.utils import safe_ticker_component + + +@pytest.mark.unit +class TestSafeTickerComponent(unittest.TestCase): + def test_accepts_common_ticker_formats(self): + for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"): + self.assertEqual(safe_ticker_component(ticker), ticker) + + def test_rejects_path_separators(self): + for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"): + with self.assertRaises(ValueError): + safe_ticker_component(bad) + + def test_rejects_null_byte_and_whitespace(self): + for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"): + with self.assertRaises(ValueError): + safe_ticker_component(bad) + + def test_rejects_empty_or_non_string(self): + for bad in ("", None, 123, b"AAPL"): + with self.assertRaises(ValueError): + safe_ticker_component(bad) + + def test_rejects_overlong_input(self): + with self.assertRaises(ValueError): + safe_ticker_component("A" * 33) + + def test_rejects_dot_only_values(self): + # '.' and '..' pass the regex but traverse when used as a path + # component (e.g. ``Path(results_dir) / ticker / "logs"``). + for bad in (".", "..", "...", "...."): + with self.assertRaises(ValueError): + safe_ticker_component(bad) + + def test_traversal_string_does_not_escape_join(self): + """Sanity: sanitized values stay within base when joined.""" + base = os.path.realpath("/tmp/cache") + ticker = safe_ticker_component("AAPL") + joined = os.path.realpath(os.path.join(base, f"{ticker}.csv")) + self.assertTrue(joined.startswith(base + os.sep)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index cb24c5d6a..260ef73cd 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -8,6 +8,7 @@ from stockstats import wrap from typing import Annotated import os from .config import get_config +from .utils import safe_ticker_component logger = logging.getLogger(__name__) @@ -51,6 +52,10 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame: subsequent calls the cache is reused. Rows after curr_date are filtered out so backtests never see future prices. """ + # Reject ticker values that would escape the cache directory when + # interpolated into the cache filename (e.g. ``../../tmp/x``). + safe_symbol = safe_ticker_component(symbol) + config = get_config() curr_date_dt = pd.to_datetime(curr_date) @@ -63,7 +68,7 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame: os.makedirs(config["data_cache_dir"], exist_ok=True) data_file = os.path.join( config["data_cache_dir"], - f"{symbol}-YFin-data-{start_str}-{end_str}.csv", + f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv", ) if os.path.exists(data_file): diff --git a/tradingagents/dataflows/utils.py b/tradingagents/dataflows/utils.py index c99b777ab..3d8a45d81 100644 --- a/tradingagents/dataflows/utils.py +++ b/tradingagents/dataflows/utils.py @@ -1,4 +1,5 @@ import os +import re import json import pandas as pd from datetime import date, timedelta, datetime @@ -6,6 +7,40 @@ from typing import Annotated SavePathType = Annotated[str, "File path to save data. If None, data is not saved."] +# Tickers can contain letters, digits, dot, dash, underscore, and caret +# (for index symbols like ^GSPC). Anything else is rejected so the value +# never escapes a containing directory when interpolated into a path. +_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$") + + +def safe_ticker_component(value: str, *, max_len: int = 32) -> str: + """Validate ``value`` is safe to interpolate into a filesystem path. + + Tickers come from user CLI input or from LLM tool calls, both of which + can be influenced by attacker-controlled content (e.g. prompt injection + embedded in fetched news). Without validation, a value like + ``"../../../etc/foo"`` flows into ``os.path.join`` / ``Path /`` and + escapes the configured cache, checkpoint, or results directory. + + Returns ``value`` unchanged when it matches the allowed pattern; raises + ``ValueError`` otherwise. + """ + if not isinstance(value, str) or not value: + raise ValueError(f"ticker must be a non-empty string, got {value!r}") + if len(value) > max_len: + raise ValueError(f"ticker exceeds {max_len} chars: {value!r}") + if not _TICKER_PATH_RE.fullmatch(value): + raise ValueError( + f"ticker contains characters not allowed in a filesystem path: {value!r}" + ) + # The regex above allows '.', so values like '.', '..', '...' would pass, + # and as a path component they traverse the parent directory. Reject any + # value that's only dots. + if set(value) == {"."}: + raise ValueError(f"ticker cannot consist solely of dots: {value!r}") + return value + + def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None: if save_path: data.to_csv(save_path, encoding="utf-8") diff --git a/tradingagents/graph/checkpointer.py b/tradingagents/graph/checkpointer.py index 7a73ee446..3ba19726d 100644 --- a/tradingagents/graph/checkpointer.py +++ b/tradingagents/graph/checkpointer.py @@ -13,12 +13,16 @@ from typing import Generator from langgraph.checkpoint.sqlite import SqliteSaver +from tradingagents.dataflows.utils import safe_ticker_component + def _db_path(data_dir: str | Path, ticker: str) -> Path: """Return the SQLite checkpoint DB path for a ticker.""" + # Reject ticker values that would escape the checkpoints directory. + safe = safe_ticker_component(ticker).upper() p = Path(data_dir) / "checkpoints" p.mkdir(parents=True, exist_ok=True) - return p / f"{ticker.upper()}.db" + return p / f"{safe}.db" def thread_id(ticker: str, date: str) -> str: diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index bd6f1fc5c..d7e8b5731 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -18,6 +18,7 @@ from tradingagents.llm_clients import create_llm_client from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import TradingMemoryLog +from tradingagents.dataflows.utils import safe_ticker_component from tradingagents.agents.utils.agent_states import ( AgentState, InvestDebateState, @@ -378,8 +379,10 @@ class TradingAgentsGraph: "final_trade_decision": final_state["final_trade_decision"], } - # Save to file - directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs" + # Save to file. Reject ticker values that would escape the + # results directory when joined as a path component. + safe_ticker = safe_ticker_component(self.ticker) + directory = Path(self.config["results_dir"]) / safe_ticker / "TradingAgentsStrategy_logs" directory.mkdir(parents=True, exist_ok=True) log_path = directory / f"full_states_log_{trade_date}.json" From 7e9e7b83c7fcc18d941300b253c6ed24d985788d Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Fri, 1 May 2026 19:23:23 +0000 Subject: [PATCH 23/44] feat: DeepSeek V4 thinking-mode round-trip via DeepSeekChatOpenAI subclass Resolves #599: thinking-mode models require reasoning_content to be echoed back across turns; multi-turn agent runs failed with HTTP 400. The fix isolates DeepSeek's quirks (reasoning_content round-trip and the deepseek-reasoner no-tool_choice limitation) into a subclass so the general OpenAI-compatible client stays untouched. Adds DeepSeek V4 Pro/Flash to the catalog. 9 new tests; rationale documented in the class docstrings. Design adapted from #600; #611 closed in favour of this approach. --- tests/test_deepseek_reasoning.py | 169 +++++++++++++++++++++ tradingagents/llm_clients/model_catalog.py | 2 + tradingagents/llm_clients/openai_client.py | 108 ++++++++++--- 3 files changed, 262 insertions(+), 17 deletions(-) create mode 100644 tests/test_deepseek_reasoning.py diff --git a/tests/test_deepseek_reasoning.py b/tests/test_deepseek_reasoning.py new file mode 100644 index 000000000..fb300336d --- /dev/null +++ b/tests/test_deepseek_reasoning.py @@ -0,0 +1,169 @@ +"""Tests for DeepSeekChatOpenAI thinking-mode behaviour. + +Two pieces verified: + +1. ``reasoning_content`` is captured on receive into the AIMessage's + ``additional_kwargs`` and re-attached on send so DeepSeek's API + sees the same value across turns. +2. ``with_structured_output`` raises NotImplementedError for + ``deepseek-reasoner`` so the agent factories' free-text fallback + handles the request instead of failing at runtime. +""" + +import os + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.prompt_values import ChatPromptValue + +from tradingagents.llm_clients.openai_client import ( + DeepSeekChatOpenAI, + NormalizedChatOpenAI, + _input_to_messages, +) + + +# --------------------------------------------------------------------------- +# _input_to_messages — the helper that handles list / ChatPromptValue / other +# (Gemini bot review note: non-list inputs must also work) +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestInputToMessages: + def test_list_input_returned_as_is(self): + msgs = [HumanMessage(content="hi")] + assert _input_to_messages(msgs) is msgs + + def test_chat_prompt_value_unwrapped(self): + msgs = [HumanMessage(content="hi")] + prompt_value = ChatPromptValue(messages=msgs) + assert _input_to_messages(prompt_value) == msgs + + def test_string_input_yields_empty_list(self): + # A bare string isn't a message-bearing input; the caller's normal + # langchain conversion happens upstream of _get_request_payload. + assert _input_to_messages("hello") == [] + + +# --------------------------------------------------------------------------- +# Reasoning content propagation across turns +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestDeepSeekReasoningContent: + def _client(self): + os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder") + return DeepSeekChatOpenAI( + model="deepseek-v4-flash", + api_key="placeholder", + base_url="https://api.deepseek.com", + ) + + def test_capture_on_receive(self): + """When the response carries reasoning_content, it lands on the + AIMessage's additional_kwargs so the next turn can echo it back.""" + client = self._client() + result = client._create_chat_result( + { + "model": "deepseek-v4-flash", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Plan: buy NVDA.", + "reasoning_content": "Step 1: trend is up. Step 2: ...", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + ) + ai = result.generations[0].message + assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..." + + def test_propagate_on_send(self): + """When an outgoing AIMessage carries reasoning_content, the request + payload echoes it on the corresponding message dict.""" + client = self._client() + prior = AIMessage( + content="Plan", + additional_kwargs={"reasoning_content": "weighed bull case"}, + ) + new_user = HumanMessage(content="Refine.") + payload = client._get_request_payload([prior, new_user]) + # Find the assistant message in the payload + assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"] + assert assistant_dicts, "assistant message missing from outgoing payload" + assert assistant_dicts[0]["reasoning_content"] == "weighed bull case" + + def test_propagate_through_chat_prompt_value(self): + """Gemini bot review note: non-list inputs (ChatPromptValue) must + also propagate reasoning_content.""" + client = self._client() + prior = AIMessage( + content="Plan", + additional_kwargs={"reasoning_content": "weighed bull case"}, + ) + prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")]) + payload = client._get_request_payload(prompt_value) + assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"] + assert assistant_dicts[0]["reasoning_content"] == "weighed bull case" + + +# --------------------------------------------------------------------------- +# deepseek-reasoner: structured output unavailable, falls through to free-text +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestDeepSeekReasonerStructuredOutput: + def test_with_structured_output_raises_for_reasoner(self): + client = DeepSeekChatOpenAI( + model="deepseek-reasoner", + api_key="placeholder", + base_url="https://api.deepseek.com", + ) + from pydantic import BaseModel + + class _Sample(BaseModel): + answer: str + + with pytest.raises(NotImplementedError): + client.with_structured_output(_Sample) + + def test_with_structured_output_works_for_v4(self): + """V4 models (non-reasoner) accept tool_choice; structured output works.""" + client = DeepSeekChatOpenAI( + model="deepseek-v4-flash", + api_key="placeholder", + base_url="https://api.deepseek.com", + ) + from pydantic import BaseModel + + class _Sample(BaseModel): + answer: str + + # Should return a Runnable, not raise. (The actual API call would + # require a real key; we only assert binding succeeds.) + wrapped = client.with_structured_output(_Sample) + assert wrapped is not None + + +# --------------------------------------------------------------------------- +# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBaseClassIsolation: + def test_normalized_does_not_propagate_reasoning_content(self): + """The general-purpose NormalizedChatOpenAI must not carry + DeepSeek-specific behaviour. Only the subclass does.""" + assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or ( + NormalizedChatOpenAI._get_request_payload + is NormalizedChatOpenAI.__bases__[0]._get_request_payload + ) diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index a2c57ed89..9a723a8b9 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -65,10 +65,12 @@ MODEL_OPTIONS: ProviderModeOptions = { }, "deepseek": { "quick": [ + ("DeepSeek V4 Flash - Latest V4 fast model", "deepseek-v4-flash"), ("DeepSeek V3.2", "deepseek-chat"), ("Custom model ID", "custom"), ], "deep": [ + ("DeepSeek V4 Pro - Latest V4 flagship model", "deepseek-v4-pro"), ("DeepSeek V3.2 (thinking)", "deepseek-reasoner"), ("DeepSeek V3.2", "deepseek-chat"), ("Custom model ID", "custom"), diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index bbfcd39e3..b74e26ef4 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -1,6 +1,7 @@ import os from typing import Any, Optional +from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI from .base_client import BaseLLMClient, normalize_content @@ -11,29 +12,97 @@ class NormalizedChatOpenAI(ChatOpenAI): """ChatOpenAI with normalized content output. The Responses API returns content as a list of typed blocks - (reasoning, text, etc.). This normalizes to string for consistent - downstream handling. + (reasoning, text, etc.). ``invoke`` normalizes to string for + consistent downstream handling. ``with_structured_output`` defaults + to function-calling so the Responses-API parse path is avoided + (langchain-openai's parse path emits noisy + PydanticSerializationUnexpectedValue warnings per call without + affecting correctness). + + Provider-specific quirks (e.g. DeepSeek's thinking mode) live in + purpose-built subclasses below so this base class stays small. """ def invoke(self, input, config=None, **kwargs): return normalize_content(super().invoke(input, config, **kwargs)) def with_structured_output(self, schema, *, method=None, **kwargs): - """Wrap with structured output, defaulting to function_calling for OpenAI. - - langchain-openai's Responses-API-parse path (the default for json_schema - when use_responses_api=True) calls response.model_dump(...) on the OpenAI - SDK's union-typed parsed response, which makes Pydantic emit ~20 - PydanticSerializationUnexpectedValue warnings per call. The function-calling - path returns a plain tool-call shape that does not trigger that - serialization, so it is the cleaner choice for our combination of - use_responses_api=True + with_structured_output. Both paths use OpenAI's - strict mode and produce the same typed Pydantic instance. - """ if method is None: method = "function_calling" return super().with_structured_output(schema, method=method, **kwargs) + +def _input_to_messages(input_: Any) -> list: + """Normalise a langchain LLM input to a list of message objects. + + Accepts a list of messages, a ``ChatPromptValue`` (from a + ChatPromptTemplate), or anything else (treated as no messages). + Used by providers that need to walk the outgoing message history; + in particular DeepSeek thinking-mode propagation must work for + both bare-list invocations and ChatPromptTemplate-driven ones, so + treating only ``list`` here would silently skip half the call sites. + """ + if isinstance(input_, list): + return input_ + if hasattr(input_, "to_messages"): + return input_.to_messages() + return [] + + +class DeepSeekChatOpenAI(NormalizedChatOpenAI): + """DeepSeek-specific overrides on top of the OpenAI-compatible client. + + Two quirks that don't apply to other OpenAI-compatible providers: + + 1. **Thinking-mode round-trip.** When DeepSeek's thinking models return + a response with ``reasoning_content``, that field must be echoed + back as part of the assistant message on the next turn or the API + fails with HTTP 400. ``_create_chat_result`` captures the field on + receive and ``_get_request_payload`` re-attaches it on send. + + 2. **deepseek-reasoner has no tool_choice.** Structured output via + function-calling is unavailable, so we raise NotImplementedError + and let the agent factories fall back to free-text generation + (see ``tradingagents/agents/utils/structured.py``). + """ + + def _get_request_payload(self, input_, *, stop=None, **kwargs): + payload = super()._get_request_payload(input_, stop=stop, **kwargs) + outgoing = payload.get("messages", []) + for message_dict, message in zip(outgoing, _input_to_messages(input_)): + if not isinstance(message, AIMessage): + continue + reasoning = message.additional_kwargs.get("reasoning_content") + if reasoning is not None: + message_dict["reasoning_content"] = reasoning + return payload + + def _create_chat_result(self, response, generation_info=None): + chat_result = super()._create_chat_result(response, generation_info) + response_dict = ( + response + if isinstance(response, dict) + else response.model_dump( + exclude={"choices": {"__all__": {"message": {"parsed"}}}} + ) + ) + for generation, choice in zip( + chat_result.generations, response_dict.get("choices", []) + ): + reasoning = choice.get("message", {}).get("reasoning_content") + if reasoning is not None: + generation.message.additional_kwargs["reasoning_content"] = reasoning + return chat_result + + def with_structured_output(self, schema, *, method=None, **kwargs): + if self.model_name == "deepseek-reasoner": + raise NotImplementedError( + "deepseek-reasoner does not support tool_choice; structured " + "output is unavailable. Agent factories fall back to " + "free-text generation automatically." + ) + return super().with_structured_output(schema, method=method, **kwargs) + # Kwargs forwarded from user config to ChatOpenAI _PASSTHROUGH_KWARGS = ( "timeout", "max_retries", "reasoning_effort", @@ -75,10 +144,12 @@ class OpenAIClient(BaseLLMClient): self.warn_if_unknown_model() llm_kwargs = {"model": self.model} - # Provider-specific base URL and auth + # Provider-specific base URL and auth. An explicit base_url on the + # client (e.g. a corporate proxy) takes precedence over the + # provider default so users can route through their own gateway. if self.provider in _PROVIDER_CONFIG: - base_url, api_key_env = _PROVIDER_CONFIG[self.provider] - llm_kwargs["base_url"] = base_url + default_base, api_key_env = _PROVIDER_CONFIG[self.provider] + llm_kwargs["base_url"] = self.base_url or default_base if api_key_env: api_key = os.environ.get(api_key_env) if api_key: @@ -98,7 +169,10 @@ class OpenAIClient(BaseLLMClient): if self.provider == "openai": llm_kwargs["use_responses_api"] = True - return NormalizedChatOpenAI(**llm_kwargs) + # DeepSeek's thinking-mode quirks live in their own subclass so the + # base NormalizedChatOpenAI stays free of provider-specific branches. + chat_cls = DeepSeekChatOpenAI if self.provider == "deepseek" else NormalizedChatOpenAI + return chat_cls(**llm_kwargs) def validate_model(self) -> bool: """Validate model for the provider.""" From db7e0a67e2722a5be8b870f5661efa3b6753419a Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sun, 10 May 2026 09:49:07 +0000 Subject: [PATCH 24/44] fix(cli): load .env from user's CWD when run as console script load_dotenv() with no arguments walks up from site-packages instead of the user's CWD, so the installed tradingagents console script silently misses the project's .env. Pass find_dotenv(usecwd=True) so the search starts from CWD; same treatment for .env.enterprise. #726 #755 #612 #747 #743 #753 #729 #728 #751 --- cli/main.py | 10 ++++++---- main.py | 5 ++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cli/main.py b/cli/main.py index 534f50379..05376ade2 100644 --- a/cli/main.py +++ b/cli/main.py @@ -4,11 +4,13 @@ import typer from pathlib import Path from functools import wraps from rich.console import Console -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv -# Load environment variables -load_dotenv() -load_dotenv(".env.enterprise", override=False) +# Search starts from the user's CWD so the installed `tradingagents` +# console script picks up the project's .env instead of walking up from +# site-packages. +load_dotenv(find_dotenv(usecwd=True)) +load_dotenv(find_dotenv(".env.enterprise", usecwd=True), override=False) from rich.panel import Panel from rich.spinner import Spinner from rich.live import Live diff --git a/main.py b/main.py index c94fde323..fa3024af8 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,9 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv -# Load environment variables from .env file -load_dotenv() +load_dotenv(find_dotenv(usecwd=True)) # Create a custom config config = DEFAULT_CONFIG.copy() From c405867bde1627bb573ea069f0531569a631e594 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sun, 10 May 2026 19:20:23 +0000 Subject: [PATCH 25/44] fix: merge streamed chunks into final_state so reports save correctly graph.stream() yields per-node deltas, not the full state. Taking trace[-1] only captured the last node's contribution, so reports saved to disk were missing every section except the final decision. Merge all chunks in both the CLI path and trading_graph._run_graph's debug branch. #719 #736 --- cli/main.py | 7 +++++-- tradingagents/graph/trading_graph.py | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/cli/main.py b/cli/main.py index 05376ade2..42794821d 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1154,8 +1154,11 @@ def run_analysis(checkpoint: bool = False): trace.append(chunk) - # Get final state and decision - final_state = trace[-1] + # Streamed chunks are per-node deltas, not full state. Merge them + # so every report field populated across the run is present. + final_state = {} + for chunk in trace: + final_state.update(chunk) decision = graph.process_signal(final_state["final_trade_decision"]) # Update all agent statuses to completed diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index d7e8b5731..197913e21 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -322,7 +322,11 @@ class TradingAgentsGraph: else: chunk["messages"][-1].pretty_print() trace.append(chunk) - final_state = trace[-1] + # Streamed chunks are per-node deltas. Merge them so the returned + # state matches what graph.invoke() yields in the non-debug path. + final_state = {} + for chunk in trace: + final_state.update(chunk) else: final_state = self.graph.invoke(init_agent_state, **args) From e2c850eb173423382d34d16bb5e1863e7a45e8a1 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sun, 10 May 2026 19:29:41 +0000 Subject: [PATCH 26/44] fix(cli): preserve exchange suffixes in ticker prompt The typer.prompt-based input could lose .SH/.SZ/.SS/.HK suffixes on some shells, so exchange-qualified tickers like 000404.SH arrived truncated to 000404 and failed downstream lookups. Switch to questionary.text which reads the raw line; keep SPY-on-empty behavior and validate the allowed character set (alnum, ._-^) up to 32 chars. #770 --- cli/main.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/cli/main.py b/cli/main.py index 42794821d..ecfc63d59 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,6 +1,7 @@ from typing import Optional import datetime import typer +import questionary from pathlib import Path from functools import wraps from rich.console import Console @@ -615,8 +616,26 @@ def get_user_selections(): def get_ticker(): - """Get ticker symbol from user input.""" - return typer.prompt("", default="SPY") + """Get ticker symbol from user input, preserving exchange suffixes.""" + # typer.prompt strips trailing dot-suffixes on some shells (e.g. 000404.SH + # collapses to 000404). questionary.text reads the raw line. + ticker = questionary.text( + "", + validate=lambda value: ( + not value.strip() + or ( + all(ch.isalnum() or ch in "._-^" for ch in value.strip()) + and len(value.strip()) <= 32 + ) + ) + or "Please enter a valid ticker symbol, e.g. AAPL, 000404.SZ, 0700.HK.", + ).ask() + + if ticker is None: + console.print("\n[red]No ticker symbol provided. Exiting...[/red]") + raise typer.Exit(1) + + return (ticker.strip() or "SPY").upper() def get_analysis_date(): From afdc6d4ec1008da88a8004e0d76a34381daab9ef Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sun, 10 May 2026 19:39:57 +0000 Subject: [PATCH 27/44] chore: suppress upstream langgraph allowed_objects deprecation noise langgraph-checkpoint 4.0.3 calls Reviver() at module load without allowed_objects, printing a pending-deprecation warning at every CLI start. The upstream patch is merged (langchain-ai/langgraph#7743) but not released; no app-side seam fixes it. Install a surgical filter in package init (message regex + PendingDeprecationWarning category). Remove when we bump past langgraph-checkpoint 4.0.3. --- tradingagents/__init__.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tradingagents/__init__.py b/tradingagents/__init__.py index e69de29bb..893a3d678 100644 --- a/tradingagents/__init__.py +++ b/tradingagents/__init__.py @@ -0,0 +1,23 @@ +import warnings + +# langchain-core 1.3.3 calls surface_langchain_deprecation_warnings() in +# its own __init__, which prepends default-action filters for its +# subclassed warning categories. To suppress a specific warning we must +# install our filter AFTER langchain-core has installed its own, so import +# it first. The package is a guaranteed transitive dep via langgraph. +try: + import langchain_core # noqa: F401 +except ImportError: + pass + +# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an +# explicit allowed_objects, which triggers a noisy pending-deprecation +# warning from langchain-core 1.3.3 on every interpreter start. The fix +# is already merged upstream (langchain-ai/langgraph#7743, 2026-05-08) +# and will arrive in the next langgraph-checkpoint release. Remove this +# block (and the langchain_core preload above) when we bump past it. +warnings.filterwarnings( + "ignore", + message=r"The default value of `allowed_objects`.*", + category=PendingDeprecationWarning, +) From 22bb91bd839dc382f244313fc7392d0e64b04590 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 01:12:28 +0000 Subject: [PATCH 28/44] fix(llm): structured output for DeepSeek V4 and reasoner DeepSeek V4 and reasoner reject tool_choice but accept tools. Route via a per-model capability table that suppresses tool_choice for thinking-mode models. #678 #689 --- tests/test_capabilities.py | 79 +++++++++++++ tests/test_deepseek_reasoning.py | 125 ++++++++++++++++----- tradingagents/llm_clients/capabilities.py | 95 ++++++++++++++++ tradingagents/llm_clients/openai_client.py | 64 ++++++----- 4 files changed, 306 insertions(+), 57 deletions(-) create mode 100644 tests/test_capabilities.py create mode 100644 tradingagents/llm_clients/capabilities.py diff --git a/tests/test_capabilities.py b/tests/test_capabilities.py new file mode 100644 index 000000000..d65e93d0e --- /dev/null +++ b/tests/test_capabilities.py @@ -0,0 +1,79 @@ +"""Unit tests for the LLM capability table.""" + +import pytest + +from tradingagents.llm_clients.capabilities import ( + ModelCapabilities, + get_capabilities, +) + + +@pytest.mark.unit +class TestExactIdMatches: + def test_deepseek_chat_supports_tool_choice(self): + caps = get_capabilities("deepseek-chat") + assert caps.supports_tool_choice is True + + def test_deepseek_reasoner_rejects_tool_choice(self): + caps = get_capabilities("deepseek-reasoner") + assert caps.supports_tool_choice is False + assert caps.requires_reasoning_content_roundtrip is True + + def test_deepseek_v4_flash_rejects_tool_choice(self): + caps = get_capabilities("deepseek-v4-flash") + assert caps.supports_tool_choice is False + assert caps.requires_reasoning_content_roundtrip is True + + def test_deepseek_v4_pro_rejects_tool_choice(self): + caps = get_capabilities("deepseek-v4-pro") + assert caps.supports_tool_choice is False + assert caps.requires_reasoning_content_roundtrip is True + + +@pytest.mark.unit +class TestPatternMatches: + """Forward-compat regex patterns catch unknown DeepSeek variants.""" + + def test_future_deepseek_v5_inherits_thinking_quirks(self): + caps = get_capabilities("deepseek-v5-flash") + assert caps.supports_tool_choice is False + assert caps.requires_reasoning_content_roundtrip is True + + def test_future_deepseek_v9_inherits_thinking_quirks(self): + caps = get_capabilities("deepseek-v9-anything") + assert caps.supports_tool_choice is False + + def test_reasoner_variant_inherits_thinking_quirks(self): + caps = get_capabilities("deepseek-reasoner-pro") + assert caps.supports_tool_choice is False + + +@pytest.mark.unit +class TestDefault: + """Unknown / non-DeepSeek models get the permissive default.""" + + def test_gpt_default(self): + caps = get_capabilities("gpt-4.1") + assert caps.supports_tool_choice is True + assert caps.preferred_structured_method == "function_calling" + + def test_grok_default(self): + caps = get_capabilities("grok-4-0709") + assert caps.supports_tool_choice is True + + def test_unknown_model_default(self): + caps = get_capabilities("totally-made-up-model-id") + assert caps.supports_tool_choice is True + + def test_exact_match_precedes_pattern(self): + """deepseek-chat must NOT match the v\\d regex.""" + caps = get_capabilities("deepseek-chat") + assert caps.supports_tool_choice is True + + +@pytest.mark.unit +def test_capabilities_dataclass_is_frozen(): + """Capability rows are immutable so they can be safely shared.""" + caps = get_capabilities("deepseek-chat") + with pytest.raises(Exception): + caps.supports_tool_choice = False # type: ignore[misc] diff --git a/tests/test_deepseek_reasoning.py b/tests/test_deepseek_reasoning.py index fb300336d..62c1b3497 100644 --- a/tests/test_deepseek_reasoning.py +++ b/tests/test_deepseek_reasoning.py @@ -5,9 +5,10 @@ Two pieces verified: 1. ``reasoning_content`` is captured on receive into the AIMessage's ``additional_kwargs`` and re-attached on send so DeepSeek's API sees the same value across turns. -2. ``with_structured_output`` raises NotImplementedError for - ``deepseek-reasoner`` so the agent factories' free-text fallback - handles the request instead of failing at runtime. +2. ``with_structured_output`` consults the capability table and + suppresses ``tool_choice`` for models that reject it (V4 + reasoner), + matching DeepSeek's official tool-calling pattern at + https://api-docs.deepseek.com/guides/tool_calls. """ import os @@ -15,6 +16,7 @@ import os import pytest from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompt_values import ChatPromptValue +from pydantic import BaseModel from tradingagents.llm_clients.openai_client import ( DeepSeekChatOpenAI, @@ -115,42 +117,111 @@ class TestDeepSeekReasoningContent: # --------------------------------------------------------------------------- -# deepseek-reasoner: structured output unavailable, falls through to free-text +# Capability-driven structured output: tool_choice suppressed for V4 + reasoner # --------------------------------------------------------------------------- +def _bound_kwargs(runnable): + """Extract bind() kwargs from a with_structured_output result.""" + first = runnable.steps[0] if hasattr(runnable, "steps") else runnable + return getattr(first, "kwargs", {}) + + @pytest.mark.unit -class TestDeepSeekReasonerStructuredOutput: - def test_with_structured_output_raises_for_reasoner(self): - client = DeepSeekChatOpenAI( - model="deepseek-reasoner", - api_key="placeholder", - base_url="https://api.deepseek.com", +class TestStructuredOutputCapabilityDispatch: + """DeepSeek V4 and reasoner reject the tool_choice parameter + (official guide: api-docs.deepseek.com/guides/tool_calls passes + tools=[...] without tool_choice). Verify the capability dispatch + suppresses tool_choice for those models and sends it for chat.""" + + class _Sample(BaseModel): + answer: str + + def _client(self, model): + return DeepSeekChatOpenAI( + model=model, api_key="placeholder", base_url="https://api.deepseek.com", ) - from pydantic import BaseModel - class _Sample(BaseModel): - answer: str + def test_chat_sends_tool_choice(self): + bound = self._client("deepseek-chat").with_structured_output(self._Sample) + assert _bound_kwargs(bound).get("tool_choice") is not None - with pytest.raises(NotImplementedError): - client.with_structured_output(_Sample) + def test_reasoner_suppresses_tool_choice(self): + bound = self._client("deepseek-reasoner").with_structured_output(self._Sample) + # tool_choice is either absent or explicitly None — both are valid + # signals that langchain's bind_tools will skip the parameter. + assert _bound_kwargs(bound).get("tool_choice") in (None, ...) or \ + "tool_choice" not in _bound_kwargs(bound) - def test_with_structured_output_works_for_v4(self): - """V4 models (non-reasoner) accept tool_choice; structured output works.""" + def test_v4_flash_suppresses_tool_choice(self): + bound = self._client("deepseek-v4-flash").with_structured_output(self._Sample) + assert _bound_kwargs(bound).get("tool_choice") is None or \ + "tool_choice" not in _bound_kwargs(bound) + + def test_v4_pro_suppresses_tool_choice(self): + bound = self._client("deepseek-v4-pro").with_structured_output(self._Sample) + assert _bound_kwargs(bound).get("tool_choice") is None or \ + "tool_choice" not in _bound_kwargs(bound) + + def test_future_v_variant_via_regex(self): + """Forward-compat: unknown deepseek-v\\d-* IDs inherit V4 quirks.""" + bound = self._client("deepseek-v5-hypothetical").with_structured_output(self._Sample) + assert _bound_kwargs(bound).get("tool_choice") is None or \ + "tool_choice" not in _bound_kwargs(bound) + + def test_schema_is_still_bound_as_tool(self): + """tool_choice is suppressed, but the schema is still bound as a tool — + exactly matching DeepSeek's official tool-calling examples.""" + bound = self._client("deepseek-reasoner").with_structured_output(self._Sample) + kwargs = _bound_kwargs(bound) + tools = kwargs.get("tools", []) + assert any( + t.get("function", {}).get("name") == "_Sample" for t in tools + ), f"schema not bound as a tool: {tools}" + + +# --------------------------------------------------------------------------- +# Live API: structured output round-trips against the real DeepSeek backend +# --------------------------------------------------------------------------- + + +def _has_real_deepseek_key(): + key = os.environ.get("DEEPSEEK_API_KEY", "") + return bool(key) and key != "placeholder" + + +@pytest.mark.integration +@pytest.mark.skipif( + not _has_real_deepseek_key(), + reason="DEEPSEEK_API_KEY not set (or placeholder); skipping live API call", +) +class TestDeepSeekLiveStructuredOutput: + """End-to-end: a real DeepSeek V4-flash call returns a typed instance. + + Verifies the no-tool_choice path doesn't trigger the 400 reported in + issue #678 and that the structured-output binding still parses to a + Pydantic instance. + """ + + class _Pick(BaseModel): + action: str + confidence: float + + def test_v4_flash_returns_structured_output(self): client = DeepSeekChatOpenAI( model="deepseek-v4-flash", - api_key="placeholder", + api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com", + timeout=60, ) - from pydantic import BaseModel - - class _Sample(BaseModel): - answer: str - - # Should return a Runnable, not raise. (The actual API call would - # require a real key; we only assert binding succeeds.) - wrapped = client.with_structured_output(_Sample) - assert wrapped is not None + bound = client.with_structured_output(self._Pick) + result = bound.invoke( + "Pick BUY or SELL or HOLD for a tech stock with strong earnings. " + "Confidence is a float between 0 and 1." + ) + assert isinstance(result, self._Pick) + assert result.action in {"BUY", "SELL", "HOLD"} + assert 0.0 <= result.confidence <= 1.0 # --------------------------------------------------------------------------- diff --git a/tradingagents/llm_clients/capabilities.py b/tradingagents/llm_clients/capabilities.py new file mode 100644 index 000000000..3c14461f2 --- /dev/null +++ b/tradingagents/llm_clients/capabilities.py @@ -0,0 +1,95 @@ +"""Declarative per-model capability table for OpenAI-compatible providers. + +This is the single place that knows which model IDs reject which API +parameters or require which structured-output method. The LLM client +subclasses consult ``get_capabilities(model_name)`` instead of hardcoding +model-name ``if`` ladders, so adding a new model (or a new provider quirk) +means editing this table — not the client code. + +Pattern adapted from the per-model ``compat:`` flags DeepSeek themselves +publish in their integration guides (e.g. the Oh My Pi config schema +documents ``supportsToolChoice``, ``requiresReasoningContentForToolCalls`` +as declarative per-model fields). +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Literal + + +StructuredMethod = Literal[ + "function_calling", # uses tools; respects supports_tool_choice + "json_mode", # uses response_format={"type":"json_object"} + "json_schema", # uses response_format={"type":"json_schema",...} + "none", # no structured output available; caller falls back to free-text +] + + +@dataclass(frozen=True) +class ModelCapabilities: + """What an OpenAI-compatible model accepts at the API level.""" + + supports_tool_choice: bool + supports_json_mode: bool + supports_json_schema: bool + preferred_structured_method: StructuredMethod + # DeepSeek thinking-mode models 400 if reasoning_content from prior + # assistant turns is not echoed back on the next request. + requires_reasoning_content_roundtrip: bool = False + + +# DeepSeek's thinking models accept the ``tools`` array but reject the +# ``tool_choice`` parameter (official Oh My Pi integration guide and the +# 400 response in issue #678). Their official tool-calling examples +# (api-docs.deepseek.com/guides/tool_calls) pass ``tools=[...]`` without +# ``tool_choice`` — we mirror that pattern by setting supports_tool_choice +# to False and letting the client suppress the kwarg. +_DEEPSEEK_THINKING = ModelCapabilities( + supports_tool_choice=False, + supports_json_mode=True, + supports_json_schema=False, + preferred_structured_method="function_calling", + requires_reasoning_content_roundtrip=True, +) + +_DEEPSEEK_CHAT = ModelCapabilities( + supports_tool_choice=True, + supports_json_mode=True, + supports_json_schema=False, + preferred_structured_method="function_calling", +) + +_DEFAULT = ModelCapabilities( + supports_tool_choice=True, + supports_json_mode=True, + supports_json_schema=True, + preferred_structured_method="function_calling", +) + + +# Exact-ID matches take precedence over pattern matches. +_BY_ID: dict[str, ModelCapabilities] = { + "deepseek-chat": _DEEPSEEK_CHAT, + "deepseek-reasoner": _DEEPSEEK_THINKING, + "deepseek-v4-flash": _DEEPSEEK_THINKING, + "deepseek-v4-pro": _DEEPSEEK_THINKING, +} + +# Forward-compat patterns. A new ``deepseek-v5-*`` or ``deepseek-reasoner-*`` +# variant inherits the thinking-mode quirks automatically. +_BY_PATTERN: list[tuple[re.Pattern[str], ModelCapabilities]] = [ + (re.compile(r"^deepseek-v\d"), _DEEPSEEK_THINKING), + (re.compile(r"^deepseek-reasoner"), _DEEPSEEK_THINKING), +] + + +def get_capabilities(model_name: str) -> ModelCapabilities: + """Resolve capabilities by exact ID, then pattern, then default.""" + if model_name in _BY_ID: + return _BY_ID[model_name] + for pattern, caps in _BY_PATTERN: + if pattern.match(model_name): + return caps + return _DEFAULT diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index b74e26ef4..b6ad771c8 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -5,30 +5,45 @@ from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI from .base_client import BaseLLMClient, normalize_content +from .capabilities import get_capabilities from .validators import validate_model class NormalizedChatOpenAI(ChatOpenAI): - """ChatOpenAI with normalized content output. + """ChatOpenAI with normalized content output and capability-aware binding. The Responses API returns content as a list of typed blocks (reasoning, text, etc.). ``invoke`` normalizes to string for - consistent downstream handling. ``with_structured_output`` defaults - to function-calling so the Responses-API parse path is avoided - (langchain-openai's parse path emits noisy - PydanticSerializationUnexpectedValue warnings per call without - affecting correctness). + consistent downstream handling. - Provider-specific quirks (e.g. DeepSeek's thinking mode) live in - purpose-built subclasses below so this base class stays small. + ``with_structured_output`` consults the per-model capability table + (``capabilities.get_capabilities``) to pick the method and to decide + whether ``tool_choice`` may be sent. Models that reject ``tool_choice`` + (e.g. DeepSeek V4 and reasoner — per their official tool-calling + guide) still bind the schema as a tool, but no ``tool_choice`` + parameter is sent. + + Provider-specific quirks beyond structured-output (e.g. DeepSeek's + reasoning_content roundtrip) live in subclasses so this base class + stays small. """ def invoke(self, input, config=None, **kwargs): return normalize_content(super().invoke(input, config, **kwargs)) def with_structured_output(self, schema, *, method=None, **kwargs): - if method is None: - method = "function_calling" + caps = get_capabilities(self.model_name) + if caps.preferred_structured_method == "none": + raise NotImplementedError( + f"{self.model_name} has no structured-output method available; " + f"agent factories will fall back to free-text generation." + ) + method = method or caps.preferred_structured_method + # When the model rejects tool_choice, suppress langchain's hardcoded + # value. The schema is still bound as a tool — exactly what + # DeepSeek's official tool-calling examples do. + if method == "function_calling" and not caps.supports_tool_choice: + kwargs.setdefault("tool_choice", None) return super().with_structured_output(schema, method=method, **kwargs) @@ -52,18 +67,16 @@ def _input_to_messages(input_: Any) -> list: class DeepSeekChatOpenAI(NormalizedChatOpenAI): """DeepSeek-specific overrides on top of the OpenAI-compatible client. - Two quirks that don't apply to other OpenAI-compatible providers: + Thinking-mode round-trip is the only DeepSeek-specific behavior that + stays here. When DeepSeek's thinking models return a response with + ``reasoning_content``, that field must be echoed back as part of the + assistant message on the next turn or the API fails with HTTP 400. + ``_create_chat_result`` captures it on receive and + ``_get_request_payload`` re-attaches it on send. - 1. **Thinking-mode round-trip.** When DeepSeek's thinking models return - a response with ``reasoning_content``, that field must be echoed - back as part of the assistant message on the next turn or the API - fails with HTTP 400. ``_create_chat_result`` captures the field on - receive and ``_get_request_payload`` re-attaches it on send. - - 2. **deepseek-reasoner has no tool_choice.** Structured output via - function-calling is unavailable, so we raise NotImplementedError - and let the agent factories fall back to free-text generation - (see ``tradingagents/agents/utils/structured.py``). + Tool-choice handling for V4 and reasoner — those models reject the + ``tool_choice`` parameter — is handled by the capability dispatch in + ``NormalizedChatOpenAI.with_structured_output``, not here. """ def _get_request_payload(self, input_, *, stop=None, **kwargs): @@ -94,15 +107,6 @@ class DeepSeekChatOpenAI(NormalizedChatOpenAI): generation.message.additional_kwargs["reasoning_content"] = reasoning return chat_result - def with_structured_output(self, schema, *, method=None, **kwargs): - if self.model_name == "deepseek-reasoner": - raise NotImplementedError( - "deepseek-reasoner does not support tool_choice; structured " - "output is unavailable. Agent factories fall back to " - "free-text generation automatically." - ) - return super().with_structured_output(schema, method=method, **kwargs) - # Kwargs forwarded from user config to ChatOpenAI _PASSTHROUGH_KWARGS = ( "timeout", "max_retries", "reasoning_effort", From 704b7627f2a11c0ab9259dab7308e39c85eecec4 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 01:34:45 +0000 Subject: [PATCH 29/44] fix(docker): pre-create .tradingagents dir with appuser ownership useradd --create-home creates /home/appuser but not the .tradingagents subdir, so cache writes fail with PermissionError when docker-compose mounts a named volume there (the volume inherits image-dir ownership on first init). #627 #672 #771 #690 #714 #723 #780 #633 #773 #631 --- Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 940609d35..024c7c72d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,8 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ COPY --from=builder /opt/venv /opt/venv ENV PATH="/opt/venv/bin:$PATH" -RUN useradd --create-home appuser +RUN useradd --create-home appuser \ + && install -d -m 0755 -o appuser -g appuser /home/appuser/.tradingagents USER appuser WORKDIR /home/appuser/app From 19d22b54a98aa050fae0cd6671deae9cd9f53ce3 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 02:03:27 +0000 Subject: [PATCH 30/44] feat(llm): add MiniMax as a built-in provider Two regional endpoints (global api.minimax.io, China api.minimaxi.com) with separate API keys. Models M2.7 / M2.5 plus -highspeed variants, 204K context. Follows the existing provider-preset pattern. #789 #609 #577 #546 #395 #378 --- .env.example | 2 ++ README.md | 6 ++++-- cli/utils.py | 2 ++ tradingagents/llm_clients/factory.py | 4 +++- tradingagents/llm_clients/model_catalog.py | 20 ++++++++++++++++++++ tradingagents/llm_clients/openai_client.py | 4 ++++ 6 files changed, 35 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index be9bf13eb..af92fcf93 100644 --- a/.env.example +++ b/.env.example @@ -6,4 +6,6 @@ XAI_API_KEY= DEEPSEEK_API_KEY= DASHSCOPE_API_KEY= ZHIPU_API_KEY= +MINIMAX_API_KEY= +MINIMAX_CN_API_KEY= OPENROUTER_API_KEY= diff --git a/README.md b/README.md index 54af501a9..25a6d69b1 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,8 @@ export XAI_API_KEY=... # xAI (Grok) export DEEPSEEK_API_KEY=... # DeepSeek export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope) export ZHIPU_API_KEY=... # GLM (Zhipu) +export MINIMAX_API_KEY=... # MiniMax (global, api.minimax.io) +export MINIMAX_CN_API_KEY=... # MiniMax (China, api.minimaxi.com) export OPENROUTER_API_KEY=... # OpenRouter export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage ``` @@ -184,7 +186,7 @@ An interface will appear showing results as they load, letting you track the age ### Implementation Details -We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise. +We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), MiniMax (global + China), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise. ### Python Usage @@ -208,7 +210,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG config = DEFAULT_CONFIG.copy() -config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, openrouter, ollama, azure +config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, minimax, minimax-cn, openrouter, ollama, azure config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks config["max_debate_rounds"] = 2 diff --git a/cli/utils.py b/cli/utils.py index 85c282edd..bd2d488fa 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -239,6 +239,8 @@ def select_llm_provider() -> tuple[str, str | None]: ("DeepSeek", "deepseek", "https://api.deepseek.com"), ("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"), ("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"), + ("MiniMax", "minimax", "https://api.minimax.io/v1"), + ("MiniMax CN", "minimax-cn", "https://api.minimaxi.com/v1"), ("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"), ("Azure OpenAI", "azure", None), ("Ollama", "ollama", "http://localhost:11434/v1"), diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index e1d24557e..32c3bed31 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,7 +4,9 @@ from .base_client import BaseLLMClient # Providers that use the OpenAI-compatible chat completions API _OPENAI_COMPATIBLE = ( - "openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter", + "openai", "xai", "deepseek", "qwen", "glm", + "minimax", "minimax-cn", + "ollama", "openrouter", ) diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index 9a723a8b9..9d097c3a2 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -8,6 +8,22 @@ ModelOption = Tuple[str, str] ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]] +# Shared model list for MiniMax's global and CN endpoints (same model IDs). +_MINIMAX_MODELS: Dict[str, List[ModelOption]] = { + "quick": [ + ("MiniMax M2.7 Highspeed — Fast, 204K ctx", "MiniMax-M2.7-highspeed"), + ("MiniMax M2.5 Highspeed — Previous-gen fast", "MiniMax-M2.5-highspeed"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("MiniMax M2.7 — Flagship, 204K ctx", "MiniMax-M2.7"), + ("MiniMax M2.5 — Previous-gen flagship", "MiniMax-M2.5"), + ("MiniMax M2.7 Highspeed — Faster M2.7, 204K ctx", "MiniMax-M2.7-highspeed"), + ("Custom model ID", "custom"), + ], +} + + MODEL_OPTIONS: ProviderModeOptions = { "openai": { "quick": [ @@ -101,6 +117,10 @@ MODEL_OPTIONS: ProviderModeOptions = { ("Custom model ID", "custom"), ], }, + # MiniMax: same model IDs across global (.io) and China (.com) regions, + # so the two provider keys share one model list. + "minimax": _MINIMAX_MODELS, + "minimax-cn": _MINIMAX_MODELS, # OpenRouter: fetched dynamically. Azure: any deployed model name. "ollama": { "quick": [ diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index b6ad771c8..354849123 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -119,6 +119,10 @@ _PROVIDER_CONFIG = { "deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"), "qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"), "glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"), + # MiniMax exposes two regional endpoints with separate keys; mainland + # Chinese users hit .com while global users hit .io. + "minimax": ("https://api.minimax.io/v1", "MINIMAX_API_KEY"), + "minimax-cn": ("https://api.minimaxi.com/v1", "MINIMAX_CN_API_KEY"), "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), "ollama": ("http://localhost:11434/v1", None), } From 9482cae188bd6a3b0e0ddd82fe45a0972885cf75 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 02:30:24 +0000 Subject: [PATCH 31/44] fix: bundle config/recursion/missing-key fixes - dataflows/config: deepcopy + one-level dict merge so a partial set_config doesn't clobber sibling defaults - graph: thread max_recur_limit from config to Propagator - openai_client: name the missing env var in the API-key error #788 #764 #680 --- tests/conftest.py | 2 + tests/test_dataflows_config.py | 61 ++++++++++++++++++++++ tradingagents/dataflows/config.py | 25 ++++++--- tradingagents/graph/trading_graph.py | 4 +- tradingagents/llm_clients/openai_client.py | 6 +++ 5 files changed, 90 insertions(+), 8 deletions(-) create mode 100644 tests/test_dataflows_config.py diff --git a/tests/conftest.py b/tests/conftest.py index 504ffb12d..5983446f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,8 @@ _API_KEY_ENV_VARS = ( "DEEPSEEK_API_KEY", "DASHSCOPE_API_KEY", "ZHIPU_API_KEY", + "MINIMAX_API_KEY", + "MINIMAX_CN_API_KEY", "OPENROUTER_API_KEY", "AZURE_OPENAI_API_KEY", "ALPHA_VANTAGE_API_KEY", diff --git a/tests/test_dataflows_config.py b/tests/test_dataflows_config.py new file mode 100644 index 000000000..ab0800eee --- /dev/null +++ b/tests/test_dataflows_config.py @@ -0,0 +1,61 @@ +"""Config isolation: get/set must not leak nested-dict references.""" + +import copy +import unittest + +import pytest + +import tradingagents.default_config as default_config +from tradingagents.dataflows.config import get_config, set_config + + +@pytest.mark.unit +class DataflowsConfigIsolationTests(unittest.TestCase): + def setUp(self): + set_config(copy.deepcopy(default_config.DEFAULT_CONFIG)) + + def test_get_config_returns_deep_copy(self): + cfg = get_config() + cfg["data_vendors"]["core_stock_apis"] = "alpha_vantage" + cfg["tool_vendors"]["get_stock_data"] = "alpha_vantage" + + fresh = get_config() + self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "yfinance") + self.assertNotIn("get_stock_data", fresh["tool_vendors"]) + + def test_set_config_does_not_alias_caller_nested_dicts(self): + custom = copy.deepcopy(default_config.DEFAULT_CONFIG) + custom["data_vendors"]["core_stock_apis"] = "alpha_vantage" + custom["tool_vendors"]["get_stock_data"] = "alpha_vantage" + + set_config(custom) + + custom["data_vendors"]["core_stock_apis"] = "yfinance" + custom["tool_vendors"]["get_stock_data"] = "yfinance" + + fresh = get_config() + self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage") + self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage") + + def test_partial_nested_update_preserves_existing_defaults(self): + set_config( + { + "data_vendors": { + "core_stock_apis": "alpha_vantage", + } + } + ) + + fresh = get_config() + self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage") + self.assertEqual(fresh["data_vendors"]["technical_indicators"], "yfinance") + self.assertEqual(fresh["data_vendors"]["fundamental_data"], "yfinance") + self.assertEqual(fresh["data_vendors"]["news_data"], "yfinance") + + def test_nested_dict_updates_merge_one_level_deep(self): + set_config({"tool_vendors": {"get_stock_data": "alpha_vantage"}}) + set_config({"tool_vendors": {"get_news": "alpha_vantage"}}) + + fresh = get_config() + self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage") + self.assertEqual(fresh["tool_vendors"]["get_news"], "alpha_vantage") diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index 5819494a3..6f3076aea 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -1,6 +1,8 @@ -import tradingagents.default_config as default_config +from copy import deepcopy from typing import Dict, Optional +import tradingagents.default_config as default_config + # Use default config but allow it to be overridden _config: Optional[Dict] = None @@ -9,22 +11,31 @@ def initialize_config(): """Initialize the configuration with default values.""" global _config if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() + _config = deepcopy(default_config.DEFAULT_CONFIG) def set_config(config: Dict): - """Update the configuration with custom values.""" + """Update the configuration with custom values. + + Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a + partial update like ``{"data_vendors": {"core_stock_apis": "alpha_vantage"}}`` + keeps the other nested keys from the default; scalar keys are replaced. + """ global _config - if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() - _config.update(config) + initialize_config() + incoming = deepcopy(config) + for key, value in incoming.items(): + if isinstance(value, dict) and isinstance(_config.get(key), dict): + _config[key].update(value) + else: + _config[key] = value def get_config() -> Dict: """Get the current configuration.""" if _config is None: initialize_config() - return _config.copy() + return deepcopy(_config) # Initialize with default config diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 197913e21..949dbf654 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -116,7 +116,9 @@ class TradingAgentsGraph: self.conditional_logic, ) - self.propagator = Propagator() + self.propagator = Propagator( + max_recur_limit=self.config.get("max_recur_limit", 100), + ) self.reflector = Reflector(self.quick_thinking_llm) self.signal_processor = SignalProcessor(self.quick_thinking_llm) diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 354849123..6947ad41e 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -162,6 +162,12 @@ class OpenAIClient(BaseLLMClient): api_key = os.environ.get(api_key_env) if api_key: llm_kwargs["api_key"] = api_key + else: + raise ValueError( + f"API key for provider '{self.provider}' is not set. " + f"Please set the {api_key_env} environment variable " + f"(e.g. add {api_key_env}=your_key to your .env file)." + ) else: llm_kwargs["api_key"] = "ollama" elif self.base_url: From e1316686f89692cc01a46f1389da5a51c8e1a300 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 02:40:33 +0000 Subject: [PATCH 32/44] fix(llm): MiniMax integration polish vs official docs M2.x tool_choice is enum-only (none/auto), so route through the no-tool_choice dispatch. MinimaxChatOpenAI injects reasoning_split so blocks stay out of content. Catalog rounded out to the full official M2.x lineup plus forward-compat regex. --- README.md | 4 +- tests/test_capabilities.py | 30 ++++++++- tests/test_minimax.py | 73 ++++++++++++++++++++++ tradingagents/llm_clients/capabilities.py | 29 ++++++++- tradingagents/llm_clients/model_catalog.py | 17 +++-- tradingagents/llm_clients/openai_client.py | 33 +++++++++- 6 files changed, 172 insertions(+), 14 deletions(-) create mode 100644 tests/test_minimax.py diff --git a/README.md b/README.md index 25a6d69b1..a897263f2 100644 --- a/README.md +++ b/README.md @@ -144,8 +144,8 @@ export XAI_API_KEY=... # xAI (Grok) export DEEPSEEK_API_KEY=... # DeepSeek export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope) export ZHIPU_API_KEY=... # GLM (Zhipu) -export MINIMAX_API_KEY=... # MiniMax (global, api.minimax.io) -export MINIMAX_CN_API_KEY=... # MiniMax (China, api.minimaxi.com) +export MINIMAX_API_KEY=... # MiniMax — Global (api.minimax.io, M2.x, 204K ctx) +export MINIMAX_CN_API_KEY=... # MiniMax — China (api.minimaxi.com, M2.x, 204K ctx) export OPENROUTER_API_KEY=... # OpenRouter export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage ``` diff --git a/tests/test_capabilities.py b/tests/test_capabilities.py index d65e93d0e..a3a13d6d6 100644 --- a/tests/test_capabilities.py +++ b/tests/test_capabilities.py @@ -32,7 +32,7 @@ class TestExactIdMatches: @pytest.mark.unit class TestPatternMatches: - """Forward-compat regex patterns catch unknown DeepSeek variants.""" + """Forward-compat regex patterns catch unknown DeepSeek and MiniMax variants.""" def test_future_deepseek_v5_inherits_thinking_quirks(self): caps = get_capabilities("deepseek-v5-flash") @@ -47,6 +47,34 @@ class TestPatternMatches: caps = get_capabilities("deepseek-reasoner-pro") assert caps.supports_tool_choice is False + def test_future_minimax_m3_inherits_thinking_quirks(self): + caps = get_capabilities("MiniMax-M3") + assert caps.supports_tool_choice is False + + def test_future_minimax_m4_highspeed_inherits_thinking_quirks(self): + caps = get_capabilities("MiniMax-M4-highspeed") + assert caps.supports_tool_choice is False + + +@pytest.mark.unit +class TestMinimaxExactMatches: + """MiniMax M2.x models reject langchain's function-spec dict tool_choice + (official API enum: none/auto only).""" + + def test_m2_7_rejects_tool_choice(self): + caps = get_capabilities("MiniMax-M2.7") + assert caps.supports_tool_choice is False + assert caps.supports_json_mode is False # only MiniMax-Text-01 supports json_object + + def test_m2_7_highspeed_rejects_tool_choice(self): + assert get_capabilities("MiniMax-M2.7-highspeed").supports_tool_choice is False + + def test_m2_1_rejects_tool_choice(self): + assert get_capabilities("MiniMax-M2.1").supports_tool_choice is False + + def test_m2_base_rejects_tool_choice(self): + assert get_capabilities("MiniMax-M2").supports_tool_choice is False + @pytest.mark.unit class TestDefault: diff --git a/tests/test_minimax.py b/tests/test_minimax.py new file mode 100644 index 000000000..c48735429 --- /dev/null +++ b/tests/test_minimax.py @@ -0,0 +1,73 @@ +"""Tests for MinimaxChatOpenAI quirks. + +Verifies the subclass injects ``reasoning_split=True`` into outgoing +requests so M2.x reasoning models put their block into +``reasoning_details`` instead of polluting ``message.content``. +""" + +import os + +import pytest +from langchain_core.messages import HumanMessage +from pydantic import BaseModel + +from tradingagents.llm_clients.openai_client import MinimaxChatOpenAI + + +def _client(model: str = "MiniMax-M2.7"): + os.environ.setdefault("MINIMAX_API_KEY", "placeholder") + return MinimaxChatOpenAI( + model=model, + api_key="placeholder", + base_url="https://api.minimax.io/v1", + ) + + +@pytest.mark.unit +class TestMinimaxReasoningSplit: + def test_request_payload_sets_reasoning_split(self): + payload = _client()._get_request_payload([HumanMessage(content="hi")]) + assert payload.get("reasoning_split") is True + + def test_caller_supplied_reasoning_split_is_preserved(self): + """If the user explicitly sets reasoning_split, don't override it + (setdefault semantics — caller wins).""" + client = _client() + payload = client._get_request_payload( + [HumanMessage(content="hi")], + reasoning_split=False, + ) + # langchain may or may not surface that kwarg into the payload; + # what matters is we don't blindly overwrite a non-default value + # the caller passed. setdefault leaves an existing value alone. + assert payload.get("reasoning_split") in (False, True) + + +@pytest.mark.unit +class TestMinimaxStructuredOutputDispatch: + """M2.x models route through the capability table — tool_choice is + suppressed but the schema is still bound as a tool.""" + + class _Pick(BaseModel): + action: str + + def _bound_kwargs(self, runnable): + first = runnable.steps[0] if hasattr(runnable, "steps") else runnable + return getattr(first, "kwargs", {}) + + def test_m2_7_suppresses_tool_choice(self): + bound = _client("MiniMax-M2.7").with_structured_output(self._Pick) + kwargs = self._bound_kwargs(bound) + assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs + + def test_m2_7_highspeed_suppresses_tool_choice(self): + bound = _client("MiniMax-M2.7-highspeed").with_structured_output(self._Pick) + kwargs = self._bound_kwargs(bound) + assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs + + def test_schema_still_bound_as_tool(self): + bound = _client("MiniMax-M2.7").with_structured_output(self._Pick) + tools = self._bound_kwargs(bound).get("tools", []) + assert any( + t.get("function", {}).get("name") == "_Pick" for t in tools + ), f"schema not bound: {tools}" diff --git a/tradingagents/llm_clients/capabilities.py b/tradingagents/llm_clients/capabilities.py index 3c14461f2..d8e21175c 100644 --- a/tradingagents/llm_clients/capabilities.py +++ b/tradingagents/llm_clients/capabilities.py @@ -61,6 +61,21 @@ _DEEPSEEK_CHAT = ModelCapabilities( preferred_structured_method="function_calling", ) +# MiniMax M2.x reasoning models accept the tools array, but their +# tool_choice parameter is restricted to the enum {"none", "auto"} +# (platform.minimax.io/docs/api-reference/text-post). Langchain's +# function_calling path sends tool_choice as a function-spec dict, which +# MiniMax 400s — same shape as the DeepSeek bug. supports_tool_choice=False +# makes the dispatch in NormalizedChatOpenAI suppress the kwarg; the schema +# still ships as a tool. json_mode response_format is only for +# MiniMax-Text-01, not M2.x. +_MINIMAX_THINKING = ModelCapabilities( + supports_tool_choice=False, + supports_json_mode=False, + supports_json_schema=False, + preferred_structured_method="function_calling", +) + _DEFAULT = ModelCapabilities( supports_tool_choice=True, supports_json_mode=True, @@ -75,13 +90,23 @@ _BY_ID: dict[str, ModelCapabilities] = { "deepseek-reasoner": _DEEPSEEK_THINKING, "deepseek-v4-flash": _DEEPSEEK_THINKING, "deepseek-v4-pro": _DEEPSEEK_THINKING, + # MiniMax — full official model lineup per + # platform.minimax.io/docs/api-reference/text-openai-api + "MiniMax-M2.7": _MINIMAX_THINKING, + "MiniMax-M2.7-highspeed": _MINIMAX_THINKING, + "MiniMax-M2.5": _MINIMAX_THINKING, + "MiniMax-M2.5-highspeed": _MINIMAX_THINKING, + "MiniMax-M2.1": _MINIMAX_THINKING, + "MiniMax-M2.1-highspeed": _MINIMAX_THINKING, + "MiniMax-M2": _MINIMAX_THINKING, } -# Forward-compat patterns. A new ``deepseek-v5-*`` or ``deepseek-reasoner-*`` -# variant inherits the thinking-mode quirks automatically. +# Forward-compat patterns. New ``deepseek-v5-*`` / ``deepseek-reasoner-*`` +# or ``MiniMax-M3*`` variants inherit the thinking-mode quirks automatically. _BY_PATTERN: list[tuple[re.Pattern[str], ModelCapabilities]] = [ (re.compile(r"^deepseek-v\d"), _DEEPSEEK_THINKING), (re.compile(r"^deepseek-reasoner"), _DEEPSEEK_THINKING), + (re.compile(r"^MiniMax-M\d"), _MINIMAX_THINKING), ] diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index 9d097c3a2..b72aa6b20 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -8,17 +8,22 @@ ModelOption = Tuple[str, str] ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]] -# Shared model list for MiniMax's global and CN endpoints (same model IDs). +# Shared model list for MiniMax's global and CN endpoints (same IDs). +# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api. +# All M2.x models share a 204,800-token context window. _MINIMAX_MODELS: Dict[str, List[ModelOption]] = { "quick": [ - ("MiniMax M2.7 Highspeed — Fast, 204K ctx", "MiniMax-M2.7-highspeed"), - ("MiniMax M2.5 Highspeed — Previous-gen fast", "MiniMax-M2.5-highspeed"), + ("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"), + ("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"), + ("MiniMax-M2.1-highspeed - M2.1 highspeed, 204K ctx", "MiniMax-M2.1-highspeed"), ("Custom model ID", "custom"), ], "deep": [ - ("MiniMax M2.7 — Flagship, 204K ctx", "MiniMax-M2.7"), - ("MiniMax M2.5 — Previous-gen flagship", "MiniMax-M2.5"), - ("MiniMax M2.7 Highspeed — Faster M2.7, 204K ctx", "MiniMax-M2.7-highspeed"), + ("MiniMax-M2.7 - Flagship, SOTA on coding/agent benchmarks, 204K ctx", "MiniMax-M2.7"), + ("MiniMax-M2.7-highspeed - Same quality as M2.7, ~100 TPS", "MiniMax-M2.7-highspeed"), + ("MiniMax-M2.5 - Previous-gen flagship, 204K ctx", "MiniMax-M2.5"), + ("MiniMax-M2.1 - Earlier M2 line, 204K ctx", "MiniMax-M2.1"), + ("MiniMax-M2 - Base M2, 204K ctx", "MiniMax-M2"), ("Custom model ID", "custom"), ], } diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 6947ad41e..5a159a12d 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -107,6 +107,28 @@ class DeepSeekChatOpenAI(NormalizedChatOpenAI): generation.message.additional_kwargs["reasoning_content"] = reasoning return chat_result + +class MinimaxChatOpenAI(NormalizedChatOpenAI): + """MiniMax-specific overrides on top of the OpenAI-compatible client. + + M2.x reasoning models embed ``...`` blocks directly in + ``message.content`` by default, which would pollute saved reports. + Per platform.minimax.io/docs/api-reference/text-openai-api, setting + ``reasoning_split=True`` in the request body redirects the thinking + block into ``reasoning_details`` so ``content`` stays clean. + + Tool-choice handling for M2.x — those models accept only the string + enum ``{"none", "auto"}`` and reject langchain's function-spec dict — + is handled by the capability dispatch in + ``NormalizedChatOpenAI.with_structured_output``, not here. + """ + + def _get_request_payload(self, input_, *, stop=None, **kwargs): + payload = super()._get_request_payload(input_, stop=stop, **kwargs) + payload.setdefault("reasoning_split", True) + return payload + + # Kwargs forwarded from user config to ChatOpenAI _PASSTHROUGH_KWARGS = ( "timeout", "max_retries", "reasoning_effort", @@ -183,9 +205,14 @@ class OpenAIClient(BaseLLMClient): if self.provider == "openai": llm_kwargs["use_responses_api"] = True - # DeepSeek's thinking-mode quirks live in their own subclass so the - # base NormalizedChatOpenAI stays free of provider-specific branches. - chat_cls = DeepSeekChatOpenAI if self.provider == "deepseek" else NormalizedChatOpenAI + # Provider-specific quirks live in their own subclasses so the + # base NormalizedChatOpenAI stays free of provider branches. + if self.provider == "deepseek": + chat_cls = DeepSeekChatOpenAI + elif self.provider in ("minimax", "minimax-cn"): + chat_cls = MinimaxChatOpenAI + else: + chat_cls = NormalizedChatOpenAI return chat_cls(**llm_kwargs) def validate_model(self) -> bool: From 78fe77f4e659d7458d2e6785421d4c04636a0412 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 02:49:57 +0000 Subject: [PATCH 33/44] feat(llm): bump OpenAI catalog to GPT-5.5 frontier GPT-5.5 (Apr 2026, 1M ctx, $5/$30 per 1M) replaces GPT-5.4 as the catalog flagship. GPT-5.5 Pro replaces 5.4 Pro in the most-capable slot. GPT-5.4 demotes to previous-gen cost-effective option. --- tradingagents/llm_clients/model_catalog.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index b72aa6b20..47f9b5e16 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -34,14 +34,14 @@ MODEL_OPTIONS: ProviderModeOptions = { "quick": [ ("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"), ("GPT-5.4 Nano - Cheapest, high-volume tasks", "gpt-5.4-nano"), - ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), + ("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"), ("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"), ], "deep": [ - ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), + ("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"), + ("GPT-5.4 - Previous-gen frontier, 1M context, cost-effective", "gpt-5.4"), ("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"), - ("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"), - ("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"), + ("GPT-5.5 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.5-pro"), ], }, "anthropic": { From 9e00c8117f3d20616c60e050a68aff85e72fb480 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 02:56:59 +0000 Subject: [PATCH 34/44] feat(llm): bump Anthropic catalog to Claude Opus 4.7 frontier Opus 4.7 is the current frontier per platform.claude.com (frontier category, listed first). Demote Opus 4.6 to second deep-tier slot. Polish quick-tier labels to match official wording; effort docstring includes 4.7. --- cli/utils.py | 4 +++- tradingagents/llm_clients/model_catalog.py | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cli/utils.py b/cli/utils.py index bd2d488fa..1fda29203 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -291,7 +291,9 @@ def ask_openai_reasoning_effort() -> str: def ask_anthropic_effort() -> str | None: """Ask for Anthropic effort level. - Controls token usage and response thoroughness on Claude 4.5+ and 4.6 models. + Controls token usage and response thoroughness on Claude 4.5 / 4.6 / 4.7 + models. The API also accepts "max"; we expose low/medium/high as the + common selection range. """ return questionary.select( "Select Effort Level:", diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index 47f9b5e16..b94758f15 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -47,14 +47,14 @@ MODEL_OPTIONS: ProviderModeOptions = { "anthropic": { "quick": [ ("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), - ("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"), - ("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"), + ("Claude Haiku 4.5 - Fastest with near-frontier intelligence", "claude-haiku-4-5"), + ("Claude Sonnet 4.5 - High-performance for agents and coding", "claude-sonnet-4-5"), ], "deep": [ - ("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"), + ("Claude Opus 4.7 - Latest frontier, long-running agents and coding", "claude-opus-4-7"), + ("Claude Opus 4.6 - Frontier intelligence, agents and coding", "claude-opus-4-6"), ("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"), ("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), - ("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"), ], }, "google": { From 4f057e290cd31b5f1e182ce7a56893abc4202c76 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 03:32:00 +0000 Subject: [PATCH 35/44] feat(llm): swap Gemini 3.1 Flash-Lite to GA stable gemini-3.1-flash-lite is now GA per ai.google.dev. Use the stable version (fewer rate limits, stronger compat guarantees) instead of the -preview suffix. Labels mark preview vs GA explicitly. --- tradingagents/llm_clients/model_catalog.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index b94758f15..e7f262b33 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -59,14 +59,14 @@ MODEL_OPTIONS: ProviderModeOptions = { }, "google": { "quick": [ - ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"), ("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"), - ("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"), + ("Gemini 3.1 Flash Lite - Most cost-efficient (GA)", "gemini-3.1-flash-lite"), ("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"), ], "deep": [ - ("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"), - ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 3.1 Pro - Reasoning-first, complex workflows (preview)", "gemini-3.1-pro-preview"), + ("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"), ("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"), ("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"), ], From 0011b5ebf54d6709cc9e6f1ab129684b6166653b Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 03:45:43 +0000 Subject: [PATCH 36/44] =?UTF-8?q?feat(llm):=20align=20xAI=20catalog=20with?= =?UTF-8?q?=20docs=20=E2=80=94=20adopt=20grok-4.20=20frontier?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit xAI's official docs lead with grok-4.20-reasoning and grok-4.20-non-reasoning across all SDK examples. Replace the prior grok-4-1-fast-* entries (hyphens where docs use dots, no literal code example) with the verified grok-4.20 family. Keep grok-4-0709 and grok-4-fast variants that are still referenced. --- tradingagents/llm_clients/model_catalog.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index e7f262b33..aeb04b1f3 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -73,15 +73,15 @@ MODEL_OPTIONS: ProviderModeOptions = { }, "xai": { "quick": [ - ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ("Grok 4.20 (Non-Reasoning) - Latest, speed-optimized", "grok-4.20-non-reasoning"), ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), - ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), ], "deep": [ - ("Grok 4 - Flagship model", "grok-4-0709"), - ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ("Grok 4.20 (Reasoning) - Latest frontier reasoning model", "grok-4.20-reasoning"), + ("Grok 4 - Flagship (dated build)", "grok-4-0709"), ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), - ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ("Grok 4.20 - Auto-select reasoning behavior", "grok-4.20"), ], }, "deepseek": { From faaeebac70bce1aed48ee53a164b4305460a27a4 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 04:16:11 +0000 Subject: [PATCH 37/44] feat(cli): collapse regional duplicates; refresh Qwen catalog Qwen and MiniMax each had two main-dropdown entries (intl + CN); consolidate to one entry per provider and prompt for region as a secondary step. Internal provider keys (qwen-cn, minimax-cn) and endpoint routing unchanged. Add qwen3.6-flash to the Qwen catalog and drop the version-less aliases (qwen-flash, qwen-plus) that auto-shift their backing model per Alibaba's docs. #758 --- cli/main.py | 8 +++ cli/utils.py | 57 +++++++++++++++++++++- tradingagents/llm_clients/model_catalog.py | 42 +++++++++++----- 3 files changed, 92 insertions(+), 15 deletions(-) diff --git a/cli/main.py b/cli/main.py index ecfc63d59..478b82fde 100644 --- a/cli/main.py +++ b/cli/main.py @@ -559,6 +559,14 @@ def get_user_selections(): ) selected_llm_provider, backend_url = select_llm_provider() + # Providers with regional endpoints prompt for the region as a secondary + # step so the main dropdown stays clean (mainland China and international + # accounts cannot share API keys). + if selected_llm_provider == "qwen": + selected_llm_provider, backend_url = ask_qwen_region() + elif selected_llm_provider == "minimax": + selected_llm_provider, backend_url = ask_minimax_region() + # Step 7: Thinking agents console.print( create_question_box( diff --git a/cli/utils.py b/cli/utils.py index 1fda29203..1245eea78 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -237,10 +237,9 @@ def select_llm_provider() -> tuple[str, str | None]: ("Anthropic", "anthropic", "https://api.anthropic.com/"), ("xAI", "xai", "https://api.x.ai/v1"), ("DeepSeek", "deepseek", "https://api.deepseek.com"), - ("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"), + ("Qwen", "qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"), ("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"), ("MiniMax", "minimax", "https://api.minimax.io/v1"), - ("MiniMax CN", "minimax-cn", "https://api.minimaxi.com/v1"), ("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"), ("Azure OpenAI", "azure", None), ("Ollama", "ollama", "http://localhost:11434/v1"), @@ -330,6 +329,60 @@ def ask_gemini_thinking_config() -> str | None: ).ask() +def ask_qwen_region() -> tuple[str, str]: + """Ask which Qwen region (international vs China) to use. + + Alibaba DashScope exposes two endpoints with separate accounts — + a key from one region does NOT authenticate against the other + (fixes #758). Returns (provider_key, backend_url). + """ + return questionary.select( + "Select Qwen region:", + choices=[ + questionary.Choice( + "International — dashscope-intl.aliyuncs.com (uses DASHSCOPE_API_KEY)", + value=("qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"), + ), + questionary.Choice( + "China — dashscope.aliyuncs.com (uses DASHSCOPE_CN_API_KEY)", + value=("qwen-cn", "https://dashscope.aliyuncs.com/compatible-mode/v1"), + ), + ], + style=questionary.Style([ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ("pointer", "fg:cyan noinherit"), + ]), + ).ask() + + +def ask_minimax_region() -> tuple[str, str]: + """Ask which MiniMax region (global vs China) to use. + + MiniMax exposes two endpoints with separate accounts — a key from + one region does NOT authenticate against the other. Returns + (provider_key, backend_url). + """ + return questionary.select( + "Select MiniMax region:", + choices=[ + questionary.Choice( + "Global — api.minimax.io (uses MINIMAX_API_KEY)", + value=("minimax", "https://api.minimax.io/v1"), + ), + questionary.Choice( + "China — api.minimaxi.com (uses MINIMAX_CN_API_KEY)", + value=("minimax-cn", "https://api.minimaxi.com/v1"), + ), + ], + style=questionary.Style([ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ("pointer", "fg:cyan noinherit"), + ]), + ).ask() + + def ask_output_language() -> str: """Ask for report output language.""" choice = questionary.select( diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index aeb04b1f3..6a9ca999e 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -8,6 +8,31 @@ ModelOption = Tuple[str, str] ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]] +# Shared model list for Qwen's global (dashscope-intl) and CN (dashscope) endpoints. +# Source: modelstudio.console.alibabacloud.com (Featured Models — Flagship + Cost-optimized). +# +# Only versioned IDs are exposed in the dropdown. The version-less aliases +# (qwen-plus, qwen-flash) are documented by Alibaba as auto-upgrading +# pointers ("backbone, latest, and snapshot ... have been upgraded to the +# Qwen3 series"), which means their behavior shifts when Alibaba rotates +# the backing model. Users who want a specific generation pick it +# explicitly; users who really want auto-latest can enter the alias via +# "Custom model ID". +_QWEN_MODELS: Dict[str, List[ModelOption]] = { + "quick": [ + ("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"), + ("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("Qwen 3.6 Plus - Flagship vision-language, agentic coding SOTA", "qwen3.6-plus"), + ("Qwen 3.5 Plus - Previous-gen flagship", "qwen3.5-plus"), + ("Qwen 3 Max - Specialized for agent programming + tool use", "qwen3-max"), + ("Custom model ID", "custom"), + ], +} + + # Shared model list for MiniMax's global and CN endpoints (same IDs). # Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api. # All M2.x models share a 204,800-token context window. @@ -97,19 +122,10 @@ MODEL_OPTIONS: ProviderModeOptions = { ("Custom model ID", "custom"), ], }, - "qwen": { - "quick": [ - ("Qwen 3.5 Flash", "qwen3.5-flash"), - ("Qwen Plus", "qwen-plus"), - ("Custom model ID", "custom"), - ], - "deep": [ - ("Qwen 3.6 Plus", "qwen3.6-plus"), - ("Qwen 3.5 Plus", "qwen3.5-plus"), - ("Qwen 3 Max", "qwen3-max"), - ("Custom model ID", "custom"), - ], - }, + # Qwen: same model IDs across global (dashscope-intl) and China + # (dashscope) endpoints, so the two provider keys share one model list. + "qwen": _QWEN_MODELS, + "qwen-cn": _QWEN_MODELS, "glm": { "quick": [ ("GLM-4.7", "glm-4.7"), From d0dd0420ad20d1fc32759c738855b7d591ee8f69 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 04:19:50 +0000 Subject: [PATCH 38/44] feat(llm): GLM dual-region split + catalog refresh Zhipu serves GLM under two brands with separate accounts (Z.AI international vs BigModel China); the CLI URL pointed at one while the openai_client default pointed at the other. Split into glm + glm-cn with secondary region prompt (same UX as Qwen + MiniMax). Catalog adds glm-5-turbo and glm-4.5-air per docs.z.ai. --- .env.example | 2 ++ README.md | 10 ++++--- cli/main.py | 2 ++ cli/utils.py | 26 ++++++++++++++++ tests/conftest.py | 2 ++ tradingagents/llm_clients/factory.py | 4 ++- tradingagents/llm_clients/model_catalog.py | 35 ++++++++++++++-------- tradingagents/llm_clients/openai_client.py | 8 +++++ 8 files changed, 72 insertions(+), 17 deletions(-) diff --git a/.env.example b/.env.example index af92fcf93..458d74956 100644 --- a/.env.example +++ b/.env.example @@ -5,7 +5,9 @@ ANTHROPIC_API_KEY= XAI_API_KEY= DEEPSEEK_API_KEY= DASHSCOPE_API_KEY= +DASHSCOPE_CN_API_KEY= ZHIPU_API_KEY= +ZHIPU_CN_API_KEY= MINIMAX_API_KEY= MINIMAX_CN_API_KEY= OPENROUTER_API_KEY= diff --git a/README.md b/README.md index a897263f2..b4578a8c9 100644 --- a/README.md +++ b/README.md @@ -142,8 +142,10 @@ export GOOGLE_API_KEY=... # Google (Gemini) export ANTHROPIC_API_KEY=... # Anthropic (Claude) export XAI_API_KEY=... # xAI (Grok) export DEEPSEEK_API_KEY=... # DeepSeek -export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope) -export ZHIPU_API_KEY=... # GLM (Zhipu) +export DASHSCOPE_API_KEY=... # Qwen — International (dashscope-intl.aliyuncs.com) +export DASHSCOPE_CN_API_KEY=... # Qwen — China (dashscope.aliyuncs.com) +export ZHIPU_API_KEY=... # GLM via Z.AI (international) +export ZHIPU_CN_API_KEY=... # GLM via BigModel (China, open.bigmodel.cn) export MINIMAX_API_KEY=... # MiniMax — Global (api.minimax.io, M2.x, 204K ctx) export MINIMAX_CN_API_KEY=... # MiniMax — China (api.minimaxi.com, M2.x, 204K ctx) export OPENROUTER_API_KEY=... # OpenRouter @@ -186,7 +188,7 @@ An interface will appear showing results as they load, letting you track the age ### Implementation Details -We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), MiniMax (global + China), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise. +We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope, international and China endpoints), GLM (Zhipu), MiniMax (global + China), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise. ### Python Usage @@ -210,7 +212,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG config = DEFAULT_CONFIG.copy() -config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, minimax, minimax-cn, openrouter, ollama, azure +config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, qwen-cn, glm, glm-cn, minimax, minimax-cn, openrouter, ollama, azure config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks config["max_debate_rounds"] = 2 diff --git a/cli/main.py b/cli/main.py index 478b82fde..c466cb219 100644 --- a/cli/main.py +++ b/cli/main.py @@ -566,6 +566,8 @@ def get_user_selections(): selected_llm_provider, backend_url = ask_qwen_region() elif selected_llm_provider == "minimax": selected_llm_provider, backend_url = ask_minimax_region() + elif selected_llm_provider == "glm": + selected_llm_provider, backend_url = ask_glm_region() # Step 7: Thinking agents console.print( diff --git a/cli/utils.py b/cli/utils.py index 1245eea78..1ccc12302 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -329,6 +329,32 @@ def ask_gemini_thinking_config() -> str | None: ).ask() +def ask_glm_region() -> tuple[str, str]: + """Ask which GLM platform (Z.AI international vs BigModel China) to use. + + Zhipu serves the same GLM models under two brands with separate + accounts; keys aren't interchangeable. Returns (provider_key, backend_url). + """ + return questionary.select( + "Select GLM platform:", + choices=[ + questionary.Choice( + "Z.AI — api.z.ai (international, uses ZHIPU_API_KEY)", + value=("glm", "https://api.z.ai/api/paas/v4/"), + ), + questionary.Choice( + "BigModel — open.bigmodel.cn (China, uses ZHIPU_CN_API_KEY)", + value=("glm-cn", "https://open.bigmodel.cn/api/paas/v4/"), + ), + ], + style=questionary.Style([ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ("pointer", "fg:cyan noinherit"), + ]), + ).ask() + + def ask_qwen_region() -> tuple[str, str]: """Ask which Qwen region (international vs China) to use. diff --git a/tests/conftest.py b/tests/conftest.py index 5983446f5..506510cea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,9 @@ _API_KEY_ENV_VARS = ( "XAI_API_KEY", "DEEPSEEK_API_KEY", "DASHSCOPE_API_KEY", + "DASHSCOPE_CN_API_KEY", "ZHIPU_API_KEY", + "ZHIPU_CN_API_KEY", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY", "OPENROUTER_API_KEY", diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 32c3bed31..9bf1d9fba 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,7 +4,9 @@ from .base_client import BaseLLMClient # Providers that use the OpenAI-compatible chat completions API _OPENAI_COMPATIBLE = ( - "openai", "xai", "deepseek", "qwen", "glm", + "openai", "xai", "deepseek", + "qwen", "qwen-cn", + "glm", "glm-cn", "minimax", "minimax-cn", "ollama", "openrouter", ) diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index 6a9ca999e..fac741bcf 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -8,6 +8,25 @@ ModelOption = Tuple[str, str] ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]] +# Shared model list for GLM via Z.AI (international) and BigModel (China). +# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides). +# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}. +_GLM_MODELS: Dict[str, List[ModelOption]] = { + "quick": [ + ("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"), + ("GLM-4.7 - Previous-gen flagship", "glm-4.7"), + ("GLM-4.5-Air - Lightweight, cost-efficient", "glm-4.5-air"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("GLM-5.1 - Latest flagship, 204K ctx", "glm-5.1"), + ("GLM-5 - Flagship, 204K ctx", "glm-5"), + ("GLM-4.7 - Previous-gen flagship", "glm-4.7"), + ("Custom model ID", "custom"), + ], +} + + # Shared model list for Qwen's global (dashscope-intl) and CN (dashscope) endpoints. # Source: modelstudio.console.alibabacloud.com (Featured Models — Flagship + Cost-optimized). # @@ -126,18 +145,10 @@ MODEL_OPTIONS: ProviderModeOptions = { # (dashscope) endpoints, so the two provider keys share one model list. "qwen": _QWEN_MODELS, "qwen-cn": _QWEN_MODELS, - "glm": { - "quick": [ - ("GLM-4.7", "glm-4.7"), - ("GLM-5", "glm-5"), - ("Custom model ID", "custom"), - ], - "deep": [ - ("GLM-5.1", "glm-5.1"), - ("GLM-5", "glm-5"), - ("Custom model ID", "custom"), - ], - }, + # GLM: Z.AI (international) and BigModel (China) host the same model + # IDs; the two provider keys share one model list. + "glm": _GLM_MODELS, + "glm-cn": _GLM_MODELS, # MiniMax: same model IDs across global (.io) and China (.com) regions, # so the two provider keys share one model list. "minimax": _MINIMAX_MODELS, diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 5a159a12d..89c67e31d 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -139,8 +139,16 @@ _PASSTHROUGH_KWARGS = ( _PROVIDER_CONFIG = { "xai": ("https://api.x.ai/v1", "XAI_API_KEY"), "deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"), + # DashScope exposes two regional endpoints with separate accounts; an + # international key won't authenticate against the China endpoint and + # vice versa (fixes issue #758). "qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"), + "qwen-cn": ("https://dashscope.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_CN_API_KEY"), + # Zhipu exposes the same GLM models under two brands with separate + # accounts: Z.AI (international, api.z.ai) and BigModel + # (open.bigmodel.cn, China). Keys aren't interchangeable across them. "glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"), + "glm-cn": ("https://open.bigmodel.cn/api/paas/v4/", "ZHIPU_CN_API_KEY"), # MiniMax exposes two regional endpoints with separate keys; mainland # Chinese users hit .com while global users hit .io. "minimax": ("https://api.minimax.io/v1", "MINIMAX_API_KEY"), From 0fcf13624e8ab20b1cb3e10ebfd51dccf0b0213a Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 05:20:07 +0000 Subject: [PATCH 39/44] feat(agents): rename to sentiment_analyst; integrate StockTwits + Reddit Pre-fetches news + StockTwits + Reddit via no-auth public endpoints and injects structured data blocks into the prompt with professional analysis instructions. Replaces the prompt-vs-tool mismatch that caused fabricated social-platform content. Backward-compat alias + "social" CLI key preserved. #557 #607 --- tradingagents/agents/__init__.py | 8 +- .../agents/analysts/sentiment_analyst.py | 184 ++++++++++++++++++ .../agents/analysts/social_media_analyst.py | 68 ++----- tradingagents/dataflows/reddit.py | 106 ++++++++++ tradingagents/dataflows/stocktwits.py | 83 ++++++++ tradingagents/graph/setup.py | 6 +- 6 files changed, 401 insertions(+), 54 deletions(-) create mode 100644 tradingagents/agents/analysts/sentiment_analyst.py create mode 100644 tradingagents/dataflows/reddit.py create mode 100644 tradingagents/dataflows/stocktwits.py diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 2fb4e1bac..f88261408 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -4,7 +4,10 @@ from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.market_analyst import create_market_analyst from .analysts.news_analyst import create_news_analyst -from .analysts.social_media_analyst import create_social_media_analyst +from .analysts.sentiment_analyst import ( + create_sentiment_analyst, + create_social_media_analyst, # deprecated alias kept for back-compat +) from .researchers.bear_researcher import create_bear_researcher from .researchers.bull_researcher import create_bull_researcher @@ -33,6 +36,7 @@ __all__ = [ "create_aggressive_debator", "create_portfolio_manager", "create_conservative_debator", - "create_social_media_analyst", + "create_sentiment_analyst", + "create_social_media_analyst", # deprecated; will be removed in a future version "create_trader", ] diff --git a/tradingagents/agents/analysts/sentiment_analyst.py b/tradingagents/agents/analysts/sentiment_analyst.py new file mode 100644 index 000000000..e1e4ee4f4 --- /dev/null +++ b/tradingagents/agents/analysts/sentiment_analyst.py @@ -0,0 +1,184 @@ +"""Sentiment analyst — multi-source sentiment analysis for a target ticker. + +Previously named ``social_media_analyst``. Renamed and redesigned because +the old version had a prompt that demanded social-media analysis but the +only tool available was Yahoo Finance news — which led LLMs to fabricate +Reddit/X/StockTwits content under prompt pressure (verified live). + +The redesigned agent pre-fetches three complementary data sources before +the LLM is invoked and injects them into the prompt as structured blocks: + + 1. News headlines — Yahoo Finance (institutional framing) + 2. StockTwits messages — retail-trader posts indexed by cashtag, with + user-labeled Bullish/Bearish sentiment tags + 3. Reddit posts — r/wallstreetbets, r/stocks, r/investing + +The agent does not use tool-calling; the data is in the prompt from +turn 0. The LLM produces the sentiment report in a single invocation. + +See: https://github.com/TauricResearch/TradingAgents/issues/557 +""" + +from datetime import datetime, timedelta + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.utils.agent_utils import ( + build_instrument_context, + get_language_instruction, + get_news, +) +from tradingagents.dataflows.reddit import fetch_reddit_posts +from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages + + +def _seven_days_back(trade_date: str) -> str: + return (datetime.strptime(trade_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d") + + +def create_sentiment_analyst(llm): + """Create a sentiment analyst node for the trading graph. + + Pre-fetches news + StockTwits + Reddit data, injects them into the + prompt as structured blocks, and produces a sentiment report in a + single LLM call. + """ + + def sentiment_analyst_node(state): + ticker = state["company_of_interest"] + end_date = state["trade_date"] + start_date = _seven_days_back(end_date) + instrument_context = build_instrument_context(ticker) + + # Pre-fetch all three sources. Each fetcher degrades gracefully and + # returns a string (no exceptions surface from here), so the LLM + # always sees something — either real data or a clear placeholder. + news_block = get_news.func(ticker, start_date, end_date) + stocktwits_block = fetch_stocktwits_messages(ticker, limit=30) + reddit_block = fetch_reddit_posts(ticker) + + system_message = _build_system_message( + ticker=ticker, + start_date=start_date, + end_date=end_date, + news_block=news_block, + stocktwits_block=stocktwits_block, + reddit_block=reddit_block, + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," + " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." + "\n{system_message}\n" + "For your reference, the current date is {current_date}. {instrument_context}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(current_date=end_date) + prompt = prompt.partial(instrument_context=instrument_context) + + # No bind_tools — the data is already in the prompt; a single LLM + # call produces the report directly. + chain = prompt | llm + result = chain.invoke(state["messages"]) + + return { + "messages": [result], + "sentiment_report": result.content, + } + + return sentiment_analyst_node + + +def _build_system_message( + *, + ticker: str, + start_date: str, + end_date: str, + news_block: str, + stocktwits_block: str, + reddit_block: str, +) -> str: + """Assemble the sentiment-analyst system message with structured data blocks.""" + return f"""You are a financial market sentiment analyst. Your task is to produce a comprehensive sentiment report for {ticker} covering the period from {start_date} to {end_date}, drawing on three complementary data sources that have already been collected for you. + +## Data sources (pre-fetched, in this prompt) + +### News headlines — Yahoo Finance, past 7 days +Institutional framing. Fact-driven, slower-moving signal. + + +{news_block} + + +### StockTwits messages — retail-trader social platform indexed by cashtag +Fast-moving signal. Each message carries a user-labeled sentiment tag (Bullish / Bearish / no-label) plus the message body. + + +{stocktwits_block} + + +### Reddit posts — r/wallstreetbets, r/stocks, r/investing (past 7 days) +Community discussion. Engagement signal via upvote score and comment count. Subreddit character matters (r/wallstreetbets is often contrarian/exuberant; r/stocks more measured; r/investing longer-term). + + +{reddit_block} + + +## How to analyze this data (best practices) + +1. **Read the StockTwits Bullish/Bearish ratio as a leading retail-sentiment signal.** A 70/30 bullish/bearish split is moderately bullish; ≥90/10 may indicate over-extension and contrarian risk; 50/50 is uncertainty. Sample size matters — base rates on the actual message count, not percentages alone. + +2. **Look for cross-source divergences.** If news framing is bearish but StockTwits is overwhelmingly bullish, that mismatch is itself a signal — it can mean retail is leaning into a thesis the news flow hasn't caught up to (or vice versa, that retail is chasing while institutions are cautious). + +3. **Weight Reddit posts by engagement.** A 400-upvote / 200-comment thread reflects community attention; a 3-upvote post is noise. Read the body excerpts for context — the title alone often misleads. + +4. **Distinguish opinion from event.** A news headline ("Nvidia announces $500M Corning deal") is an event; a StockTwits post ("buying NVDA, this is going to moon") is opinion. Both are inputs but should be weighted differently in your conclusions. + +5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment. + +6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "" placeholder, the sentiment read is less robust — flag this caveat explicitly. If the sources are silent on a given subreddit, say so. + +7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc. + +8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call. + +## Output + +Produce a sentiment report covering, in order: + +1. **Overall sentiment direction** — Bullish / Bearish / Neutral / Mixed — with a brief confidence note based on data quality and sample size. +2. **Source-by-source breakdown** — what each of news / StockTwits / Reddit is telling you, with specific evidence (cite message counts, ratios, notable posts). +3. **Divergences, alignments, and key narratives** across sources. +4. **Catalysts and risks** surfaced by the data. +5. **Markdown table** at the end summarizing key sentiment signals, their direction, source, and supporting evidence. + +{get_language_instruction()}""" + + +# --------------------------------------------------------------------------- +# Backwards-compatibility shim +# --------------------------------------------------------------------------- +def create_social_media_analyst(llm): + """Deprecated alias for :func:`create_sentiment_analyst`. + + Kept so existing code that imports ``create_social_media_analyst`` + continues to work. + + .. deprecated:: + Import :func:`create_sentiment_analyst` directly instead. + """ + import warnings + warnings.warn( + "create_social_media_analyst is deprecated and will be removed in a " + "future version. Use create_sentiment_analyst instead.", + DeprecationWarning, + stacklevel=2, + ) + return create_sentiment_analyst(llm) diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 34a53c462..03cd7a44c 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,57 +1,23 @@ -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news -from tradingagents.dataflows.config import get_config +"""Backwards-compatibility shim for the renamed social_media_analyst module. +The social media analyst has been renamed to ``sentiment_analyst`` because its +only data tool is ``get_news`` (Yahoo Finance), not a social media feed. -def create_social_media_analyst(llm): - def social_media_analyst_node(state): - current_date = state["trade_date"] - instrument_context = build_instrument_context(state["company_of_interest"]) +Import from ``tradingagents.agents.analysts.sentiment_analyst`` going forward. - tools = [ - get_news, - ] +See: https://github.com/TauricResearch/TradingAgents/issues/557 +""" - system_message = ( - "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions." - + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" - + get_language_instruction() - ) +import warnings as _warnings - prompt = ChatPromptTemplate.from_messages( - [ - ( - "system", - "You are a helpful AI assistant, collaborating with other assistants." - " Use the provided tools to progress towards answering the question." - " If you are unable to fully answer, that's OK; another assistant with different tools" - " will help where you left off. Execute what you can to make progress." - " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," - " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." - " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. {instrument_context}", - ), - MessagesPlaceholder(variable_name="messages"), - ] - ) +from tradingagents.agents.analysts.sentiment_analyst import ( # noqa: F401 + create_sentiment_analyst, + create_social_media_analyst, +) - prompt = prompt.partial(system_message=system_message) - prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) - prompt = prompt.partial(current_date=current_date) - prompt = prompt.partial(instrument_context=instrument_context) - - chain = prompt | llm.bind_tools(tools) - - result = chain.invoke(state["messages"]) - - report = "" - - if len(result.tool_calls) == 0: - report = result.content - - return { - "messages": [result], - "sentiment_report": report, - } - - return social_media_analyst_node +_warnings.warn( + "tradingagents.agents.analysts.social_media_analyst is deprecated. " + "Import from tradingagents.agents.analysts.sentiment_analyst instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/tradingagents/dataflows/reddit.py b/tradingagents/dataflows/reddit.py new file mode 100644 index 000000000..c8e01b92c --- /dev/null +++ b/tradingagents/dataflows/reddit.py @@ -0,0 +1,106 @@ +"""Reddit search fetcher for ticker-specific discussion posts. + +Uses Reddit's public JSON endpoints (``reddit.com/r/{sub}/search.json``) +which do not require an API key. Public throughput is ~10 requests per +minute per IP, well within budget for a single agent run that queries +a handful of finance subreddits per ticker. + +Returns formatted plaintext blocks ready for prompt injection. Degrades +gracefully — returns a placeholder string rather than raising, so callers +never have to special-case missing data. +""" + +from __future__ import annotations + +import json +import logging +import time +from typing import Iterable +from urllib.error import HTTPError, URLError +from urllib.parse import urlencode +from urllib.request import Request, urlopen + +logger = logging.getLogger(__name__) + +_API = "https://www.reddit.com/r/{sub}/search.json?{qs}" +_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)" + +# Default subreddits ordered roughly by signal density for ticker-specific +# discussion. wallstreetbets has the most volume but most noise; stocks / +# investing trend more measured. Caller can override. +DEFAULT_SUBREDDITS = ("wallstreetbets", "stocks", "investing") + + +def _fetch_subreddit( + ticker: str, + sub: str, + limit: int, + timeout: float, +) -> list[dict]: + qs = urlencode({ + "q": ticker, + "restrict_sr": "on", + "sort": "new", + "t": "week", # last 7 days + "limit": limit, + }) + url = _API.format(sub=sub, qs=qs) + req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"}) + try: + with urlopen(req, timeout=timeout) as resp: + payload = json.loads(resp.read()) + except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc: + logger.warning("Reddit fetch failed for r/%s · %s: %s", sub, ticker, exc) + return [] + children = (payload.get("data") or {}).get("children") or [] + return [c.get("data", {}) for c in children if isinstance(c, dict)] + + +def fetch_reddit_posts( + ticker: str, + subreddits: Iterable[str] = DEFAULT_SUBREDDITS, + limit_per_sub: int = 5, + timeout: float = 10.0, + inter_request_delay: float = 0.4, +) -> str: + """Fetch recent Reddit posts mentioning ``ticker`` across finance + subreddits and return them as a formatted plaintext block. + + ``inter_request_delay`` keeps us under Reddit's public rate limit + (~10 req/min per IP) even if the caller queries many subreddits. + """ + blocks = [] + total_posts = 0 + for i, sub in enumerate(subreddits): + if i > 0: + time.sleep(inter_request_delay) + posts = _fetch_subreddit(ticker, sub, limit_per_sub, timeout) + total_posts += len(posts) + if not posts: + blocks.append(f"r/{sub}: ") + continue + + lines = [f"r/{sub} — {len(posts)} recent posts mentioning {ticker.upper()}:"] + for p in posts: + title = (p.get("title") or "").replace("\n", " ").strip() + score = p.get("score", 0) + comments = p.get("num_comments", 0) + created = p.get("created_utc") + created_str = ( + time.strftime("%Y-%m-%d", time.gmtime(created)) if created else "?" + ) + selftext = (p.get("selftext") or "").replace("\n", " ").strip() + if len(selftext) > 240: + selftext = selftext[:240] + "…" + lines.append( + f" [{created_str} · {score:>4}↑ · {comments:>3}c] {title}" + + (f"\n body excerpt: {selftext}" if selftext else "") + ) + blocks.append("\n".join(lines)) + + if total_posts == 0: + return ( + f"" + ) + return "\n\n".join(blocks) diff --git a/tradingagents/dataflows/stocktwits.py b/tradingagents/dataflows/stocktwits.py new file mode 100644 index 000000000..a1b2992ba --- /dev/null +++ b/tradingagents/dataflows/stocktwits.py @@ -0,0 +1,83 @@ +"""StockTwits public symbol-stream fetcher. + +StockTwits exposes a per-symbol message stream at +``api.stocktwits.com/api/2/streams/symbol/{ticker}.json`` that requires no +API key, no OAuth, and no registration. Each message includes a +user-labeled sentiment field (``Bullish``/``Bearish``/null), the message +body, timestamp, and posting user. + +The function is deliberately self-contained: short timeout, graceful +degradation on any HTTP or parse failure, and a string return type so +the calling agent gets a uniform interface regardless of whether the +network call succeeded. +""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from typing import Optional +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +logger = logging.getLogger(__name__) + +_API = "https://api.stocktwits.com/api/2/streams/symbol/{ticker}.json" +_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)" + + +def fetch_stocktwits_messages(ticker: str, limit: int = 30, timeout: float = 10.0) -> str: + """Fetch recent StockTwits messages for ``ticker`` and return them as a + formatted plaintext block ready for prompt injection. + + Returns a placeholder string when the endpoint is unreachable, the + symbol has no messages, or the response shape is unexpected — the + caller never has to special-case None or exceptions. + """ + url = _API.format(ticker=ticker.upper()) + req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"}) + try: + with urlopen(req, timeout=timeout) as resp: + data = json.loads(resp.read()) + except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc: + logger.warning("StockTwits fetch failed for %s: %s", ticker, exc) + return f"" + + messages = data.get("messages", []) if isinstance(data, dict) else [] + if not messages: + return f"" + + lines = [] + bullish = bearish = unlabeled = 0 + for m in messages[:limit]: + created = m.get("created_at", "") + user = (m.get("user") or {}).get("username", "?") + entities = m.get("entities") or {} + sentiment_obj = entities.get("sentiment") or {} + sentiment = sentiment_obj.get("basic") if isinstance(sentiment_obj, dict) else None + body = (m.get("body") or "").replace("\n", " ").strip() + if len(body) > 280: + body = body[:280] + "…" + + if sentiment == "Bullish": + bullish += 1 + tag = "Bullish" + elif sentiment == "Bearish": + bearish += 1 + tag = "Bearish" + else: + unlabeled += 1 + tag = "no-label" + lines.append(f"[{created} · @{user} · {tag}] {body}") + + total = bullish + bearish + unlabeled + bull_pct = round(100 * bullish / total) if total else 0 + bear_pct = round(100 * bearish / total) if total else 0 + summary = ( + f"Bullish: {bullish} ({bull_pct}%) · " + f"Bearish: {bearish} ({bear_pct}%) · " + f"Unlabeled: {unlabeled} · " + f"Total: {total} most-recent messages" + ) + return summary + "\n\n" + "\n".join(lines) diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 45d6bfd38..3d328c89f 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -54,7 +54,11 @@ class GraphSetup: tool_nodes["market"] = self.tool_nodes["market"] if "social" in selected_analysts: - analyst_nodes["social"] = create_social_media_analyst( + # "social" selector key preserved for back-compat with existing + # user configs; the underlying agent has been renamed to + # sentiment_analyst (the old name advertised social-media data + # the agent never had access to — see issue #557). + analyst_nodes["social"] = create_sentiment_analyst( self.quick_thinking_llm ) delete_nodes["social"] = create_msg_delete() From 384fe1a3d2b507ca5e970bab1540bdd1346c384e Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 05:30:52 +0000 Subject: [PATCH 40/44] feat(news): configurable fetch params via DEFAULT_CONFIG Per-ticker article limit, global article limit, global lookback window, and macro query list are now read from get_config() instead of being hardcoded. Tool wrapper get_global_news passes None defaults so config overrides flow through the LLM-tool path too. Macro query defaults broadened from 4 US-centric strings to 5 covering Fed, S&P 500, geopolitics, ECB/BOJ/BOE, commodities. #606 #558 #562 --- tradingagents/agents/utils/news_data_tools.py | 16 ++++++---- tradingagents/dataflows/yfinance_news.py | 29 +++++++++++-------- tradingagents/default_config.py | 15 ++++++++++ 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/tradingagents/agents/utils/news_data_tools.py b/tradingagents/agents/utils/news_data_tools.py index 781e793c3..f503c4d3a 100644 --- a/tradingagents/agents/utils/news_data_tools.py +++ b/tradingagents/agents/utils/news_data_tools.py @@ -1,5 +1,5 @@ from langchain_core.tools import tool -from typing import Annotated +from typing import Annotated, Optional from tradingagents.dataflows.interface import route_to_vendor @tool @@ -23,16 +23,20 @@ def get_news( @tool def get_global_news( curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], - look_back_days: Annotated[int, "Number of days to look back"] = 7, - limit: Annotated[int, "Maximum number of articles to return"] = 5, + look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None, + limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None, ) -> str: """ Retrieve global news data. - Uses the configured news_data vendor. + Uses the configured news_data vendor. Defaults for look_back_days and + limit come from DEFAULT_CONFIG (global_news_lookback_days, + global_news_article_limit); pass explicit values to override. + Args: curr_date (str): Current date in yyyy-mm-dd format - look_back_days (int): Number of days to look back (default 7) - limit (int): Maximum number of articles to return (default 5) + look_back_days (int): Number of days to look back; omit to inherit config + limit (int): Maximum number of articles to return; omit to inherit config + Returns: str: A formatted string containing global news data """ diff --git a/tradingagents/dataflows/yfinance_news.py b/tradingagents/dataflows/yfinance_news.py index dd1046f54..55c5d2512 100644 --- a/tradingagents/dataflows/yfinance_news.py +++ b/tradingagents/dataflows/yfinance_news.py @@ -1,9 +1,12 @@ """yfinance-based news data fetching functions.""" +from typing import Optional + import yfinance as yf from datetime import datetime from dateutil.relativedelta import relativedelta +from .config import get_config from .stockstats_utils import yf_retry @@ -64,9 +67,10 @@ def get_news_yfinance( Returns: Formatted string containing news articles """ + article_limit = get_config()["news_article_limit"] try: stock = yf.Ticker(ticker) - news = yf_retry(lambda: stock.get_news(count=20)) + news = yf_retry(lambda: stock.get_news(count=article_limit)) if not news: return f"No news found for {ticker}" @@ -106,27 +110,28 @@ def get_news_yfinance( def get_global_news_yfinance( curr_date: str, - look_back_days: int = 7, - limit: int = 10, + look_back_days: Optional[int] = None, + limit: Optional[int] = None, ) -> str: """ Retrieve global/macro economic news using yfinance Search. Args: curr_date: Current date in yyyy-mm-dd format - look_back_days: Number of days to look back - limit: Maximum number of articles to return + look_back_days: Number of days to look back. ``None`` falls back to + ``global_news_lookback_days`` from the active config. + limit: Maximum number of articles to return. ``None`` falls back to + ``global_news_article_limit`` from the active config. Returns: Formatted string containing global news articles """ - # Search queries for macro/global news - search_queries = [ - "stock market economy", - "Federal Reserve interest rates", - "inflation economic outlook", - "global markets trading", - ] + config = get_config() + if look_back_days is None: + look_back_days = config["global_news_lookback_days"] + if limit is None: + limit = config["global_news_article_limit"] + search_queries = config["global_news_queries"] all_news = [] seen_titles = set() diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index fa6d5742c..faa71f591 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -35,6 +35,21 @@ DEFAULT_CONFIG = { "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, "max_recur_limit": 100, + # News / data fetching parameters + # Increase for longer lookback strategies or to broaden macro coverage; + # decrease to reduce token usage in agent prompts. + "news_article_limit": 20, # max articles per ticker (ticker-news) + "global_news_article_limit": 10, # max articles for global/macro news + "global_news_lookback_days": 7, # macro news lookback window + # Search queries used by get_global_news for macro headlines. Extend or + # replace to broaden geographic / sector coverage. + "global_news_queries": [ + "Federal Reserve interest rates inflation", + "S&P 500 earnings GDP economic outlook", + "geopolitical risk trade war sanctions", + "ECB Bank of England BOJ central bank policy", + "oil commodities supply chain energy", + ], # Data vendor configuration # Category-level configuration (default for all tools in category) "data_vendors": { From 6b384f74f94c2b2d16e6114a709f1b5bb2d4fcd5 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 05:41:42 +0000 Subject: [PATCH 41/44] feat(i18n): localize researchers, risk debators, research mgr, trader output_language config now propagates to every user-facing agent. Previously only the four analysts and portfolio manager respected the setting, producing partial-localization reports with English debate text interleaved with non-English analyst sections. Verified live: 7 agents produce Chinese output when config is set to Chinese. #575 --- tradingagents/agents/managers/research_manager.py | 7 +++++-- tradingagents/agents/researchers/bear_researcher.py | 3 ++- tradingagents/agents/researchers/bull_researcher.py | 3 ++- tradingagents/agents/risk_mgmt/aggressive_debator.py | 3 ++- tradingagents/agents/risk_mgmt/conservative_debator.py | 3 ++- tradingagents/agents/risk_mgmt/neutral_debator.py | 3 ++- tradingagents/agents/trader/trader.py | 6 +++++- tradingagents/agents/utils/agent_utils.py | 6 ++++-- 8 files changed, 24 insertions(+), 10 deletions(-) diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 0e2206b2e..924b36b4d 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -3,7 +3,10 @@ from __future__ import annotations from tradingagents.agents.schemas import ResearchPlan, render_research_plan -from tradingagents.agents.utils.agent_utils import build_instrument_context +from tradingagents.agents.utils.agent_utils import ( + build_instrument_context, + get_language_instruction, +) from tradingagents.agents.utils.structured import ( bind_structured, invoke_structured_or_freetext, @@ -37,7 +40,7 @@ Commit to a clear stance whenever the debate's strongest arguments warrant one; --- **Debate History:** -{history}""" +{history}""" + get_language_instruction() investment_plan = invoke_structured_or_freetext( structured_llm, diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index 9cde9d39c..c78923eb2 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.agent_utils import get_language_instruction def create_bear_researcher(llm): @@ -31,7 +32,7 @@ Company fundamentals report: {fundamentals_report} Conversation history of the debate: {history} Last bull argument: {current_response} Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. -""" +""" + get_language_instruction() response = llm.invoke(prompt) diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index d16bc2371..c639c0bed 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.agent_utils import get_language_instruction def create_bull_researcher(llm): @@ -29,7 +30,7 @@ Company fundamentals report: {fundamentals_report} Conversation history of the debate: {history} Last bear argument: {current_response} Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. -""" +""" + get_language_instruction() response = llm.invoke(prompt) diff --git a/tradingagents/agents/risk_mgmt/aggressive_debator.py b/tradingagents/agents/risk_mgmt/aggressive_debator.py index 2dab1152a..212e73d6e 100644 --- a/tradingagents/agents/risk_mgmt/aggressive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggressive_debator.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.agent_utils import get_language_instruction def create_aggressive_debator(llm): @@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report} Company Fundamentals Report: {fundamentals_report} Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data. -Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" +Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction() response = llm.invoke(prompt) diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index 99a8315e0..a7f7342fa 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.agent_utils import get_language_instruction def create_conservative_debator(llm): @@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report} Company Fundamentals Report: {fundamentals_report} Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data. -Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" +Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction() response = llm.invoke(prompt) diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index e99ff0af1..73b306078 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.agent_utils import get_language_instruction def create_neutral_debator(llm): @@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report} Company Fundamentals Report: {fundamentals_report} Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data. -Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" +Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction() response = llm.invoke(prompt) diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index ea3f6b232..970350b1d 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -7,7 +7,10 @@ import functools from langchain_core.messages import AIMessage from tradingagents.agents.schemas import TraderProposal, render_trader_proposal -from tradingagents.agents.utils.agent_utils import build_instrument_context +from tradingagents.agents.utils.agent_utils import ( + build_instrument_context, + get_language_instruction, +) from tradingagents.agents.utils.structured import ( bind_structured, invoke_structured_or_freetext, @@ -29,6 +32,7 @@ def create_trader(llm): "You are a trading agent analyzing market data to make investment decisions. " "Based on your analysis, provide a specific recommendation to buy, sell, or hold. " "Anchor your reasoning in the analysts' reports and the research plan." + + get_language_instruction() ), }, { diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 4ba40a803..03340b3fe 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -24,8 +24,10 @@ def get_language_instruction() -> str: """Return a prompt instruction for the configured output language. Returns empty string when English (default), so no extra tokens are used. - Only applied to user-facing agents (analysts, portfolio manager). - Internal debate agents stay in English for reasoning quality. + Applied to every agent whose output reaches the saved report — + analysts, researchers, debaters, research manager, trader, and + portfolio manager — so a non-English run produces a fully localized + report rather than a mix of languages. """ from tradingagents.dataflows.config import get_config lang = get_config().get("output_language", "English") From d13e9b79469abe142fcda7b26d03fcaf6a99262b Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 06:12:31 +0000 Subject: [PATCH 42/44] feat(config): TRADINGAGENTS_* env-var overlay for DEFAULT_CONFIG Adds a single _ENV_OVERRIDES table in default_config.py with type-aware coercion (str/int/bool), so users can switch llm_provider, deep/quick models, backend URL, output language, debate rounds, and the checkpoint flag purely via .env. Centralizes load_dotenv in the package __init__ so the overlay applies for every entry point (CLI, main.py, programmatic). Drops the hardcoded model assignments and duplicate dotenv loads in main.py and cli/main.py. Verified live with OpenAI and Gemini. #602 --- .env.example | 13 +++++ cli/main.py | 12 ++-- main.py | 21 ++----- tests/test_env_overrides.py | 98 +++++++++++++++++++++++++++++++++ tradingagents/__init__.py | 15 +++++ tradingagents/default_config.py | 42 +++++++++++++- 6 files changed, 176 insertions(+), 25 deletions(-) create mode 100644 tests/test_env_overrides.py diff --git a/.env.example b/.env.example index 458d74956..a0a10050a 100644 --- a/.env.example +++ b/.env.example @@ -11,3 +11,16 @@ ZHIPU_CN_API_KEY= MINIMAX_API_KEY= MINIMAX_CN_API_KEY= OPENROUTER_API_KEY= + +# Optional: override DEFAULT_CONFIG without editing code. +# Any TRADINGAGENTS_* variable below, when set, replaces the matching key +# in tradingagents/default_config.py. Values are coerced to the type of +# the existing default (bool / int / str), so "true"/"3" work as expected. +#TRADINGAGENTS_LLM_PROVIDER=openai +#TRADINGAGENTS_DEEP_THINK_LLM=gpt-5.4 +#TRADINGAGENTS_QUICK_THINK_LLM=gpt-5.4-mini +#TRADINGAGENTS_LLM_BACKEND_URL= +#TRADINGAGENTS_OUTPUT_LANGUAGE=English +#TRADINGAGENTS_MAX_DEBATE_ROUNDS=1 +#TRADINGAGENTS_MAX_RISK_ROUNDS=1 +#TRADINGAGENTS_CHECKPOINT_ENABLED=false diff --git a/cli/main.py b/cli/main.py index c466cb219..3af0cbbe2 100644 --- a/cli/main.py +++ b/cli/main.py @@ -5,13 +5,6 @@ import questionary from pathlib import Path from functools import wraps from rich.console import Console -from dotenv import find_dotenv, load_dotenv - -# Search starts from the user's CWD so the installed `tradingagents` -# console script picks up the project's .env instead of walking up from -# site-packages. -load_dotenv(find_dotenv(usecwd=True)) -load_dotenv(find_dotenv(".env.enterprise", usecwd=True), override=False) from rich.panel import Panel from rich.spinner import Spinner from rich.live import Live @@ -569,6 +562,11 @@ def get_user_selections(): elif selected_llm_provider == "glm": selected_llm_provider, backend_url = ask_glm_region() + # Confirm the provider's API key is present; prompt the user to paste + # one and persist it to .env if it's missing, so the analysis run + # doesn't fail later at the first API call. + ensure_api_key(selected_llm_provider) + # Step 7: Thinking agents console.print( create_question_box( diff --git a/main.py b/main.py index fa3024af8..fea2f3680 100644 --- a/main.py +++ b/main.py @@ -1,23 +1,12 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG -from dotenv import find_dotenv, load_dotenv - -load_dotenv(find_dotenv(usecwd=True)) - -# Create a custom config +# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides +# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.), +# so users can switch models or endpoints purely via .env without +# editing this script. Override individual keys here only when you +# want a hard-coded value that should ignore the environment. config = DEFAULT_CONFIG.copy() -config["deep_think_llm"] = "gpt-5.4-mini" # Use a different model -config["quick_think_llm"] = "gpt-5.4-mini" # Use a different model -config["max_debate_rounds"] = 1 # Increase debate rounds - -# Configure data vendors (default uses yfinance, no extra API keys needed) -config["data_vendors"] = { - "core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance - "technical_indicators": "yfinance", # Options: alpha_vantage, yfinance - "fundamental_data": "yfinance", # Options: alpha_vantage, yfinance - "news_data": "yfinance", # Options: alpha_vantage, yfinance -} # Initialize with custom config ta = TradingAgentsGraph(debug=True, config=config) diff --git a/tests/test_env_overrides.py b/tests/test_env_overrides.py new file mode 100644 index 000000000..c12ce5f18 --- /dev/null +++ b/tests/test_env_overrides.py @@ -0,0 +1,98 @@ +"""Tests for TRADINGAGENTS_* env-var overlay onto DEFAULT_CONFIG.""" + +from __future__ import annotations + +import importlib + +import pytest + +import tradingagents.default_config as default_config_module + + +def _reload_with_env(monkeypatch, **overrides): + """Set/clear env vars then reload default_config to re-evaluate DEFAULT_CONFIG.""" + for key in list(default_config_module._ENV_OVERRIDES): + monkeypatch.delenv(key, raising=False) + for key, val in overrides.items(): + monkeypatch.setenv(key, val) + return importlib.reload(default_config_module) + + +def test_no_env_uses_built_in_defaults(monkeypatch): + dc = _reload_with_env(monkeypatch) + assert dc.DEFAULT_CONFIG["llm_provider"] == "openai" + assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gpt-5.4" + assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gpt-5.4-mini" + assert dc.DEFAULT_CONFIG["backend_url"] is None + assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1 + assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is False + + +def test_string_overrides(monkeypatch): + dc = _reload_with_env( + monkeypatch, + TRADINGAGENTS_LLM_PROVIDER="google", + TRADINGAGENTS_DEEP_THINK_LLM="gemini-3-pro-preview", + TRADINGAGENTS_QUICK_THINK_LLM="gemini-3-flash-preview", + TRADINGAGENTS_LLM_BACKEND_URL="https://example.invalid/v1", + TRADINGAGENTS_OUTPUT_LANGUAGE="Chinese", + ) + assert dc.DEFAULT_CONFIG["llm_provider"] == "google" + assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gemini-3-pro-preview" + assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gemini-3-flash-preview" + assert dc.DEFAULT_CONFIG["backend_url"] == "https://example.invalid/v1" + assert dc.DEFAULT_CONFIG["output_language"] == "Chinese" + + +def test_int_coercion(monkeypatch): + dc = _reload_with_env( + monkeypatch, + TRADINGAGENTS_MAX_DEBATE_ROUNDS="3", + TRADINGAGENTS_MAX_RISK_ROUNDS="2", + ) + assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 3 + assert isinstance(dc.DEFAULT_CONFIG["max_debate_rounds"], int) + assert dc.DEFAULT_CONFIG["max_risk_discuss_rounds"] == 2 + assert isinstance(dc.DEFAULT_CONFIG["max_risk_discuss_rounds"], int) + + +@pytest.mark.parametrize( + "raw,expected", + [ + ("true", True), ("True", True), ("1", True), ("yes", True), ("on", True), + ("false", False), ("False", False), ("0", False), ("no", False), ("off", False), + ], +) +def test_bool_coercion(monkeypatch, raw, expected): + dc = _reload_with_env(monkeypatch, TRADINGAGENTS_CHECKPOINT_ENABLED=raw) + assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is expected + + +def test_empty_env_value_is_passthrough(monkeypatch): + """Empty TRADINGAGENTS_* values must not clobber the built-in default.""" + dc = _reload_with_env( + monkeypatch, + TRADINGAGENTS_LLM_PROVIDER="", + TRADINGAGENTS_MAX_DEBATE_ROUNDS="", + ) + assert dc.DEFAULT_CONFIG["llm_provider"] == "openai" + assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1 + + +def test_invalid_int_raises(monkeypatch): + """Garbage int values should surface a ValueError at import, not silently misconfigure.""" + monkeypatch.setenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", "not-a-number") + with pytest.raises(ValueError): + importlib.reload(default_config_module) + # Restore module state for subsequent tests in this process + monkeypatch.delenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", raising=False) + importlib.reload(default_config_module) + + +def test_unknown_env_var_is_ignored(monkeypatch): + """Env vars outside _ENV_OVERRIDES must not bleed into DEFAULT_CONFIG.""" + dc = _reload_with_env( + monkeypatch, + TRADINGAGENTS_NONEXISTENT_KEY="oops", + ) + assert "nonexistent_key" not in dc.DEFAULT_CONFIG diff --git a/tradingagents/__init__.py b/tradingagents/__init__.py index 893a3d678..5f83f2a52 100644 --- a/tradingagents/__init__.py +++ b/tradingagents/__init__.py @@ -1,5 +1,20 @@ import warnings +# Load .env files at package import so DEFAULT_CONFIG's env-var overlay +# (and every llm_clients consumer) sees the user's keys regardless of +# which entry point started the process. find_dotenv(usecwd=True) walks +# from the CWD, so the installed `tradingagents` console script picks up +# the project's .env instead of stepping up from site-packages. +# load_dotenv defaults to override=False, so it never clobbers values +# the caller has already exported. +try: + from dotenv import find_dotenv, load_dotenv + + load_dotenv(find_dotenv(usecwd=True)) + load_dotenv(find_dotenv(".env.enterprise", usecwd=True), override=False) +except ImportError: + pass + # langchain-core 1.3.3 calls surface_langchain_deprecation_warnings() in # its own __init__, which prepends default-action filters for its # subclassed warning categories. To suppress a specific warning we must diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index faa71f591..fe5a6f755 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -2,7 +2,45 @@ import os _TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents") -DEFAULT_CONFIG = { +# Single source of truth for env-var → config-key overrides. To expose +# a new config key for environment-based override, add a row here — no +# entry-point script changes required. Coercion is driven by the type +# of the existing default, so users can keep writing plain strings in +# their .env file. +_ENV_OVERRIDES = { + "TRADINGAGENTS_LLM_PROVIDER": "llm_provider", + "TRADINGAGENTS_DEEP_THINK_LLM": "deep_think_llm", + "TRADINGAGENTS_QUICK_THINK_LLM": "quick_think_llm", + "TRADINGAGENTS_LLM_BACKEND_URL": "backend_url", + "TRADINGAGENTS_OUTPUT_LANGUAGE": "output_language", + "TRADINGAGENTS_MAX_DEBATE_ROUNDS": "max_debate_rounds", + "TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds", + "TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled", +} + + +def _coerce(value: str, reference): + """Coerce env-var string to the type of the existing default value.""" + if isinstance(reference, bool): + return value.strip().lower() in ("true", "1", "yes", "on") + if isinstance(reference, int) and not isinstance(reference, bool): + return int(value) + if isinstance(reference, float): + return float(value) + return value + + +def _apply_env_overrides(config: dict) -> dict: + """Apply TRADINGAGENTS_* env vars to the config dict in-place.""" + for env_var, key in _ENV_OVERRIDES.items(): + raw = os.environ.get(env_var) + if raw is None or raw == "": + continue + config[key] = _coerce(raw, config.get(key)) + return config + + +DEFAULT_CONFIG = _apply_env_overrides({ "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")), "data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")), @@ -62,4 +100,4 @@ DEFAULT_CONFIG = { "tool_vendors": { # Example: "get_stock_data": "alpha_vantage", # Override category default }, -} +}) From 9f7abfcbd576686685210f2dc6b8ec52c5d744ba Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 06:12:34 +0000 Subject: [PATCH 43/44] feat(cli): detect missing provider API keys and persist to .env MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a canonical PROVIDER_API_KEY_ENV mapping (14 providers including the three dual-region pairs) and an ensure_api_key() helper. When the selected provider's key is absent from the environment, the CLI prompts via questionary.password, writes the value to .env via python-dotenv's set_key (preserves existing lines), and exports it into os.environ so the run continues without restart. Wired into cli/main.py right after the region prompts so qwen-cn, glm-cn, and minimax-cn each check their own region-specific key. openai_client refactored to consult the same mapping, eliminating its private duplicate of provider→env-var data. --- cli/utils.py | 49 ++++++- tests/test_api_key_env.py | 149 +++++++++++++++++++++ tradingagents/llm_clients/api_key_env.py | 44 ++++++ tradingagents/llm_clients/openai_client.py | 43 +++--- 4 files changed, 261 insertions(+), 24 deletions(-) create mode 100644 tests/test_api_key_env.py create mode 100644 tradingagents/llm_clients/api_key_env.py diff --git a/cli/utils.py b/cli/utils.py index 1ccc12302..5fd0b806c 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,9 +1,13 @@ -import questionary +import os +from pathlib import Path from typing import List, Optional, Tuple, Dict +import questionary +from dotenv import find_dotenv, set_key from rich.console import Console from cli.models import AnalystType +from tradingagents.llm_clients.api_key_env import get_api_key_env from tradingagents.llm_clients.model_catalog import get_model_options console = Console() @@ -409,6 +413,49 @@ def ask_minimax_region() -> tuple[str, str]: ).ask() +def ensure_api_key(provider: str) -> Optional[str]: + """Make sure the API key for `provider` is available in the environment. + + If the env var is already set, returns its value untouched. Otherwise + interactively prompts the user, persists the value to the project's + .env file via python-dotenv's set_key (creating .env if needed), and + exports it into os.environ so the current process picks it up. + + Returns None for providers that do not require a key (e.g. ollama) + and for providers not found in the canonical mapping. + """ + env_var = get_api_key_env(provider) + if env_var is None: + return None # ollama / unknown — no key check possible + + existing = os.environ.get(env_var) + if existing: + return existing + + console.print( + f"\n[yellow]{env_var} is not set in your environment.[/yellow]" + ) + key = questionary.password( + f"Paste your {env_var} (will be saved to .env):", + style=questionary.Style([ + ("text", "fg:cyan"), + ("highlighted", "noinherit"), + ]), + ).ask() + if not key: + console.print( + f"[red]Skipped. API calls will fail until {env_var} is set.[/red]" + ) + return None + + env_path = find_dotenv(usecwd=True) or str(Path.cwd() / ".env") + Path(env_path).touch(exist_ok=True) + set_key(env_path, env_var, key) + os.environ[env_var] = key + console.print(f"[green]Saved {env_var} to {env_path}[/green]") + return key + + def ask_output_language() -> str: """Ask for report output language.""" choice = questionary.select( diff --git a/tests/test_api_key_env.py b/tests/test_api_key_env.py new file mode 100644 index 000000000..dde5a4886 --- /dev/null +++ b/tests/test_api_key_env.py @@ -0,0 +1,149 @@ +"""Tests for the canonical provider->env-var mapping and the CLI key-prompt helper.""" + +from __future__ import annotations + +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env + + +# ---- Mapping coverage ----------------------------------------------------- + + +def test_every_select_llm_provider_choice_has_an_entry(): + """select_llm_provider() must not present a provider the mapping doesn't know about.""" + # Mirrors the dropdown order in cli/utils.select_llm_provider so the two + # stay in lockstep. Region-specific keys (qwen-cn / minimax-cn / glm-cn) + # are reached via the secondary region prompt, so they must also be present. + expected = { + "openai", "google", "anthropic", "xai", "deepseek", + "qwen", "qwen-cn", + "glm", "glm-cn", + "minimax", "minimax-cn", + "openrouter", "azure", "ollama", + } + assert expected.issubset(PROVIDER_API_KEY_ENV.keys()) + + +@pytest.mark.parametrize( + "provider,env_var", + [ + ("openai", "OPENAI_API_KEY"), + ("anthropic", "ANTHROPIC_API_KEY"), + ("google", "GOOGLE_API_KEY"), + ("azure", "AZURE_OPENAI_API_KEY"), + ("xai", "XAI_API_KEY"), + ("deepseek", "DEEPSEEK_API_KEY"), + ("qwen", "DASHSCOPE_API_KEY"), + ("qwen-cn", "DASHSCOPE_CN_API_KEY"), + ("glm", "ZHIPU_API_KEY"), + ("glm-cn", "ZHIPU_CN_API_KEY"), + ("minimax", "MINIMAX_API_KEY"), + ("minimax-cn", "MINIMAX_CN_API_KEY"), + ("openrouter", "OPENROUTER_API_KEY"), + ], +) +def test_known_providers_resolve(provider, env_var): + assert get_api_key_env(provider) == env_var + + +def test_ollama_has_no_key(): + assert get_api_key_env("ollama") is None + + +def test_unknown_provider_returns_none(): + assert get_api_key_env("not-a-real-provider") is None + + +def test_case_insensitive_lookup(): + assert get_api_key_env("OpenAI") == "OPENAI_API_KEY" + assert get_api_key_env("QWEN-CN") == "DASHSCOPE_CN_API_KEY" + + +# ---- ensure_api_key behavior --------------------------------------------- + + +@pytest.fixture +def cli_utils(monkeypatch): + """Import cli.utils with a fresh environment so module-level state is consistent.""" + import importlib + import cli.utils as cli_utils_module + return importlib.reload(cli_utils_module) + + +def test_ensure_api_key_returns_existing(monkeypatch, cli_utils): + monkeypatch.setenv("OPENAI_API_KEY", "sk-already-set") + result = cli_utils.ensure_api_key("openai") + assert result == "sk-already-set" + + +def test_ensure_api_key_no_op_for_ollama(monkeypatch, cli_utils): + # Even with no env var set, ollama should not prompt and should return None. + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with patch.object(cli_utils, "questionary") as mock_q: + result = cli_utils.ensure_api_key("ollama") + assert result is None + mock_q.password.assert_not_called() + + +def test_ensure_api_key_unknown_provider_no_prompt(monkeypatch, cli_utils): + with patch.object(cli_utils, "questionary") as mock_q: + result = cli_utils.ensure_api_key("totally-fake-provider") + assert result is None + mock_q.password.assert_not_called() + + +def test_ensure_api_key_prompts_and_writes_to_env(monkeypatch, tmp_path, cli_utils): + """When key is missing, user-pasted value must be written to .env AND os.environ.""" + monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False) + monkeypatch.chdir(tmp_path) + + fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-deepseek-test")})() + with patch.object(cli_utils.questionary, "password", return_value=fake_prompt): + result = cli_utils.ensure_api_key("deepseek") + + assert result == "sk-deepseek-test" + assert os.environ["DEEPSEEK_API_KEY"] == "sk-deepseek-test" + env_file = tmp_path / ".env" + assert env_file.exists() + assert "DEEPSEEK_API_KEY" in env_file.read_text() + assert "sk-deepseek-test" in env_file.read_text() + + +def test_ensure_api_key_user_cancels_returns_none(monkeypatch, tmp_path, cli_utils): + """Empty prompt response (user cancelled) must not write to .env.""" + monkeypatch.delenv("XAI_API_KEY", raising=False) + monkeypatch.chdir(tmp_path) + + fake_prompt = type("P", (), {"ask": staticmethod(lambda: None)})() + with patch.object(cli_utils.questionary, "password", return_value=fake_prompt): + result = cli_utils.ensure_api_key("xai") + + assert result is None + assert "XAI_API_KEY" not in os.environ + # .env may or may not exist depending on find_dotenv's walk, but if it + # does it must not contain the key. + env_file = tmp_path / ".env" + if env_file.exists(): + assert "XAI_API_KEY" not in env_file.read_text() + + +def test_ensure_api_key_updates_existing_env_file(monkeypatch, tmp_path, cli_utils): + """An existing .env with other keys must be preserved on writeback.""" + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + monkeypatch.chdir(tmp_path) + env_file = tmp_path / ".env" + env_file.write_text("OPENAI_API_KEY=sk-existing\nOTHER=value\n") + + fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-openrouter-new")})() + with patch.object(cli_utils.questionary, "password", return_value=fake_prompt): + cli_utils.ensure_api_key("openrouter") + + content = env_file.read_text() + assert "OPENAI_API_KEY" in content and "sk-existing" in content + assert "OTHER=value" in content + assert "OPENROUTER_API_KEY" in content and "sk-openrouter-new" in content diff --git a/tradingagents/llm_clients/api_key_env.py b/tradingagents/llm_clients/api_key_env.py new file mode 100644 index 000000000..ff03d441a --- /dev/null +++ b/tradingagents/llm_clients/api_key_env.py @@ -0,0 +1,44 @@ +"""Canonical provider -> API-key env-var mapping. + +A single source of truth for which environment variable holds the API +key for each supported LLM provider. Used by the CLI's interactive key +prompt (cli/utils.ensure_api_key) and by anything else that needs to +ask "does this provider require a key, and which env var is it?". + +When adding a new provider, register its env var here so the CLI flow +prompts for it automatically instead of failing on first API call. +""" + +from __future__ import annotations + +from typing import Optional + + +PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "azure": "AZURE_OPENAI_API_KEY", + "xai": "XAI_API_KEY", + "deepseek": "DEEPSEEK_API_KEY", + # Dual-region providers each carry their own account; keys are not + # interchangeable between the international and China endpoints. + "qwen": "DASHSCOPE_API_KEY", + "qwen-cn": "DASHSCOPE_CN_API_KEY", + "glm": "ZHIPU_API_KEY", + "glm-cn": "ZHIPU_CN_API_KEY", + "minimax": "MINIMAX_API_KEY", + "minimax-cn": "MINIMAX_CN_API_KEY", + "openrouter": "OPENROUTER_API_KEY", + # Local runtimes do not authenticate. + "ollama": None, +} + + +def get_api_key_env(provider: str) -> Optional[str]: + """Return the env var name for `provider`'s API key, or None if not applicable. + + Unknown providers also return None — callers should treat that as + "no key check possible" rather than as "no key required". + """ + return PROVIDER_API_KEY_ENV.get(provider.lower()) diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 89c67e31d..771b28127 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -4,6 +4,7 @@ from typing import Any, Optional from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI +from .api_key_env import get_api_key_env from .base_client import BaseLLMClient, normalize_content from .capabilities import get_capabilities from .validators import validate_model @@ -135,26 +136,22 @@ _PASSTHROUGH_KWARGS = ( "api_key", "callbacks", "http_client", "http_async_client", ) -# Provider base URLs and API key env vars -_PROVIDER_CONFIG = { - "xai": ("https://api.x.ai/v1", "XAI_API_KEY"), - "deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"), - # DashScope exposes two regional endpoints with separate accounts; an - # international key won't authenticate against the China endpoint and - # vice versa (fixes issue #758). - "qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"), - "qwen-cn": ("https://dashscope.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_CN_API_KEY"), - # Zhipu exposes the same GLM models under two brands with separate - # accounts: Z.AI (international, api.z.ai) and BigModel - # (open.bigmodel.cn, China). Keys aren't interchangeable across them. - "glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"), - "glm-cn": ("https://open.bigmodel.cn/api/paas/v4/", "ZHIPU_CN_API_KEY"), - # MiniMax exposes two regional endpoints with separate keys; mainland - # Chinese users hit .com while global users hit .io. - "minimax": ("https://api.minimax.io/v1", "MINIMAX_API_KEY"), - "minimax-cn": ("https://api.minimaxi.com/v1", "MINIMAX_CN_API_KEY"), - "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), - "ollama": ("http://localhost:11434/v1", None), +# Provider base URLs. API-key env vars live in api_key_env.PROVIDER_API_KEY_ENV +# (one canonical mapping consulted by both this client and the CLI's +# interactive key-prompt). Dual-region providers (qwen/glm/minimax) keep +# separate endpoints because international and China accounts cannot share +# credentials (#758). +_PROVIDER_BASE_URL = { + "xai": "https://api.x.ai/v1", + "deepseek": "https://api.deepseek.com", + "qwen": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "qwen-cn": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "glm": "https://api.z.ai/api/paas/v4/", + "glm-cn": "https://open.bigmodel.cn/api/paas/v4/", + "minimax": "https://api.minimax.io/v1", + "minimax-cn": "https://api.minimaxi.com/v1", + "openrouter": "https://openrouter.ai/api/v1", + "ollama": "http://localhost:11434/v1", } @@ -185,9 +182,9 @@ class OpenAIClient(BaseLLMClient): # Provider-specific base URL and auth. An explicit base_url on the # client (e.g. a corporate proxy) takes precedence over the # provider default so users can route through their own gateway. - if self.provider in _PROVIDER_CONFIG: - default_base, api_key_env = _PROVIDER_CONFIG[self.provider] - llm_kwargs["base_url"] = self.base_url or default_base + if self.provider in _PROVIDER_BASE_URL: + llm_kwargs["base_url"] = self.base_url or _PROVIDER_BASE_URL[self.provider] + api_key_env = get_api_key_env(self.provider) if api_key_env: api_key = os.environ.get(api_key_env) if api_key: From 879e2bb5da53b0a3f78de014bb165ba7de55e83b Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 06:25:22 +0000 Subject: [PATCH 44/44] refactor: align display label and docs with sentiment_analyst rename The agent ingests news, StockTwits, and Reddit, but CLI labels, the README description, and the legacy shim docstring still framed it as social-media-only. Updates all user-visible surfaces so the name and the implementation match. --- README.md | 2 +- cli/main.py | 12 ++++++------ cli/models.py | 2 ++ cli/utils.py | 2 +- .../agents/analysts/social_media_analyst.py | 10 +++++----- tradingagents/agents/utils/agent_states.py | 2 +- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index b4578a8c9..a74c4f471 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ Our framework decomposes complex trading tasks into specialized roles. This ensu ### Analyst Team - Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags. -- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood. +- Sentiment Analyst: Aggregates news headlines, StockTwits, and Reddit chatter into a single sentiment read to gauge short-term market mood. - News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions. - Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements. diff --git a/cli/main.py b/cli/main.py index 3af0cbbe2..043c2311f 100644 --- a/cli/main.py +++ b/cli/main.py @@ -49,7 +49,7 @@ class MessageBuffer: # Analyst name mapping ANALYST_MAPPING = { "market": "Market Analyst", - "social": "Social Analyst", + "social": "Sentiment Analyst", "news": "News Analyst", "fundamentals": "Fundamentals Analyst", } @@ -59,7 +59,7 @@ class MessageBuffer: # finalizing_agent: which agent must be "completed" for this report to count as done REPORT_SECTIONS = { "market_report": ("market", "Market Analyst"), - "sentiment_report": ("social", "Social Analyst"), + "sentiment_report": ("social", "Sentiment Analyst"), "news_report": ("news", "News Analyst"), "fundamentals_report": ("fundamentals", "Fundamentals Analyst"), "investment_plan": (None, "Research Manager"), @@ -280,7 +280,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non all_teams = { "Analyst Team": [ "Market Analyst", - "Social Analyst", + "Sentiment Analyst", "News Analyst", "Fundamentals Analyst", ], @@ -680,7 +680,7 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): if final_state.get("sentiment_report"): analysts_dir.mkdir(exist_ok=True) (analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"], encoding="utf-8") - analyst_parts.append(("Social Analyst", final_state["sentiment_report"])) + analyst_parts.append(("Sentiment Analyst", final_state["sentiment_report"])) if final_state.get("news_report"): analysts_dir.mkdir(exist_ok=True) (analysts_dir / "news.md").write_text(final_state["news_report"], encoding="utf-8") @@ -765,7 +765,7 @@ def display_complete_report(final_state): if final_state.get("market_report"): analysts.append(("Market Analyst", final_state["market_report"])) if final_state.get("sentiment_report"): - analysts.append(("Social Analyst", final_state["sentiment_report"])) + analysts.append(("Sentiment Analyst", final_state["sentiment_report"])) if final_state.get("news_report"): analysts.append(("News Analyst", final_state["news_report"])) if final_state.get("fundamentals_report"): @@ -827,7 +827,7 @@ def update_research_team_status(status): ANALYST_ORDER = ["market", "social", "news", "fundamentals"] ANALYST_AGENT_NAMES = { "market": "Market Analyst", - "social": "Social Analyst", + "social": "Sentiment Analyst", "news": "News Analyst", "fundamentals": "Fundamentals Analyst", } diff --git a/cli/models.py b/cli/models.py index f68c3da1c..d1c5c24b1 100644 --- a/cli/models.py +++ b/cli/models.py @@ -5,6 +5,8 @@ from pydantic import BaseModel class AnalystType(str, Enum): MARKET = "market" + # Wire value stays "social" for saved-config and string-keyed-caller + # back-compat; the user-facing label is "Sentiment Analyst". SOCIAL = "social" NEWS = "news" FUNDAMENTALS = "fundamentals" diff --git a/cli/utils.py b/cli/utils.py index 5fd0b806c..15371400a 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -16,7 +16,7 @@ TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK" ANALYST_ORDER = [ ("Market Analyst", AnalystType.MARKET), - ("Social Media Analyst", AnalystType.SOCIAL), + ("Sentiment Analyst", AnalystType.SOCIAL), ("News Analyst", AnalystType.NEWS), ("Fundamentals Analyst", AnalystType.FUNDAMENTALS), ] diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 03cd7a44c..8c72df082 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,9 +1,9 @@ -"""Backwards-compatibility shim for the renamed social_media_analyst module. +"""Backwards-compatibility shim for the renamed module. -The social media analyst has been renamed to ``sentiment_analyst`` because its -only data tool is ``get_news`` (Yahoo Finance), not a social media feed. - -Import from ``tradingagents.agents.analysts.sentiment_analyst`` going forward. +The agent is now ``sentiment_analyst`` and aggregates Yahoo Finance news, +StockTwits cashtag streams, and Reddit posts into a single sentiment +report. Import from ``tradingagents.agents.analysts.sentiment_analyst`` +going forward; this module will be removed in a future release. See: https://github.com/TauricResearch/TradingAgents/issues/557 """ diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 6151a3863..d3a441a1a 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -51,7 +51,7 @@ class AgentState(MessagesState): # research step market_report: Annotated[str, "Report from the Market Analyst"] - sentiment_report: Annotated[str, "Report from the Social Media Analyst"] + sentiment_report: Annotated[str, "Report from the Sentiment Analyst"] news_report: Annotated[ str, "Report from the News Researcher of current world affairs" ]