Merge remote-tracking branch 'upstream/main' into analyst-phase1-observability

# Conflicts:
#	tradingagents/default_config.py
#	tradingagents/graph/setup.py
This commit is contained in:
CadeYu
2026-05-11 16:44:00 +08:00
71 changed files with 5225 additions and 760 deletions

View File

@@ -1,2 +1,38 @@
import os
os.environ.setdefault("PYTHONUTF8", "1")
import warnings
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
# (and every llm_clients consumer) sees the user's keys regardless of
# which entry point started the process. find_dotenv(usecwd=True) walks
# from the CWD, so the installed `tradingagents` console script picks up
# the project's .env instead of stepping up from site-packages.
# load_dotenv defaults to override=False, so it never clobbers values
# the caller has already exported.
try:
from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv(usecwd=True))
load_dotenv(find_dotenv(".env.enterprise", usecwd=True), override=False)
except ImportError:
pass
# langchain-core 1.3.3 calls surface_langchain_deprecation_warnings() in
# its own __init__, which prepends default-action filters for its
# subclassed warning categories. To suppress a specific warning we must
# install our filter AFTER langchain-core has installed its own, so import
# it first. The package is a guaranteed transitive dep via langgraph.
try:
import langchain_core # noqa: F401
except ImportError:
pass
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
# explicit allowed_objects, which triggers a noisy pending-deprecation
# warning from langchain-core 1.3.3 on every interpreter start. The fix
# is already merged upstream (langchain-ai/langgraph#7743, 2026-05-08)
# and will arrive in the next langgraph-checkpoint release. Remove this
# block (and the langchain_core preload above) when we bump past it.
warnings.filterwarnings(
"ignore",
message=r"The default value of `allowed_objects`.*",
category=PendingDeprecationWarning,
)

View File

@@ -1,11 +1,13 @@
from .utils.agent_utils import create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .analysts.sentiment_analyst import (
create_sentiment_analyst,
create_social_media_analyst, # deprecated alias kept for back-compat
)
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
@@ -20,7 +22,6 @@ from .managers.portfolio_manager import create_portfolio_manager
from .trader.trader import create_trader
__all__ = [
"FinancialSituationMemory",
"AgentState",
"create_msg_delete",
"InvestDebateState",
@@ -35,6 +36,7 @@ __all__ = [
"create_aggressive_debator",
"create_portfolio_manager",
"create_conservative_debator",
"create_social_media_analyst",
"create_sentiment_analyst",
"create_social_media_analyst", # deprecated; will be removed in a future version
"create_trader",
]

View File

@@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_balance_sheet,

View File

@@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_indicators,

View File

@@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_global_news,

View File

@@ -0,0 +1,184 @@
"""Sentiment analyst — multi-source sentiment analysis for a target ticker.
Previously named ``social_media_analyst``. Renamed and redesigned because
the old version had a prompt that demanded social-media analysis but the
only tool available was Yahoo Finance news — which led LLMs to fabricate
Reddit/X/StockTwits content under prompt pressure (verified live).
The redesigned agent pre-fetches three complementary data sources before
the LLM is invoked and injects them into the prompt as structured blocks:
1. News headlines — Yahoo Finance (institutional framing)
2. StockTwits messages — retail-trader posts indexed by cashtag, with
user-labeled Bullish/Bearish sentiment tags
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
The agent does not use tool-calling; the data is in the prompt from
turn 0. The LLM produces the sentiment report in a single invocation.
See: https://github.com/TauricResearch/TradingAgents/issues/557
"""
from datetime import datetime, timedelta
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
get_news,
)
from tradingagents.dataflows.reddit import fetch_reddit_posts
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
def _seven_days_back(trade_date: str) -> str:
return (datetime.strptime(trade_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
def create_sentiment_analyst(llm):
"""Create a sentiment analyst node for the trading graph.
Pre-fetches news + StockTwits + Reddit data, injects them into the
prompt as structured blocks, and produces a sentiment report in a
single LLM call.
"""
def sentiment_analyst_node(state):
ticker = state["company_of_interest"]
end_date = state["trade_date"]
start_date = _seven_days_back(end_date)
instrument_context = build_instrument_context(ticker)
# Pre-fetch all three sources. Each fetcher degrades gracefully and
# returns a string (no exceptions surface from here), so the LLM
# always sees something — either real data or a clear placeholder.
news_block = get_news.func(ticker, start_date, end_date)
stocktwits_block = fetch_stocktwits_messages(ticker, limit=30)
reddit_block = fetch_reddit_posts(ticker)
system_message = _build_system_message(
ticker=ticker,
start_date=start_date,
end_date=end_date,
news_block=news_block,
stocktwits_block=stocktwits_block,
reddit_block=reddit_block,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
"\n{system_message}\n"
"For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(current_date=end_date)
prompt = prompt.partial(instrument_context=instrument_context)
# No bind_tools — the data is already in the prompt; a single LLM
# call produces the report directly.
chain = prompt | llm
result = chain.invoke(state["messages"])
return {
"messages": [result],
"sentiment_report": result.content,
}
return sentiment_analyst_node
def _build_system_message(
*,
ticker: str,
start_date: str,
end_date: str,
news_block: str,
stocktwits_block: str,
reddit_block: str,
) -> str:
"""Assemble the sentiment-analyst system message with structured data blocks."""
return f"""You are a financial market sentiment analyst. Your task is to produce a comprehensive sentiment report for {ticker} covering the period from {start_date} to {end_date}, drawing on three complementary data sources that have already been collected for you.
## Data sources (pre-fetched, in this prompt)
### News headlines — Yahoo Finance, past 7 days
Institutional framing. Fact-driven, slower-moving signal.
<start_of_news>
{news_block}
<end_of_news>
### StockTwits messages — retail-trader social platform indexed by cashtag
Fast-moving signal. Each message carries a user-labeled sentiment tag (Bullish / Bearish / no-label) plus the message body.
<start_of_stocktwits>
{stocktwits_block}
<end_of_stocktwits>
### Reddit posts — r/wallstreetbets, r/stocks, r/investing (past 7 days)
Community discussion. Engagement signal via upvote score and comment count. Subreddit character matters (r/wallstreetbets is often contrarian/exuberant; r/stocks more measured; r/investing longer-term).
<start_of_reddit>
{reddit_block}
<end_of_reddit>
## How to analyze this data (best practices)
1. **Read the StockTwits Bullish/Bearish ratio as a leading retail-sentiment signal.** A 70/30 bullish/bearish split is moderately bullish; ≥90/10 may indicate over-extension and contrarian risk; 50/50 is uncertainty. Sample size matters — base rates on the actual message count, not percentages alone.
2. **Look for cross-source divergences.** If news framing is bearish but StockTwits is overwhelmingly bullish, that mismatch is itself a signal — it can mean retail is leaning into a thesis the news flow hasn't caught up to (or vice versa, that retail is chasing while institutions are cautious).
3. **Weight Reddit posts by engagement.** A 400-upvote / 200-comment thread reflects community attention; a 3-upvote post is noise. Read the body excerpts for context — the title alone often misleads.
4. **Distinguish opinion from event.** A news headline ("Nvidia announces $500M Corning deal") is an event; a StockTwits post ("buying NVDA, this is going to moon") is opinion. Both are inputs but should be weighted differently in your conclusions.
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this caveat explicitly. If the sources are silent on a given subreddit, say so.
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
## Output
Produce a sentiment report covering, in order:
1. **Overall sentiment direction** — Bullish / Bearish / Neutral / Mixed — with a brief confidence note based on data quality and sample size.
2. **Source-by-source breakdown** — what each of news / StockTwits / Reddit is telling you, with specific evidence (cite message counts, ratios, notable posts).
3. **Divergences, alignments, and key narratives** across sources.
4. **Catalysts and risks** surfaced by the data.
5. **Markdown table** at the end summarizing key sentiment signals, their direction, source, and supporting evidence.
{get_language_instruction()}"""
# ---------------------------------------------------------------------------
# Backwards-compatibility shim
# ---------------------------------------------------------------------------
def create_social_media_analyst(llm):
"""Deprecated alias for :func:`create_sentiment_analyst`.
Kept so existing code that imports ``create_social_media_analyst``
continues to work.
.. deprecated::
Import :func:`create_sentiment_analyst` directly instead.
"""
import warnings
warnings.warn(
"create_social_media_analyst is deprecated and will be removed in a "
"future version. Use create_sentiment_analyst instead.",
DeprecationWarning,
stacklevel=2,
)
return create_sentiment_analyst(llm)

View File

@@ -1,59 +1,23 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
from tradingagents.dataflows.config import get_config
"""Backwards-compatibility shim for the renamed module.
The agent is now ``sentiment_analyst`` and aggregates Yahoo Finance news,
StockTwits cashtag streams, and Reddit posts into a single sentiment
report. Import from ``tradingagents.agents.analysts.sentiment_analyst``
going forward; this module will be removed in a future release.
def create_social_media_analyst(llm):
def social_media_analyst_node(state):
current_date = state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
See: https://github.com/TauricResearch/TradingAgents/issues/557
"""
tools = [
get_news,
]
import warnings as _warnings
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
+ get_language_instruction()
)
from tradingagents.agents.analysts.sentiment_analyst import ( # noqa: F401
create_sentiment_analyst,
create_social_media_analyst,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"sentiment_report": report,
}
return social_media_analyst_node
_warnings.warn(
"tradingagents.agents.analysts.social_media_analyst is deprecated. "
"Import from tradingagents.agents.analysts.sentiment_analyst instead.",
DeprecationWarning,
stacklevel=2,
)

View File

@@ -1,25 +1,43 @@
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
"""Portfolio Manager: synthesises the risk-analyst debate into the final decision.
Uses LangChain's ``with_structured_output`` so the LLM produces a typed
``PortfolioDecision`` directly, in a single call. The result is rendered
back to markdown for storage in ``final_trade_decision`` so memory log,
CLI display, and saved reports continue to consume the same shape they do
today. When a provider does not expose structured output, the agent falls
back gracefully to free-text generation.
"""
from __future__ import annotations
from tradingagents.agents.schemas import PortfolioDecision, render_pm_decision
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
)
from tradingagents.agents.utils.structured import (
bind_structured,
invoke_structured_or_freetext,
)
def create_portfolio_manager(llm, memory):
def create_portfolio_manager(llm):
structured_llm = bind_structured(llm, PortfolioDecision, "Portfolio Manager")
def portfolio_manager_node(state) -> dict:
instrument_context = build_instrument_context(state["company_of_interest"])
history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"]
market_research_report = state["market_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
sentiment_report = state["sentiment_report"]
trader_plan = state["investment_plan"]
research_plan = state["investment_plan"]
trader_plan = state["trader_investment_plan"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
past_context = state.get("past_context", "")
lessons_line = (
f"- Lessons from prior decisions and outcomes:\n{past_context}\n"
if past_context
else ""
)
prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision.
@@ -35,16 +53,9 @@ def create_portfolio_manager(llm, memory):
- **Sell**: Exit position or avoid entry
**Context:**
- Trader's proposed plan: **{trader_plan}**
- Lessons from past decisions: **{past_memory_str}**
**Required Output Structure:**
1. **Rating**: State one of Buy / Overweight / Hold / Underweight / Sell.
2. **Executive Summary**: A concise action plan covering entry strategy, position sizing, key risk levels, and time horizon.
3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate and past reflections.
---
- Research Manager's investment plan: **{research_plan}**
- Trader's transaction proposal: **{trader_plan}**
{lessons_line}
**Risk Analysts Debate History:**
{history}
@@ -52,10 +63,16 @@ def create_portfolio_manager(llm, memory):
Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}"""
response = llm.invoke(prompt)
final_trade_decision = invoke_structured_or_freetext(
structured_llm,
llm,
prompt,
render_pm_decision,
"Portfolio Manager",
)
new_risk_debate_state = {
"judge_decision": response.content,
"judge_decision": final_trade_decision,
"history": risk_debate_state["history"],
"aggressive_history": risk_debate_state["aggressive_history"],
"conservative_history": risk_debate_state["conservative_history"],
@@ -69,7 +86,7 @@ Be decisive and ground every conclusion in specific evidence from the analysts.{
return {
"risk_debate_state": new_risk_debate_state,
"final_trade_decision": response.content,
"final_trade_decision": final_trade_decision,
}
return portfolio_manager_node

View File

@@ -1,60 +1,67 @@
import time
import json
"""Research Manager: turns the bull/bear debate into a structured investment plan for the trader."""
from tradingagents.agents.utils.agent_utils import build_instrument_context
from __future__ import annotations
from tradingagents.agents.schemas import ResearchPlan, render_research_plan
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
)
from tradingagents.agents.utils.structured import (
bind_structured,
invoke_structured_or_freetext,
)
def create_research_manager(llm, memory):
def create_research_manager(llm):
structured_llm = bind_structured(llm, ResearchPlan, "Research Manager")
def research_manager_node(state) -> dict:
instrument_context = build_instrument_context(state["company_of_interest"])
history = state["investment_debate_state"].get("history", "")
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
investment_debate_state = state["investment_debate_state"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
Additionally, develop a detailed investment plan for the trader. This should include:
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"
prompt = f"""As the Research Manager and debate facilitator, your role is to critically evaluate this round of debate and deliver a clear, actionable investment plan for the trader.
{instrument_context}
Here is the debate:
Debate History:
{history}"""
response = llm.invoke(prompt)
---
**Rating Scale** (use exactly one):
- **Buy**: Strong conviction in the bull thesis; recommend taking or growing the position
- **Overweight**: Constructive view; recommend gradually increasing exposure
- **Hold**: Balanced view; recommend maintaining the current position
- **Underweight**: Cautious view; recommend trimming exposure
- **Sell**: Strong conviction in the bear thesis; recommend exiting or avoiding the position
Commit to a clear stance whenever the debate's strongest arguments warrant one; reserve Hold for situations where the evidence on both sides is genuinely balanced.
---
**Debate History:**
{history}""" + get_language_instruction()
investment_plan = invoke_structured_or_freetext(
structured_llm,
llm,
prompt,
render_research_plan,
"Research Manager",
)
new_investment_debate_state = {
"judge_decision": response.content,
"judge_decision": investment_plan,
"history": investment_debate_state.get("history", ""),
"bear_history": investment_debate_state.get("bear_history", ""),
"bull_history": investment_debate_state.get("bull_history", ""),
"current_response": response.content,
"current_response": investment_plan,
"count": investment_debate_state["count"],
}
return {
"investment_debate_state": new_investment_debate_state,
"investment_plan": response.content,
"investment_plan": investment_plan,
}
return research_manager_node

View File

@@ -1,9 +1,7 @@
from langchain_core.messages import AIMessage
import time
import json
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_bear_researcher(llm, memory):
def create_bear_researcher(llm):
def bear_node(state) -> dict:
investment_debate_state = state["investment_debate_state"]
history = investment_debate_state.get("history", "")
@@ -15,13 +13,6 @@ def create_bear_researcher(llm, memory):
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
Key points to focus on:
@@ -40,9 +31,8 @@ Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
"""
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock.
""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -1,9 +1,7 @@
from langchain_core.messages import AIMessage
import time
import json
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_bull_researcher(llm, memory):
def create_bull_researcher(llm):
def bull_node(state) -> dict:
investment_debate_state = state["investment_debate_state"]
history = investment_debate_state.get("history", "")
@@ -15,13 +13,6 @@ def create_bull_researcher(llm, memory):
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
Key points to focus on:
@@ -38,9 +29,8 @@ Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
"""
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position.
""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -1,5 +1,4 @@
import time
import json
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_aggressive_debator(llm):
@@ -30,7 +29,7 @@ Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -1,6 +1,4 @@
from langchain_core.messages import AIMessage
import time
import json
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_conservative_debator(llm):
@@ -31,7 +29,7 @@ Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -1,5 +1,4 @@
import time
import json
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_neutral_debator(llm):
@@ -30,7 +29,7 @@ Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -0,0 +1,228 @@
"""Pydantic schemas used by agents that produce structured output.
The framework's primary artifact is still prose: each agent's natural-language
reasoning is what users read in the saved markdown reports and what the
downstream agents read as context. Structured output is layered onto the
three decision-making agents (Research Manager, Trader, Portfolio Manager)
so that:
- Their outputs follow consistent section headers across runs and providers
- Each provider's native structured-output mode is used (json_schema for
OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic)
- Schema field descriptions become the model's output instructions, freeing
the prompt body to focus on context and the rating-scale guidance
- A render helper turns the parsed Pydantic instance back into the same
markdown shape the rest of the system already consumes, so display,
memory log, and saved reports keep working unchanged
"""
from __future__ import annotations
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Shared rating types
# ---------------------------------------------------------------------------
class PortfolioRating(str, Enum):
"""5-tier rating used by the Research Manager and Portfolio Manager."""
BUY = "Buy"
OVERWEIGHT = "Overweight"
HOLD = "Hold"
UNDERWEIGHT = "Underweight"
SELL = "Sell"
class TraderAction(str, Enum):
"""3-tier transaction direction used by the Trader.
The Trader's job is to translate the Research Manager's investment plan
into a concrete transaction proposal: should the desk execute a Buy, a
Sell, or sit on Hold this round. Position sizing and the nuanced
Overweight / Underweight calls happen later at the Portfolio Manager.
"""
BUY = "Buy"
HOLD = "Hold"
SELL = "Sell"
# ---------------------------------------------------------------------------
# Research Manager
# ---------------------------------------------------------------------------
class ResearchPlan(BaseModel):
"""Structured investment plan produced by the Research Manager.
Hand-off to the Trader: the recommendation pins the directional view,
the rationale captures which side of the bull/bear debate carried the
argument, and the strategic actions translate that into concrete
instructions the trader can execute against.
"""
recommendation: PortfolioRating = Field(
description=(
"The investment recommendation. Exactly one of Buy / Overweight / "
"Hold / Underweight / Sell. Reserve Hold for situations where the "
"evidence on both sides is genuinely balanced; otherwise commit to "
"the side with the stronger arguments."
),
)
rationale: str = Field(
description=(
"Conversational summary of the key points from both sides of the "
"debate, ending with which arguments led to the recommendation. "
"Speak naturally, as if to a teammate."
),
)
strategic_actions: str = Field(
description=(
"Concrete steps for the trader to implement the recommendation, "
"including position sizing guidance consistent with the rating."
),
)
def render_research_plan(plan: ResearchPlan) -> str:
"""Render a ResearchPlan to markdown for storage and the trader's prompt context."""
return "\n".join([
f"**Recommendation**: {plan.recommendation.value}",
"",
f"**Rationale**: {plan.rationale}",
"",
f"**Strategic Actions**: {plan.strategic_actions}",
])
# ---------------------------------------------------------------------------
# Trader
# ---------------------------------------------------------------------------
class TraderProposal(BaseModel):
"""Structured transaction proposal produced by the Trader.
The trader reads the Research Manager's investment plan and the analyst
reports, then turns them into a concrete transaction: what action to
take, the reasoning that justifies it, and the practical levels for
entry, stop-loss, and sizing.
"""
action: TraderAction = Field(
description="The transaction direction. Exactly one of Buy / Hold / Sell.",
)
reasoning: str = Field(
description=(
"The case for this action, anchored in the analysts' reports and "
"the research plan. Two to four sentences."
),
)
entry_price: Optional[float] = Field(
default=None,
description="Optional entry price target in the instrument's quote currency.",
)
stop_loss: Optional[float] = Field(
default=None,
description="Optional stop-loss price in the instrument's quote currency.",
)
position_sizing: Optional[str] = Field(
default=None,
description="Optional sizing guidance, e.g. '5% of portfolio'.",
)
def render_trader_proposal(proposal: TraderProposal) -> str:
"""Render a TraderProposal to markdown.
The trailing ``FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**`` line is
preserved for backward compatibility with the analyst stop-signal text
and any external code that greps for it.
"""
parts = [
f"**Action**: {proposal.action.value}",
"",
f"**Reasoning**: {proposal.reasoning}",
]
if proposal.entry_price is not None:
parts.extend(["", f"**Entry Price**: {proposal.entry_price}"])
if proposal.stop_loss is not None:
parts.extend(["", f"**Stop Loss**: {proposal.stop_loss}"])
if proposal.position_sizing:
parts.extend(["", f"**Position Sizing**: {proposal.position_sizing}"])
parts.extend([
"",
f"FINAL TRANSACTION PROPOSAL: **{proposal.action.value.upper()}**",
])
return "\n".join(parts)
# ---------------------------------------------------------------------------
# Portfolio Manager
# ---------------------------------------------------------------------------
class PortfolioDecision(BaseModel):
"""Structured output produced by the Portfolio Manager.
The model fills every field as part of its primary LLM call; no separate
extraction pass is required. Field descriptions double as the model's
output instructions, so the prompt body only needs to convey context and
the rating-scale guidance.
"""
rating: PortfolioRating = Field(
description=(
"The final position rating. Exactly one of Buy / Overweight / Hold / "
"Underweight / Sell, picked based on the analysts' debate."
),
)
executive_summary: str = Field(
description=(
"A concise action plan covering entry strategy, position sizing, "
"key risk levels, and time horizon. Two to four sentences."
),
)
investment_thesis: str = Field(
description=(
"Detailed reasoning anchored in specific evidence from the analysts' "
"debate. If prior lessons are referenced in the prompt context, "
"incorporate them; otherwise rely solely on the current analysis."
),
)
price_target: Optional[float] = Field(
default=None,
description="Optional target price in the instrument's quote currency.",
)
time_horizon: Optional[str] = Field(
default=None,
description="Optional recommended holding period, e.g. '3-6 months'.",
)
def render_pm_decision(decision: PortfolioDecision) -> str:
"""Render a PortfolioDecision back to the markdown shape the rest of the system expects.
Memory log, CLI display, and saved report files all read this markdown,
so the rendered output preserves the exact section headers (``**Rating**``,
``**Executive Summary**``, ``**Investment Thesis**``) that downstream
parsers and the report writers already handle.
"""
parts = [
f"**Rating**: {decision.rating.value}",
"",
f"**Executive Summary**: {decision.executive_summary}",
"",
f"**Investment Thesis**: {decision.investment_thesis}",
]
if decision.price_target is not None:
parts.extend(["", f"**Price Target**: {decision.price_target}"])
if decision.time_horizon:
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
return "\n".join(parts)

View File

@@ -1,48 +1,64 @@
"""Trader: turns the Research Manager's investment plan into a concrete transaction proposal."""
from __future__ import annotations
import functools
import time
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context
from langchain_core.messages import AIMessage
from tradingagents.agents.schemas import TraderProposal, render_trader_proposal
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
)
from tradingagents.agents.utils.structured import (
bind_structured,
invoke_structured_or_freetext,
)
def create_trader(llm, memory):
def create_trader(llm):
structured_llm = bind_structured(llm, TraderProposal, "Trader")
def trader_node(state, name):
company_name = state["company_of_interest"]
instrument_context = build_instrument_context(company_name)
investment_plan = state["investment_plan"]
market_research_report = state["market_report"]
sentiment_report = state["sentiment_report"]
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
if past_memories:
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
else:
past_memory_str = "No past memories found."
context = {
"role": "user",
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. {instrument_context} This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
}
messages = [
{
"role": "system",
"content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Apply lessons from past decisions to strengthen your analysis. Here are reflections from similar situations you traded in and the lessons learned: {past_memory_str}""",
"content": (
"You are a trading agent analyzing market data to make investment decisions. "
"Based on your analysis, provide a specific recommendation to buy, sell, or hold. "
"Anchor your reasoning in the analysts' reports and the research plan."
+ get_language_instruction()
),
},
{
"role": "user",
"content": (
f"Based on a comprehensive analysis by a team of analysts, here is an investment "
f"plan tailored for {company_name}. {instrument_context} This plan incorporates "
f"insights from current technical market trends, macroeconomic indicators, and "
f"social media sentiment. Use this plan as a foundation for evaluating your next "
f"trading decision.\n\nProposed Investment Plan: {investment_plan}\n\n"
f"Leverage these insights to make an informed and strategic decision."
),
},
context,
]
result = llm.invoke(messages)
trader_plan = invoke_structured_or_freetext(
structured_llm,
llm,
messages,
render_trader_proposal,
"Trader",
)
return {
"messages": [result],
"trader_investment_plan": result.content,
"messages": [AIMessage(content=trader_plan)],
"trader_investment_plan": trader_plan,
"sender": name,
}

View File

@@ -1,10 +1,6 @@
from typing import Annotated, Sequence
from datetime import date, timedelta, datetime
from typing_extensions import TypedDict, Optional
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import MessagesState
# Researcher team state
@@ -55,7 +51,7 @@ class AgentState(MessagesState):
# research step
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
sentiment_report: Annotated[str, "Report from the Sentiment Analyst"]
news_report: Annotated[
str, "Report from the News Researcher of current world affairs"
]
@@ -74,3 +70,4 @@ class AgentState(MessagesState):
RiskDebateState, "Current state of the debate on evaluating risk"
]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
past_context: Annotated[str, "Memory log context injected at run start (same-ticker decisions + cross-ticker lessons)"]

View File

@@ -24,8 +24,10 @@ def get_language_instruction() -> str:
"""Return a prompt instruction for the configured output language.
Returns empty string when English (default), so no extra tokens are used.
Only applied to user-facing agents (analysts, portfolio manager).
Internal debate agents stay in English for reasoning quality.
Applied to every agent whose output reaches the saved report —
analysts, researchers, debaters, research manager, trader, and
portfolio manager — so a non-English run produces a fully localized
report rather than a mix of languages.
"""
from tradingagents.dataflows.config import get_config
lang = get_config().get("output_language", "English")

View File

@@ -1,144 +1,300 @@
"""Financial situation memory using BM25 for lexical similarity matching.
"""Append-only markdown decision log for TradingAgents."""
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
no token limits, works offline with any LLM provider.
"""
from rank_bm25 import BM25Okapi
from typing import List, Tuple
from typing import List, Optional
from pathlib import Path
import re
from tradingagents.agents.utils.rating import parse_rating
class FinancialSituationMemory:
"""Memory system for storing and retrieving financial situations using BM25."""
def __init__(self, name: str, config: dict = None):
"""Initialize the memory system.
class TradingMemoryLog:
"""Append-only markdown log of trading decisions and reflections."""
Args:
name: Name identifier for this memory instance
config: Configuration dict (kept for API compatibility, not used for BM25)
"""
self.name = name
self.documents: List[str] = []
self.recommendations: List[str] = []
self.bm25 = None
# HTML comment: cannot appear in LLM prose output, safe as a hard delimiter
_SEPARATOR = "\n\n<!-- ENTRY_END -->\n\n"
# Precompiled patterns — avoids re-compilation on every load_entries() call
_DECISION_RE = re.compile(r"DECISION:\n(.*?)(?=\nREFLECTION:|\Z)", re.DOTALL)
_REFLECTION_RE = re.compile(r"REFLECTION:\n(.*?)$", re.DOTALL)
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text for BM25 indexing.
def __init__(self, config: dict = None):
cfg = config or {}
self._log_path = None
path = cfg.get("memory_log_path")
if path:
self._log_path = Path(path).expanduser()
self._log_path.parent.mkdir(parents=True, exist_ok=True)
# Optional cap on resolved entries. None disables rotation.
self._max_entries = cfg.get("memory_log_max_entries")
Simple whitespace + punctuation tokenization with lowercasing.
"""
# Lowercase and split on non-alphanumeric characters
tokens = re.findall(r'\b\w+\b', text.lower())
return tokens
# --- Write path (Phase A) ---
def _rebuild_index(self):
"""Rebuild the BM25 index after adding documents."""
if self.documents:
tokenized_docs = [self._tokenize(doc) for doc in self.documents]
self.bm25 = BM25Okapi(tokenized_docs)
else:
self.bm25 = None
def store_decision(
self,
ticker: str,
trade_date: str,
final_trade_decision: str,
) -> None:
"""Append pending entry at end of propagate(). No LLM call."""
if not self._log_path:
return
# Idempotency guard: fast raw-text scan instead of full parse
if self._log_path.exists():
raw = self._log_path.read_text(encoding="utf-8")
for line in raw.splitlines():
if line.startswith(f"[{trade_date} | {ticker} |") and line.endswith("| pending]"):
return
rating = parse_rating(final_trade_decision)
tag = f"[{trade_date} | {ticker} | {rating} | pending]"
entry = f"{tag}\n\nDECISION:\n{final_trade_decision}{self._SEPARATOR}"
with open(self._log_path, "a", encoding="utf-8") as f:
f.write(entry)
def add_situations(self, situations_and_advice: List[Tuple[str, str]]):
"""Add financial situations and their corresponding advice.
# --- Read path (Phase A) ---
Args:
situations_and_advice: List of tuples (situation, recommendation)
"""
for situation, recommendation in situations_and_advice:
self.documents.append(situation)
self.recommendations.append(recommendation)
# Rebuild BM25 index with new documents
self._rebuild_index()
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
"""Find matching recommendations using BM25 similarity.
Args:
current_situation: The current financial situation to match against
n_matches: Number of top matches to return
Returns:
List of dicts with matched_situation, recommendation, and similarity_score
"""
if not self.documents or self.bm25 is None:
def load_entries(self) -> List[dict]:
"""Parse all entries from log. Returns list of dicts."""
if not self._log_path or not self._log_path.exists():
return []
text = self._log_path.read_text(encoding="utf-8")
raw_entries = [e.strip() for e in text.split(self._SEPARATOR) if e.strip()]
entries = []
for raw in raw_entries:
parsed = self._parse_entry(raw)
if parsed:
entries.append(parsed)
return entries
# Tokenize query
query_tokens = self._tokenize(current_situation)
def get_pending_entries(self) -> List[dict]:
"""Return entries with outcome:pending (for Phase B)."""
return [e for e in self.load_entries() if e.get("pending")]
# Get BM25 scores for all documents
scores = self.bm25.get_scores(query_tokens)
def get_past_context(self, ticker: str, n_same: int = 5, n_cross: int = 3) -> str:
"""Return formatted past context string for agent prompt injection."""
entries = [e for e in self.load_entries() if not e.get("pending")]
if not entries:
return ""
# Get top-n indices sorted by score (descending)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_matches]
same, cross = [], []
for e in reversed(entries):
if len(same) >= n_same and len(cross) >= n_cross:
break
if e["ticker"] == ticker and len(same) < n_same:
same.append(e)
elif e["ticker"] != ticker and len(cross) < n_cross:
cross.append(e)
# Build results
results = []
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
if not same and not cross:
return ""
for idx in top_indices:
# Normalize score to 0-1 range for consistency
normalized_score = scores[idx] / max_score if max_score > 0 else 0
results.append({
"matched_situation": self.documents[idx],
"recommendation": self.recommendations[idx],
"similarity_score": normalized_score,
})
parts = []
if same:
parts.append(f"Past analyses of {ticker} (most recent first):")
parts.extend(self._format_full(e) for e in same)
if cross:
parts.append("Recent cross-ticker lessons:")
parts.extend(self._format_reflection_only(e) for e in cross)
return "\n\n".join(parts)
return results
# --- Update path (Phase B) ---
def clear(self):
"""Clear all stored memories."""
self.documents = []
self.recommendations = []
self.bm25 = None
def update_with_outcome(
self,
ticker: str,
trade_date: str,
raw_return: float,
alpha_return: float,
holding_days: int,
reflection: str,
) -> None:
"""Replace pending tag and append REFLECTION section using atomic write.
Finds the first pending entry matching (trade_date, ticker), updates
its tag with return figures, and appends a REFLECTION section. Uses
a temp-file + os.replace() so a crash mid-write never corrupts the log.
"""
if not self._log_path or not self._log_path.exists():
return
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory("test_memory")
text = self._log_path.read_text(encoding="utf-8")
blocks = text.split(self._SEPARATOR)
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
pending_prefix = f"[{trade_date} | {ticker} |"
raw_pct = f"{raw_return:+.1%}"
alpha_pct = f"{alpha_return:+.1%}"
# Add the example situations and recommendations
matcher.add_situations(example_data)
updated = False
new_blocks = []
for block in blocks:
stripped = block.strip()
if not stripped:
new_blocks.append(block)
continue
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
lines = stripped.splitlines()
tag_line = lines[0].strip()
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
if (
not updated
and tag_line.startswith(pending_prefix)
and tag_line.endswith("| pending]")
):
# Parse rating from the existing pending tag
fields = [f.strip() for f in tag_line[1:-1].split("|")]
rating = fields[2]
new_tag = (
f"[{trade_date} | {ticker} | {rating}"
f" | {raw_pct} | {alpha_pct} | {holding_days}d]"
)
rest = "\n".join(lines[1:])
new_blocks.append(
f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{reflection}"
)
updated = True
else:
new_blocks.append(block)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
if not updated:
return
except Exception as e:
print(f"Error during recommendation: {str(e)}")
new_blocks = self._apply_rotation(new_blocks)
new_text = self._SEPARATOR.join(new_blocks)
tmp_path = self._log_path.with_suffix(".tmp")
tmp_path.write_text(new_text, encoding="utf-8")
tmp_path.replace(self._log_path)
def batch_update_with_outcomes(self, updates: List[dict]) -> None:
"""Apply multiple outcome updates in a single read + atomic write.
Each element of updates must have keys: ticker, trade_date,
raw_return, alpha_return, holding_days, reflection.
"""
if not self._log_path or not self._log_path.exists() or not updates:
return
text = self._log_path.read_text(encoding="utf-8")
blocks = text.split(self._SEPARATOR)
# Build lookup keyed by (trade_date, ticker) for O(1) dispatch
update_map = {(u["trade_date"], u["ticker"]): u for u in updates}
new_blocks = []
for block in blocks:
stripped = block.strip()
if not stripped:
new_blocks.append(block)
continue
lines = stripped.splitlines()
tag_line = lines[0].strip()
matched = False
for (trade_date, ticker), upd in list(update_map.items()):
pending_prefix = f"[{trade_date} | {ticker} |"
if tag_line.startswith(pending_prefix) and tag_line.endswith("| pending]"):
fields = [f.strip() for f in tag_line[1:-1].split("|")]
rating = fields[2]
raw_pct = f"{upd['raw_return']:+.1%}"
alpha_pct = f"{upd['alpha_return']:+.1%}"
new_tag = (
f"[{trade_date} | {ticker} | {rating}"
f" | {raw_pct} | {alpha_pct} | {upd['holding_days']}d]"
)
rest = "\n".join(lines[1:])
new_blocks.append(
f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{upd['reflection']}"
)
del update_map[(trade_date, ticker)]
matched = True
break
if not matched:
new_blocks.append(block)
new_blocks = self._apply_rotation(new_blocks)
new_text = self._SEPARATOR.join(new_blocks)
tmp_path = self._log_path.with_suffix(".tmp")
tmp_path.write_text(new_text, encoding="utf-8")
tmp_path.replace(self._log_path)
# --- Helpers ---
def _apply_rotation(self, blocks: List[str]) -> List[str]:
"""Drop oldest resolved blocks when their count exceeds max_entries.
Pending blocks are always kept (they represent unprocessed work).
Returns ``blocks`` unchanged when rotation is disabled or under cap.
"""
if not self._max_entries or self._max_entries <= 0:
return blocks
# Tag each block with (kept, is_resolved) by parsing tag-line markers.
decisions = []
for block in blocks:
stripped = block.strip()
if not stripped:
decisions.append((block, False))
continue
tag_line = stripped.splitlines()[0].strip()
is_resolved = (
tag_line.startswith("[")
and tag_line.endswith("]")
and not tag_line.endswith("| pending]")
)
decisions.append((block, is_resolved))
resolved_count = sum(1 for _, r in decisions if r)
if resolved_count <= self._max_entries:
return blocks
to_drop = resolved_count - self._max_entries
kept: List[str] = []
for block, is_resolved in decisions:
if is_resolved and to_drop > 0:
to_drop -= 1
continue
kept.append(block)
return kept
def _parse_entry(self, raw: str) -> Optional[dict]:
lines = raw.strip().splitlines()
if not lines:
return None
tag_line = lines[0].strip()
if not (tag_line.startswith("[") and tag_line.endswith("]")):
return None
fields = [f.strip() for f in tag_line[1:-1].split("|")]
if len(fields) < 4:
return None
entry = {
"date": fields[0],
"ticker": fields[1],
"rating": fields[2],
"pending": fields[3] == "pending",
"raw": fields[3] if fields[3] != "pending" else None,
"alpha": fields[4] if len(fields) > 4 else None,
"holding": fields[5] if len(fields) > 5 else None,
}
body = "\n".join(lines[1:]).strip()
decision_match = self._DECISION_RE.search(body)
reflection_match = self._REFLECTION_RE.search(body)
entry["decision"] = decision_match.group(1).strip() if decision_match else ""
entry["reflection"] = reflection_match.group(1).strip() if reflection_match else ""
return entry
def _format_full(self, e: dict) -> str:
raw = e["raw"] or "n/a"
alpha = e["alpha"] or "n/a"
holding = e["holding"] or "n/a"
tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {raw} | {alpha} | {holding}]"
parts = [tag, f"DECISION:\n{e['decision']}"]
if e["reflection"]:
parts.append(f"REFLECTION:\n{e['reflection']}")
return "\n\n".join(parts)
def _format_reflection_only(self, e: dict) -> str:
tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {e['raw'] or 'n/a'}]"
if e["reflection"]:
return f"{tag}\n{e['reflection']}"
text = e["decision"][:300]
suffix = "..." if len(e["decision"]) > 300 else ""
return f"{tag}\n{text}{suffix}"

View File

@@ -1,5 +1,5 @@
from langchain_core.tools import tool
from typing import Annotated
from typing import Annotated, Optional
from tradingagents.dataflows.interface import route_to_vendor
@tool
@@ -23,16 +23,20 @@ def get_news(
@tool
def get_global_news(
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "Number of days to look back"] = 7,
limit: Annotated[int, "Maximum number of articles to return"] = 5,
look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None,
limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None,
) -> str:
"""
Retrieve global news data.
Uses the configured news_data vendor.
Uses the configured news_data vendor. Defaults for look_back_days and
limit come from DEFAULT_CONFIG (global_news_lookback_days,
global_news_article_limit); pass explicit values to override.
Args:
curr_date (str): Current date in yyyy-mm-dd format
look_back_days (int): Number of days to look back (default 7)
limit (int): Maximum number of articles to return (default 5)
look_back_days (int): Number of days to look back; omit to inherit config
limit (int): Maximum number of articles to return; omit to inherit config
Returns:
str: A formatted string containing global news data
"""

View File

@@ -0,0 +1,50 @@
"""Shared 5-tier rating vocabulary and a deterministic heuristic parser.
The same five-tier scale (Buy, Overweight, Hold, Underweight, Sell) is used by:
- The Research Manager (investment plan recommendation)
- The Portfolio Manager (final position decision)
- The signal processor (rating extracted for downstream consumers)
- The memory log (rating tag stored alongside each decision entry)
Centralising it here avoids drift between those call sites.
"""
from __future__ import annotations
import re
from typing import Tuple
# Canonical, ordered 5-tier scale (most bullish to most bearish).
RATINGS_5_TIER: Tuple[str, ...] = (
"Buy", "Overweight", "Hold", "Underweight", "Sell",
)
_RATING_SET = {r.lower() for r in RATINGS_5_TIER}
# Matches "Rating: X" / "rating - X" / "Rating: **X**" — tolerates markdown
# bold wrappers and either a colon or hyphen separator.
_RATING_LABEL_RE = re.compile(r"rating.*?[:\-][\s*]*(\w+)", re.IGNORECASE)
def parse_rating(text: str, default: str = "Hold") -> str:
"""Heuristically extract a 5-tier rating from prose text.
Two-pass strategy:
1. Look for an explicit "Rating: X" label (tolerant of markdown bold).
2. Fall back to the first 5-tier rating word found anywhere in the text.
Returns a Title-cased rating string, or ``default`` if no rating word appears.
"""
for line in text.splitlines():
m = _RATING_LABEL_RE.search(line)
if m and m.group(1).lower() in _RATING_SET:
return m.group(1).capitalize()
for line in text.splitlines():
for word in line.lower().split():
clean = word.strip("*:.,")
if clean in _RATING_SET:
return clean.capitalize()
return default

View File

@@ -0,0 +1,73 @@
"""Shared helpers for invoking an agent with structured output and a graceful fallback.
The Portfolio Manager, Trader, and Research Manager all follow the same
canonical pattern:
1. At agent creation, wrap the LLM with ``with_structured_output(Schema)``
so the model returns a typed Pydantic instance. If the provider does
not support structured output (rare; mostly older Ollama models), the
wrap is skipped and the agent uses free-text generation instead.
2. At invocation, run the structured call and render the result back to
markdown. If the structured call itself fails for any reason
(malformed JSON from a weak model, transient provider issue), fall
back to a plain ``llm.invoke`` so the pipeline never blocks.
Centralising the pattern here keeps the agent factories small and ensures
all three agents log the same warnings when fallback fires.
"""
from __future__ import annotations
import logging
from typing import Any, Callable, Optional, TypeVar
from pydantic import BaseModel
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]:
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
Logs a warning when the binding fails so the user understands the agent
will use free-text generation for every call instead of one-shot fallback.
"""
try:
return llm.with_structured_output(schema)
except (NotImplementedError, AttributeError) as exc:
logger.warning(
"%s: provider does not support with_structured_output (%s); "
"falling back to free-text generation",
agent_name, exc,
)
return None
def invoke_structured_or_freetext(
structured_llm: Optional[Any],
plain_llm: Any,
prompt: Any,
render: Callable[[T], str],
agent_name: str,
) -> str:
"""Run the structured call and render to markdown; fall back to free-text on any failure.
``prompt`` is whatever the underlying LLM accepts (a string for chat
invocations, a list of message dicts for chat models that take that
shape). The same value is forwarded to the free-text path so the
fallback sees the same input the structured call did.
"""
if structured_llm is not None:
try:
result = structured_llm.invoke(prompt)
return render(result)
except Exception as exc:
logger.warning(
"%s: structured-output invocation failed (%s); retrying once as free text",
agent_name, exc,
)
response = plain_llm.invoke(prompt)
return response.content

View File

@@ -22,7 +22,7 @@ def get_indicators(
"""
# LLMs sometimes pass multiple indicators as a comma-separated string;
# split and process each individually.
indicators = [i.strip() for i in indicator.split(",") if i.strip()]
indicators = [i.strip().lower() for i in indicator.split(",") if i.strip()]
results = []
for ind in indicators:
try:

View File

@@ -1,6 +1,8 @@
import tradingagents.default_config as default_config
from copy import deepcopy
from typing import Dict, Optional
import tradingagents.default_config as default_config
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
@@ -9,22 +11,31 @@ def initialize_config():
"""Initialize the configuration with default values."""
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config = deepcopy(default_config.DEFAULT_CONFIG)
def set_config(config: Dict):
"""Update the configuration with custom values."""
"""Update the configuration with custom values.
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
partial update like ``{"data_vendors": {"core_stock_apis": "alpha_vantage"}}``
keeps the other nested keys from the default; scalar keys are replaced.
"""
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
initialize_config()
incoming = deepcopy(config)
for key, value in incoming.items():
if isinstance(value, dict) and isinstance(_config.get(key), dict):
_config[key].update(value)
else:
_config[key] = value
def get_config() -> Dict:
"""Get the current configuration."""
if _config is None:
initialize_config()
return _config.copy()
return deepcopy(_config)
# Initialize with default config

View File

@@ -0,0 +1,106 @@
"""Reddit search fetcher for ticker-specific discussion posts.
Uses Reddit's public JSON endpoints (``reddit.com/r/{sub}/search.json``)
which do not require an API key. Public throughput is ~10 requests per
minute per IP, well within budget for a single agent run that queries
a handful of finance subreddits per ticker.
Returns formatted plaintext blocks ready for prompt injection. Degrades
gracefully — returns a placeholder string rather than raising, so callers
never have to special-case missing data.
"""
from __future__ import annotations
import json
import logging
import time
from typing import Iterable
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode
from urllib.request import Request, urlopen
logger = logging.getLogger(__name__)
_API = "https://www.reddit.com/r/{sub}/search.json?{qs}"
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
# Default subreddits ordered roughly by signal density for ticker-specific
# discussion. wallstreetbets has the most volume but most noise; stocks /
# investing trend more measured. Caller can override.
DEFAULT_SUBREDDITS = ("wallstreetbets", "stocks", "investing")
def _fetch_subreddit(
ticker: str,
sub: str,
limit: int,
timeout: float,
) -> list[dict]:
qs = urlencode({
"q": ticker,
"restrict_sr": "on",
"sort": "new",
"t": "week", # last 7 days
"limit": limit,
})
url = _API.format(sub=sub, qs=qs)
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
try:
with urlopen(req, timeout=timeout) as resp:
payload = json.loads(resp.read())
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
logger.warning("Reddit fetch failed for r/%s · %s: %s", sub, ticker, exc)
return []
children = (payload.get("data") or {}).get("children") or []
return [c.get("data", {}) for c in children if isinstance(c, dict)]
def fetch_reddit_posts(
ticker: str,
subreddits: Iterable[str] = DEFAULT_SUBREDDITS,
limit_per_sub: int = 5,
timeout: float = 10.0,
inter_request_delay: float = 0.4,
) -> str:
"""Fetch recent Reddit posts mentioning ``ticker`` across finance
subreddits and return them as a formatted plaintext block.
``inter_request_delay`` keeps us under Reddit's public rate limit
(~10 req/min per IP) even if the caller queries many subreddits.
"""
blocks = []
total_posts = 0
for i, sub in enumerate(subreddits):
if i > 0:
time.sleep(inter_request_delay)
posts = _fetch_subreddit(ticker, sub, limit_per_sub, timeout)
total_posts += len(posts)
if not posts:
blocks.append(f"r/{sub}: <no posts found mentioning {ticker.upper()} in the past 7 days>")
continue
lines = [f"r/{sub}{len(posts)} recent posts mentioning {ticker.upper()}:"]
for p in posts:
title = (p.get("title") or "").replace("\n", " ").strip()
score = p.get("score", 0)
comments = p.get("num_comments", 0)
created = p.get("created_utc")
created_str = (
time.strftime("%Y-%m-%d", time.gmtime(created)) if created else "?"
)
selftext = (p.get("selftext") or "").replace("\n", " ").strip()
if len(selftext) > 240:
selftext = selftext[:240] + ""
lines.append(
f" [{created_str} · {score:>4}↑ · {comments:>3}c] {title}"
+ (f"\n body excerpt: {selftext}" if selftext else "")
)
blocks.append("\n".join(lines))
if total_posts == 0:
return (
f"<no Reddit posts found mentioning {ticker.upper()} across "
f"{', '.join(f'r/{s}' for s in subreddits)} in the past 7 days>"
)
return "\n\n".join(blocks)

View File

@@ -8,6 +8,7 @@ from stockstats import wrap
from typing import Annotated
import os
from .config import get_config
from .utils import safe_ticker_component
logger = logging.getLogger(__name__)
@@ -51,6 +52,10 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
subsequent calls the cache is reused. Rows after curr_date are
filtered out so backtests never see future prices.
"""
# Reject ticker values that would escape the cache directory when
# interpolated into the cache filename (e.g. ``../../tmp/x``).
safe_symbol = safe_ticker_component(symbol)
config = get_config()
curr_date_dt = pd.to_datetime(curr_date)
@@ -63,11 +68,11 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_str}-{end_str}.csv",
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file, on_bad_lines="skip")
data = pd.read_csv(data_file, on_bad_lines="skip", encoding="utf-8")
else:
data = yf_retry(lambda: yf.download(
symbol,
@@ -78,7 +83,7 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
auto_adjust=True,
))
data = data.reset_index()
data.to_csv(data_file, index=False)
data.to_csv(data_file, index=False, encoding="utf-8")
data = _clean_dataframe(data)

View File

@@ -0,0 +1,83 @@
"""StockTwits public symbol-stream fetcher.
StockTwits exposes a per-symbol message stream at
``api.stocktwits.com/api/2/streams/symbol/{ticker}.json`` that requires no
API key, no OAuth, and no registration. Each message includes a
user-labeled sentiment field (``Bullish``/``Bearish``/null), the message
body, timestamp, and posting user.
The function is deliberately self-contained: short timeout, graceful
degradation on any HTTP or parse failure, and a string return type so
the calling agent gets a uniform interface regardless of whether the
network call succeeded.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Optional
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
logger = logging.getLogger(__name__)
_API = "https://api.stocktwits.com/api/2/streams/symbol/{ticker}.json"
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
def fetch_stocktwits_messages(ticker: str, limit: int = 30, timeout: float = 10.0) -> str:
"""Fetch recent StockTwits messages for ``ticker`` and return them as a
formatted plaintext block ready for prompt injection.
Returns a placeholder string when the endpoint is unreachable, the
symbol has no messages, or the response shape is unexpected — the
caller never has to special-case None or exceptions.
"""
url = _API.format(ticker=ticker.upper())
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
try:
with urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read())
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
logger.warning("StockTwits fetch failed for %s: %s", ticker, exc)
return f"<stocktwits unavailable: {type(exc).__name__}>"
messages = data.get("messages", []) if isinstance(data, dict) else []
if not messages:
return f"<no StockTwits messages found for ${ticker.upper()}>"
lines = []
bullish = bearish = unlabeled = 0
for m in messages[:limit]:
created = m.get("created_at", "")
user = (m.get("user") or {}).get("username", "?")
entities = m.get("entities") or {}
sentiment_obj = entities.get("sentiment") or {}
sentiment = sentiment_obj.get("basic") if isinstance(sentiment_obj, dict) else None
body = (m.get("body") or "").replace("\n", " ").strip()
if len(body) > 280:
body = body[:280] + ""
if sentiment == "Bullish":
bullish += 1
tag = "Bullish"
elif sentiment == "Bearish":
bearish += 1
tag = "Bearish"
else:
unlabeled += 1
tag = "no-label"
lines.append(f"[{created} · @{user} · {tag}] {body}")
total = bullish + bearish + unlabeled
bull_pct = round(100 * bullish / total) if total else 0
bear_pct = round(100 * bearish / total) if total else 0
summary = (
f"Bullish: {bullish} ({bull_pct}%) · "
f"Bearish: {bearish} ({bear_pct}%) · "
f"Unlabeled: {unlabeled} · "
f"Total: {total} most-recent messages"
)
return summary + "\n\n" + "\n".join(lines)

View File

@@ -1,4 +1,5 @@
import os
import re
import json
import pandas as pd
from datetime import date, timedelta, datetime
@@ -6,9 +7,43 @@ from typing import Annotated
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
# Tickers can contain letters, digits, dot, dash, underscore, and caret
# (for index symbols like ^GSPC). Anything else is rejected so the value
# never escapes a containing directory when interpolated into a path.
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$")
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
"""Validate ``value`` is safe to interpolate into a filesystem path.
Tickers come from user CLI input or from LLM tool calls, both of which
can be influenced by attacker-controlled content (e.g. prompt injection
embedded in fetched news). Without validation, a value like
``"../../../etc/foo"`` flows into ``os.path.join`` / ``Path /`` and
escapes the configured cache, checkpoint, or results directory.
Returns ``value`` unchanged when it matches the allowed pattern; raises
``ValueError`` otherwise.
"""
if not isinstance(value, str) or not value:
raise ValueError(f"ticker must be a non-empty string, got {value!r}")
if len(value) > max_len:
raise ValueError(f"ticker exceeds {max_len} chars: {value!r}")
if not _TICKER_PATH_RE.fullmatch(value):
raise ValueError(
f"ticker contains characters not allowed in a filesystem path: {value!r}"
)
# The regex above allows '.', so values like '.', '..', '...' would pass,
# and as a path component they traverse the parent directory. Reject any
# value that's only dots.
if set(value) == {"."}:
raise ValueError(f"ticker cannot consist solely of dots: {value!r}")
return value
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
if save_path:
data.to_csv(save_path)
data.to_csv(save_path, encoding="utf-8")
print(f"{tag} saved to {save_path}")

View File

@@ -1,6 +1,7 @@
from typing import Annotated
from datetime import datetime
from dateutil.relativedelta import relativedelta
import pandas as pd
import yfinance as yf
import os
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date

View File

@@ -1,9 +1,12 @@
"""yfinance-based news data fetching functions."""
from typing import Optional
import yfinance as yf
from datetime import datetime
from dateutil.relativedelta import relativedelta
from .config import get_config
from .stockstats_utils import yf_retry
@@ -64,9 +67,10 @@ def get_news_yfinance(
Returns:
Formatted string containing news articles
"""
article_limit = get_config()["news_article_limit"]
try:
stock = yf.Ticker(ticker)
news = yf_retry(lambda: stock.get_news(count=20))
news = yf_retry(lambda: stock.get_news(count=article_limit))
if not news:
return f"No news found for {ticker}"
@@ -106,27 +110,28 @@ def get_news_yfinance(
def get_global_news_yfinance(
curr_date: str,
look_back_days: int = 7,
limit: int = 10,
look_back_days: Optional[int] = None,
limit: Optional[int] = None,
) -> str:
"""
Retrieve global/macro economic news using yfinance Search.
Args:
curr_date: Current date in yyyy-mm-dd format
look_back_days: Number of days to look back
limit: Maximum number of articles to return
look_back_days: Number of days to look back. ``None`` falls back to
``global_news_lookback_days`` from the active config.
limit: Maximum number of articles to return. ``None`` falls back to
``global_news_article_limit`` from the active config.
Returns:
Formatted string containing global news articles
"""
# Search queries for macro/global news
search_queries = [
"stock market economy",
"Federal Reserve interest rates",
"inflation economic outlook",
"global markets trading",
]
config = get_config()
if look_back_days is None:
look_back_days = config["global_news_lookback_days"]
if limit is None:
limit = config["global_news_article_limit"]
search_queries = config["global_news_queries"]
all_news = []
seen_titles = set()

View File

@@ -1,21 +1,71 @@
import os
DEFAULT_CONFIG = {
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
# Single source of truth for env-var → config-key overrides. To expose
# a new config key for environment-based override, add a row here — no
# entry-point script changes required. Coercion is driven by the type
# of the existing default, so users can keep writing plain strings in
# their .env file.
_ENV_OVERRIDES = {
"TRADINGAGENTS_LLM_PROVIDER": "llm_provider",
"TRADINGAGENTS_DEEP_THINK_LLM": "deep_think_llm",
"TRADINGAGENTS_QUICK_THINK_LLM": "quick_think_llm",
"TRADINGAGENTS_LLM_BACKEND_URL": "backend_url",
"TRADINGAGENTS_OUTPUT_LANGUAGE": "output_language",
"TRADINGAGENTS_MAX_DEBATE_ROUNDS": "max_debate_rounds",
"TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds",
"TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled",
}
def _coerce(value: str, reference):
"""Coerce env-var string to the type of the existing default value."""
if isinstance(reference, bool):
return value.strip().lower() in ("true", "1", "yes", "on")
if isinstance(reference, int) and not isinstance(reference, bool):
return int(value)
if isinstance(reference, float):
return float(value)
return value
def _apply_env_overrides(config: dict) -> dict:
"""Apply TRADINGAGENTS_* env vars to the config dict in-place."""
for env_var, key in _ENV_OVERRIDES.items():
raw = os.environ.get(env_var)
if raw is None or raw == "":
continue
config[key] = _coerce(raw, config.get(key))
return config
DEFAULT_CONFIG = _apply_env_overrides({
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
"memory_log_path": os.getenv("TRADINGAGENTS_MEMORY_LOG_PATH", os.path.join(_TRADINGAGENTS_HOME, "memory", "trading_memory.md")),
# Optional cap on the number of resolved memory log entries. When set,
# the oldest resolved entries are pruned once this limit is exceeded.
# Pending entries are never pruned. None disables rotation entirely.
"memory_log_max_entries": None,
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "gpt-5.4",
"quick_think_llm": "gpt-5.4-mini",
"backend_url": "https://api.openai.com/v1",
# When None, each provider's client falls back to its own default endpoint
# (api.openai.com for OpenAI, generativelanguage.googleapis.com for Gemini, ...).
# The CLI overrides this per provider when the user picks one. Keeping a
# provider-specific URL here would leak (e.g. OpenAI's /v1 was previously
# being forwarded to Gemini, producing malformed request URLs).
"backend_url": None,
# Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"
"anthropic_effort": None, # "high", "medium", "low"
# Checkpoint/resume: when True, LangGraph saves state after each node
# so a crashed run can resume from the last successful step.
"checkpoint_enabled": False,
# Output language for analyst reports and final decision
# Internal agent debate stays in English for reasoning quality
"output_language": "English",
@@ -24,6 +74,21 @@ DEFAULT_CONFIG = {
"max_risk_discuss_rounds": 1,
"max_recur_limit": 100,
"analyst_concurrency_limit": 1,
# News / data fetching parameters
# Increase for longer lookback strategies or to broaden macro coverage;
# decrease to reduce token usage in agent prompts.
"news_article_limit": 20, # max articles per ticker (ticker-news)
"global_news_article_limit": 10, # max articles for global/macro news
"global_news_lookback_days": 7, # macro news lookback window
# Search queries used by get_global_news for macro headlines. Extend or
# replace to broaden geographic / sector coverage.
"global_news_queries": [
"Federal Reserve interest rates inflation",
"S&P 500 earnings GDP economic outlook",
"geopolitical risk trade war sanctions",
"ECB Bank of England BOJ central bank policy",
"oil commodities supply chain energy",
],
# Data vendor configuration
# Category-level configuration (default for all tools in category)
"data_vendors": {
@@ -36,4 +101,4 @@ DEFAULT_CONFIG = {
"tool_vendors": {
# Example: "get_stock_data": "alpha_vantage", # Override category default
},
}
})

View File

@@ -0,0 +1,90 @@
"""LangGraph checkpoint support for resumable analysis runs.
Per-ticker SQLite databases so concurrent tickers don't contend.
"""
from __future__ import annotations
import hashlib
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from langgraph.checkpoint.sqlite import SqliteSaver
from tradingagents.dataflows.utils import safe_ticker_component
def _db_path(data_dir: str | Path, ticker: str) -> Path:
"""Return the SQLite checkpoint DB path for a ticker."""
# Reject ticker values that would escape the checkpoints directory.
safe = safe_ticker_component(ticker).upper()
p = Path(data_dir) / "checkpoints"
p.mkdir(parents=True, exist_ok=True)
return p / f"{safe}.db"
def thread_id(ticker: str, date: str) -> str:
"""Deterministic thread ID for a ticker+date pair."""
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
@contextmanager
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
db = _db_path(data_dir, ticker)
conn = sqlite3.connect(str(db), check_same_thread=False)
try:
saver = SqliteSaver(conn)
saver.setup()
yield saver
finally:
conn.close()
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
"""Check whether a resumable checkpoint exists for ticker+date."""
return checkpoint_step(data_dir, ticker, date) is not None
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
"""Return the step number of the latest checkpoint, or None if none exists."""
db = _db_path(data_dir, ticker)
if not db.exists():
return None
tid = thread_id(ticker, date)
with get_checkpointer(data_dir, ticker) as saver:
config = {"configurable": {"thread_id": tid}}
cp = saver.get_tuple(config)
if cp is None:
return None
return cp.metadata.get("step")
def clear_all_checkpoints(data_dir: str | Path) -> int:
"""Remove all checkpoint DBs. Returns number of files deleted."""
cp_dir = Path(data_dir) / "checkpoints"
if not cp_dir.exists():
return 0
dbs = list(cp_dir.glob("*.db"))
for db in dbs:
db.unlink()
return len(dbs)
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
"""Remove checkpoint for a specific ticker+date by deleting the thread's rows."""
db = _db_path(data_dir, ticker)
if not db.exists():
return
tid = thread_id(ticker, date)
conn = sqlite3.connect(str(db))
try:
for table in ("writes", "checkpoints"):
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
conn.commit()
except sqlite3.OperationalError:
pass
finally:
conn.close()

View File

@@ -16,13 +16,14 @@ class Propagator:
self.max_recur_limit = max_recur_limit
def create_initial_state(
self, company_name: str, trade_date: str
self, company_name: str, trade_date: str, past_context: str = ""
) -> Dict[str, Any]:
"""Create the initial state for the agent graph."""
return {
"messages": [("human", company_name)],
"company_of_interest": company_name,
"trade_date": str(trade_date),
"past_context": past_context,
"investment_debate_state": InvestDebateState(
{
"bull_history": "",

View File

@@ -1,121 +1,53 @@
# TradingAgents/graph/reflection.py
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from typing import Any
class Reflector:
"""Handles reflection on decisions and updating memory."""
"""Handles reflection on trading decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI):
def __init__(self, quick_thinking_llm: Any):
"""Initialize the reflector with an LLM."""
self.quick_thinking_llm = quick_thinking_llm
self.reflection_system_prompt = self._get_reflection_prompt()
self.log_reflection_prompt = self._get_log_reflection_prompt()
def _get_reflection_prompt(self) -> str:
"""Get the system prompt for reflection."""
return """
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
def _get_log_reflection_prompt(self) -> str:
"""Concise prompt for reflect_on_final_decision (Phase B log entries).
1. Reasoning:
- For each trading decision, determine whether it was correct or incorrect. A correct decision results in an increase in returns, while an incorrect decision does the opposite.
- Analyze the contributing factors to each success or mistake. Consider:
- Market intelligence.
- Technical indicators.
- Technical signals.
- Price movement analysis.
- Overall market data analysis
- News analysis.
- Social media and sentiment analysis.
- Fundamental data analysis.
- Weight the importance of each factor in the decision-making process.
Produces 2-4 sentences of plain prose — compact enough to be re-injected
into future agent prompts without bloating the context window.
"""
return (
"You are a trading analyst reviewing your own past decision now that the outcome is known.\n"
"Write exactly 2-4 sentences of plain prose (no bullets, no headers, no markdown).\n\n"
"Cover in order:\n"
"1. Was the directional call correct? (cite the alpha figure)\n"
"2. Which part of the investment thesis held or failed?\n"
"3. One concrete lesson to apply to the next similar analysis.\n\n"
"Be specific and terse. Your output will be stored verbatim in a decision log "
"and re-read by future analysts, so every word must earn its place."
)
2. Improvement:
- For any incorrect decisions, propose revisions to maximize returns.
- Provide a detailed list of corrective actions or improvements, including specific recommendations (e.g., changing a decision from HOLD to BUY on a particular date).
3. Summary:
- Summarize the lessons learned from the successes and mistakes.
- Highlight how these lessons can be adapted for future trading scenarios and draw connections between similar situations to apply the knowledge gained.
4. Query:
- Extract key insights from the summary into a concise sentence of no more than 1000 tokens.
- Ensure the condensed sentence captures the essence of the lessons and reasoning for easy reference.
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
"""
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
"""Extract the current market situation from the state."""
curr_market_report = current_state["market_report"]
curr_sentiment_report = current_state["sentiment_report"]
curr_news_report = current_state["news_report"]
curr_fundamentals_report = current_state["fundamentals_report"]
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
def _reflect_on_component(
self, component_type: str, report: str, situation: str, returns_losses
def reflect_on_final_decision(
self,
final_decision: str,
raw_return: float,
alpha_return: float,
) -> str:
"""Generate reflection for a component."""
"""Single reflection call on the final trade decision with outcome context.
Used by Phase B deferred reflection. The final_trade_decision already
synthesises all analyst insights, so no separate market context is needed.
"""
messages = [
("system", self.reflection_system_prompt),
("system", self.log_reflection_prompt),
(
"human",
f"Returns: {returns_losses}\n\nAnalysis/Decision: {report}\n\nObjective Market Reports for Reference: {situation}",
(
f"Raw return: {raw_return:+.1%}\n"
f"Alpha vs SPY: {alpha_return:+.1%}\n\n"
f"Final Decision:\n{final_decision}"
),
),
]
result = self.quick_thinking_llm.invoke(messages).content
return result
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
"""Reflect on bull researcher's analysis and update memory."""
situation = self._extract_current_situation(current_state)
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
result = self._reflect_on_component(
"BULL", bull_debate_history, situation, returns_losses
)
bull_memory.add_situations([(situation, result)])
def reflect_bear_researcher(self, current_state, returns_losses, bear_memory):
"""Reflect on bear researcher's analysis and update memory."""
situation = self._extract_current_situation(current_state)
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
result = self._reflect_on_component(
"BEAR", bear_debate_history, situation, returns_losses
)
bear_memory.add_situations([(situation, result)])
def reflect_trader(self, current_state, returns_losses, trader_memory):
"""Reflect on trader's decision and update memory."""
situation = self._extract_current_situation(current_state)
trader_decision = current_state["trader_investment_plan"]
result = self._reflect_on_component(
"TRADER", trader_decision, situation, returns_losses
)
trader_memory.add_situations([(situation, result)])
def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory):
"""Reflect on investment judge's decision and update memory."""
situation = self._extract_current_situation(current_state)
judge_decision = current_state["investment_debate_state"]["judge_decision"]
result = self._reflect_on_component(
"INVEST JUDGE", judge_decision, situation, returns_losses
)
invest_judge_memory.add_situations([(situation, result)])
def reflect_portfolio_manager(self, current_state, returns_losses, portfolio_manager_memory):
"""Reflect on portfolio manager's decision and update memory."""
situation = self._extract_current_situation(current_state)
judge_decision = current_state["risk_debate_state"]["judge_decision"]
result = self._reflect_on_component(
"PORTFOLIO MANAGER", judge_decision, situation, returns_losses
)
portfolio_manager_memory.add_situations([(situation, result)])
return self.quick_thinking_llm.invoke(messages).content

View File

@@ -1,8 +1,7 @@
# TradingAgents/graph/setup.py
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START
from typing import Any, Dict
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
@@ -17,14 +16,9 @@ class GraphSetup:
def __init__(
self,
quick_thinking_llm: ChatOpenAI,
deep_thinking_llm: ChatOpenAI,
quick_thinking_llm: Any,
deep_thinking_llm: Any,
tool_nodes: Dict[str, ToolNode],
bull_memory,
bear_memory,
trader_memory,
invest_judge_memory,
portfolio_manager_memory,
conditional_logic: ConditionalLogic,
analyst_concurrency_limit: int = 1,
):
@@ -32,11 +26,6 @@ class GraphSetup:
self.quick_thinking_llm = quick_thinking_llm
self.deep_thinking_llm = deep_thinking_llm
self.tool_nodes = tool_nodes
self.bull_memory = bull_memory
self.bear_memory = bear_memory
self.trader_memory = trader_memory
self.invest_judge_memory = invest_judge_memory
self.portfolio_manager_memory = portfolio_manager_memory
self.conditional_logic = conditional_logic
self.analyst_concurrency_limit = analyst_concurrency_limit
@@ -59,30 +48,22 @@ class GraphSetup:
analyst_factories = {
"market": lambda: create_market_analyst(self.quick_thinking_llm),
"social": lambda: create_social_media_analyst(self.quick_thinking_llm),
"social": lambda: create_sentiment_analyst(self.quick_thinking_llm),
"news": lambda: create_news_analyst(self.quick_thinking_llm),
"fundamentals": lambda: create_fundamentals_analyst(self.quick_thinking_llm),
}
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory
)
bear_researcher_node = create_bear_researcher(
self.quick_thinking_llm, self.bear_memory
)
research_manager_node = create_research_manager(
self.deep_thinking_llm, self.invest_judge_memory
)
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
bull_researcher_node = create_bull_researcher(self.quick_thinking_llm)
bear_researcher_node = create_bear_researcher(self.quick_thinking_llm)
research_manager_node = create_research_manager(self.deep_thinking_llm)
trader_node = create_trader(self.quick_thinking_llm)
# Create risk analysis nodes
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm)
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
conservative_analyst = create_conservative_debator(self.quick_thinking_llm)
portfolio_manager_node = create_portfolio_manager(
self.deep_thinking_llm, self.portfolio_manager_memory
)
portfolio_manager_node = create_portfolio_manager(self.deep_thinking_llm)
# Create workflow
workflow = StateGraph(AgentState)
@@ -173,5 +154,4 @@ class GraphSetup:
workflow.add_edge("Portfolio Manager", END)
# Compile and return
return workflow.compile()
return workflow

View File

@@ -1,33 +1,31 @@
# TradingAgents/graph/signal_processing.py
"""Extract the 5-tier portfolio rating from the Portfolio Manager's decision.
from langchain_openai import ChatOpenAI
The Portfolio Manager produces a typed ``PortfolioDecision`` via structured
output and renders it to markdown that always carries a ``**Rating**: X``
header (see :func:`tradingagents.agents.schemas.render_pm_decision`). The
deterministic heuristic in :mod:`tradingagents.agents.utils.rating` is more
than sufficient to extract that rating; no extra LLM call is needed.
This module exists for backwards compatibility with callers that expect a
``SignalProcessor.process_signal(text)`` interface.
"""
from __future__ import annotations
from typing import Any
from tradingagents.agents.utils.rating import parse_rating
class SignalProcessor:
"""Processes trading signals to extract actionable decisions."""
"""Read the 5-tier rating out of a Portfolio Manager decision."""
def __init__(self, quick_thinking_llm: ChatOpenAI):
"""Initialize with an LLM for processing."""
def __init__(self, quick_thinking_llm: Any = None):
# The LLM argument is accepted for backwards compatibility but no
# longer used: the PM's structured output guarantees the rating is
# parseable from the rendered markdown without a second LLM call.
self.quick_thinking_llm = quick_thinking_llm
def process_signal(self, full_signal: str) -> str:
"""
Process a full trading signal to extract the core decision.
Args:
full_signal: Complete trading signal text
Returns:
Extracted rating (BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, or SELL)
"""
messages = [
(
"system",
"You are an efficient assistant that extracts the trading decision from analyst reports. "
"Extract the rating as exactly one of: BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, SELL. "
"Output only the single rating word, nothing else.",
),
("human", full_signal),
]
return self.quick_thinking_llm.invoke(messages).content
"""Return one of Buy / Overweight / Hold / Underweight / Sell."""
return parse_rating(full_signal)

View File

@@ -1,18 +1,24 @@
# TradingAgents/graph/trading_graph.py
import logging
import os
from pathlib import Path
import json
from datetime import date
from datetime import datetime, timedelta
from typing import Dict, Any, Tuple, List, Optional
import yfinance as yf
logger = logging.getLogger(__name__)
from langgraph.prebuilt import ToolNode
from tradingagents.llm_clients import create_llm_client
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.memory import TradingMemoryLog
from tradingagents.dataflows.utils import safe_ticker_component
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
@@ -33,6 +39,7 @@ from tradingagents.agents.utils.agent_utils import (
get_global_news
)
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
@@ -66,10 +73,8 @@ class TradingAgentsGraph:
set_config(self.config)
# Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
os.makedirs(self.config["results_dir"], exist_ok=True)
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
@@ -94,12 +99,7 @@ class TradingAgentsGraph:
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.portfolio_manager_memory = FinancialSituationMemory("portfolio_manager_memory", self.config)
self.memory_log = TradingMemoryLog(self.config)
# Create tool nodes
self.tool_nodes = self._create_tool_nodes()
@@ -113,16 +113,13 @@ class TradingAgentsGraph:
self.quick_thinking_llm,
self.deep_thinking_llm,
self.tool_nodes,
self.bull_memory,
self.bear_memory,
self.trader_memory,
self.invest_judge_memory,
self.portfolio_manager_memory,
self.conditional_logic,
analyst_concurrency_limit=self.config.get("analyst_concurrency_limit", 1),
)
self.propagator = Propagator()
self.propagator = Propagator(
max_recur_limit=self.config.get("max_recur_limit", 100),
)
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
@@ -131,8 +128,10 @@ class TradingAgentsGraph:
self.ticker = None
self.log_states_dict = {} # date to full state dict
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
# Set up the graph: keep the workflow for recompilation with a checkpointer.
self.workflow = self.graph_setup.setup_graph(selected_analysts)
self.graph = self.workflow.compile()
self._checkpointer_ctx = None
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
@@ -192,19 +191,133 @@ class TradingAgentsGraph:
),
}
def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date."""
def _fetch_returns(
self, ticker: str, trade_date: str, holding_days: int = 5
) -> Tuple[Optional[float], Optional[float], Optional[int]]:
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
Returns (raw_return, alpha_return, actual_holding_days) or
(None, None, None) if price data is unavailable (too recent, delisted,
or network error).
"""
try:
start = datetime.strptime(trade_date, "%Y-%m-%d")
end = start + timedelta(days=holding_days + 7) # buffer for weekends/holidays
end_str = end.strftime("%Y-%m-%d")
stock = yf.Ticker(ticker).history(start=trade_date, end=end_str)
spy = yf.Ticker("SPY").history(start=trade_date, end=end_str)
if len(stock) < 2 or len(spy) < 2:
return None, None, None
actual_days = min(holding_days, len(stock) - 1, len(spy) - 1)
raw = float(
(stock["Close"].iloc[actual_days] - stock["Close"].iloc[0])
/ stock["Close"].iloc[0]
)
spy_ret = float(
(spy["Close"].iloc[actual_days] - spy["Close"].iloc[0])
/ spy["Close"].iloc[0]
)
alpha = raw - spy_ret
return raw, alpha, actual_days
except Exception as e:
logger.warning(
"Could not resolve outcome for %s on %s (will retry next run): %s",
ticker, trade_date, e,
)
return None, None, None
def _resolve_pending_entries(self, ticker: str) -> None:
"""Resolve pending log entries for ticker at the start of a new run.
Fetches returns for each same-ticker pending entry, generates reflections,
then writes all updates in a single atomic batch write to avoid redundant I/O.
Skips entries whose price data is not yet available (too recent or delisted).
Trade-off: only same-ticker entries are resolved per run. Entries for
other tickers accumulate until that ticker is run again.
"""
pending = [e for e in self.memory_log.get_pending_entries() if e["ticker"] == ticker]
if not pending:
return
updates = []
for entry in pending:
raw, alpha, days = self._fetch_returns(ticker, entry["date"])
if raw is None:
continue # price not available yet — try again next run
reflection = self.reflector.reflect_on_final_decision(
final_decision=entry.get("decision", ""),
raw_return=raw,
alpha_return=alpha,
)
updates.append({
"ticker": ticker,
"trade_date": entry["date"],
"raw_return": raw,
"alpha_return": alpha,
"holding_days": days,
"reflection": reflection,
})
if updates:
self.memory_log.batch_update_with_outcomes(updates)
def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date.
When ``checkpoint_enabled`` is set in config, the graph is recompiled
with a per-ticker SqliteSaver so a crashed run can resume from the last
successful node on a subsequent invocation with the same ticker+date.
"""
self.ticker = company_name
# Initialize state
# Resolve any pending memory-log entries for this ticker before the pipeline runs.
self._resolve_pending_entries(company_name)
# Recompile with a checkpointer if the user opted in.
if self.config.get("checkpoint_enabled"):
self._checkpointer_ctx = get_checkpointer(
self.config["data_cache_dir"], company_name
)
saver = self._checkpointer_ctx.__enter__()
self.graph = self.workflow.compile(checkpointer=saver)
step = checkpoint_step(
self.config["data_cache_dir"], company_name, str(trade_date)
)
if step is not None:
logger.info(
"Resuming from step %d for %s on %s", step, company_name, trade_date
)
else:
logger.info("Starting fresh for %s on %s", company_name, trade_date)
try:
return self._run_graph(company_name, trade_date)
finally:
if self._checkpointer_ctx is not None:
self._checkpointer_ctx.__exit__(None, None, None)
self._checkpointer_ctx = None
self.graph = self.workflow.compile()
def _run_graph(self, company_name, trade_date):
"""Execute the graph and write the resulting state to disk and memory log."""
# Initialize state — inject memory log context for PM.
past_context = self.memory_log.get_past_context(company_name)
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
company_name, trade_date, past_context=past_context
)
args = self.propagator.get_graph_args()
# Inject thread_id so same ticker+date resumes, different date starts fresh.
if self.config.get("checkpoint_enabled"):
tid = thread_id(company_name, str(trade_date))
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
@@ -212,19 +325,33 @@ class TradingAgentsGraph:
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
final_state = trace[-1]
# Streamed chunks are per-node deltas. Merge them so the returned
# state matches what graph.invoke() yields in the non-debug path.
final_state = {}
for chunk in trace:
final_state.update(chunk)
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
# Store current state for reflection
# Store current state for reflection.
self.curr_state = final_state
# Log state
# Log state to disk.
self._log_state(trade_date, final_state)
# Return decision and processed signal
# Store decision for deferred reflection on the next same-ticker run.
self.memory_log.store_decision(
ticker=company_name,
trade_date=trade_date,
final_trade_decision=final_state["final_trade_decision"],
)
# Clear checkpoint on successful completion to avoid stale state.
if self.config.get("checkpoint_enabled"):
clear_checkpoint(
self.config["data_cache_dir"], company_name, str(trade_date)
)
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state):
@@ -259,34 +386,15 @@ class TradingAgentsGraph:
"final_trade_decision": final_state["final_trade_decision"],
}
# Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
# Save to file. Reject ticker values that would escape the
# results directory when joined as a path component.
safe_ticker = safe_ticker_component(self.ticker)
directory = Path(self.config["results_dir"]) / safe_ticker / "TradingAgentsStrategy_logs"
directory.mkdir(parents=True, exist_ok=True)
with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
"w",
encoding="utf-8",
) as f:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
)
self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory
)
self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory
)
self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory
)
self.reflector.reflect_portfolio_manager(
self.curr_state, returns_losses, self.portfolio_manager_memory
)
log_path = directory / f"full_states_log_{trade_date}.json"
with open(log_path, "w", encoding="utf-8") as f:
json.dump(self.log_states_dict[str(trade_date)], f, indent=4)
def process_signal(self, full_signal):
"""Process a signal to extract the core decision."""

View File

@@ -0,0 +1,44 @@
"""Canonical provider -> API-key env-var mapping.
A single source of truth for which environment variable holds the API
key for each supported LLM provider. Used by the CLI's interactive key
prompt (cli/utils.ensure_api_key) and by anything else that needs to
ask "does this provider require a key, and which env var is it?".
When adding a new provider, register its env var here so the CLI flow
prompts for it automatically instead of failing on first API call.
"""
from __future__ import annotations
from typing import Optional
PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"google": "GOOGLE_API_KEY",
"azure": "AZURE_OPENAI_API_KEY",
"xai": "XAI_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
# Dual-region providers each carry their own account; keys are not
# interchangeable between the international and China endpoints.
"qwen": "DASHSCOPE_API_KEY",
"qwen-cn": "DASHSCOPE_CN_API_KEY",
"glm": "ZHIPU_API_KEY",
"glm-cn": "ZHIPU_CN_API_KEY",
"minimax": "MINIMAX_API_KEY",
"minimax-cn": "MINIMAX_CN_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
# Local runtimes do not authenticate.
"ollama": None,
}
def get_api_key_env(provider: str) -> Optional[str]:
"""Return the env var name for `provider`'s API key, or None if not applicable.
Unknown providers also return None — callers should treat that as
"no key check possible" rather than as "no key required".
"""
return PROVIDER_API_KEY_ENV.get(provider.lower())

View File

@@ -0,0 +1,52 @@
import os
from typing import Any, Optional
from langchain_openai import AzureChatOpenAI
from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "api_key", "reasoning_effort",
"callbacks", "http_client", "http_async_client",
)
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
"""AzureChatOpenAI with normalized content output."""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
class AzureOpenAIClient(BaseLLMClient):
"""Client for Azure OpenAI deployments.
Requires environment variables:
AZURE_OPENAI_API_KEY: API key
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
"""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured AzureChatOpenAI instance."""
self.warn_if_unknown_model()
llm_kwargs = {
"model": self.model,
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
}
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return NormalizedAzureChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Azure accepts any deployed model name."""
return True

View File

@@ -0,0 +1,120 @@
"""Declarative per-model capability table for OpenAI-compatible providers.
This is the single place that knows which model IDs reject which API
parameters or require which structured-output method. The LLM client
subclasses consult ``get_capabilities(model_name)`` instead of hardcoding
model-name ``if`` ladders, so adding a new model (or a new provider quirk)
means editing this table — not the client code.
Pattern adapted from the per-model ``compat:`` flags DeepSeek themselves
publish in their integration guides (e.g. the Oh My Pi config schema
documents ``supportsToolChoice``, ``requiresReasoningContentForToolCalls``
as declarative per-model fields).
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Literal
StructuredMethod = Literal[
"function_calling", # uses tools; respects supports_tool_choice
"json_mode", # uses response_format={"type":"json_object"}
"json_schema", # uses response_format={"type":"json_schema",...}
"none", # no structured output available; caller falls back to free-text
]
@dataclass(frozen=True)
class ModelCapabilities:
"""What an OpenAI-compatible model accepts at the API level."""
supports_tool_choice: bool
supports_json_mode: bool
supports_json_schema: bool
preferred_structured_method: StructuredMethod
# DeepSeek thinking-mode models 400 if reasoning_content from prior
# assistant turns is not echoed back on the next request.
requires_reasoning_content_roundtrip: bool = False
# DeepSeek's thinking models accept the ``tools`` array but reject the
# ``tool_choice`` parameter (official Oh My Pi integration guide and the
# 400 response in issue #678). Their official tool-calling examples
# (api-docs.deepseek.com/guides/tool_calls) pass ``tools=[...]`` without
# ``tool_choice`` — we mirror that pattern by setting supports_tool_choice
# to False and letting the client suppress the kwarg.
_DEEPSEEK_THINKING = ModelCapabilities(
supports_tool_choice=False,
supports_json_mode=True,
supports_json_schema=False,
preferred_structured_method="function_calling",
requires_reasoning_content_roundtrip=True,
)
_DEEPSEEK_CHAT = ModelCapabilities(
supports_tool_choice=True,
supports_json_mode=True,
supports_json_schema=False,
preferred_structured_method="function_calling",
)
# MiniMax M2.x reasoning models accept the tools array, but their
# tool_choice parameter is restricted to the enum {"none", "auto"}
# (platform.minimax.io/docs/api-reference/text-post). Langchain's
# function_calling path sends tool_choice as a function-spec dict, which
# MiniMax 400s — same shape as the DeepSeek bug. supports_tool_choice=False
# makes the dispatch in NormalizedChatOpenAI suppress the kwarg; the schema
# still ships as a tool. json_mode response_format is only for
# MiniMax-Text-01, not M2.x.
_MINIMAX_THINKING = ModelCapabilities(
supports_tool_choice=False,
supports_json_mode=False,
supports_json_schema=False,
preferred_structured_method="function_calling",
)
_DEFAULT = ModelCapabilities(
supports_tool_choice=True,
supports_json_mode=True,
supports_json_schema=True,
preferred_structured_method="function_calling",
)
# Exact-ID matches take precedence over pattern matches.
_BY_ID: dict[str, ModelCapabilities] = {
"deepseek-chat": _DEEPSEEK_CHAT,
"deepseek-reasoner": _DEEPSEEK_THINKING,
"deepseek-v4-flash": _DEEPSEEK_THINKING,
"deepseek-v4-pro": _DEEPSEEK_THINKING,
# MiniMax — full official model lineup per
# platform.minimax.io/docs/api-reference/text-openai-api
"MiniMax-M2.7": _MINIMAX_THINKING,
"MiniMax-M2.7-highspeed": _MINIMAX_THINKING,
"MiniMax-M2.5": _MINIMAX_THINKING,
"MiniMax-M2.5-highspeed": _MINIMAX_THINKING,
"MiniMax-M2.1": _MINIMAX_THINKING,
"MiniMax-M2.1-highspeed": _MINIMAX_THINKING,
"MiniMax-M2": _MINIMAX_THINKING,
}
# Forward-compat patterns. New ``deepseek-v5-*`` / ``deepseek-reasoner-*``
# or ``MiniMax-M3*`` variants inherit the thinking-mode quirks automatically.
_BY_PATTERN: list[tuple[re.Pattern[str], ModelCapabilities]] = [
(re.compile(r"^deepseek-v\d"), _DEEPSEEK_THINKING),
(re.compile(r"^deepseek-reasoner"), _DEEPSEEK_THINKING),
(re.compile(r"^MiniMax-M\d"), _MINIMAX_THINKING),
]
def get_capabilities(model_name: str) -> ModelCapabilities:
"""Resolve capabilities by exact ID, then pattern, then default."""
if model_name in _BY_ID:
return _BY_ID[model_name]
for pattern, caps in _BY_PATTERN:
if pattern.match(model_name):
return caps
return _DEFAULT

View File

@@ -1,9 +1,15 @@
from typing import Optional
from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
# Providers that use the OpenAI-compatible chat completions API
_OPENAI_COMPATIBLE = (
"openai", "xai", "deepseek",
"qwen", "qwen-cn",
"glm", "glm-cn",
"minimax", "minimax-cn",
"ollama", "openrouter",
)
def create_llm_client(
@@ -14,17 +20,15 @@ def create_llm_client(
) -> BaseLLMClient:
"""Create an LLM client for the specified provider.
Provider modules are imported lazily so that simply importing this
factory (e.g. during test collection) does not pull in heavy LLM SDKs
or fail when their API keys are absent.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
provider: LLM provider name
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
- http_client: Custom httpx.Client for SSL proxy or certificate customization
- http_async_client: Custom httpx.AsyncClient for async operations
- timeout: Request timeout in seconds
- max_retries: Maximum retry attempts
- api_key: API key for the provider
- callbacks: LangChain callbacks
Returns:
Configured BaseLLMClient instance
@@ -34,16 +38,20 @@ def create_llm_client(
"""
provider_lower = provider.lower()
if provider_lower in ("openai", "ollama", "openrouter"):
if provider_lower in _OPENAI_COMPATIBLE:
from .openai_client import OpenAIClient
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
if provider_lower == "xai":
return OpenAIClient(model, base_url, provider="xai", **kwargs)
if provider_lower == "anthropic":
from .anthropic_client import AnthropicClient
return AnthropicClient(model, base_url, **kwargs)
if provider_lower == "google":
from .google_client import GoogleClient
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "azure":
from .azure_client import AzureOpenAIClient
return AzureOpenAIClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@@ -8,71 +8,152 @@ ModelOption = Tuple[str, str]
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
# Shared model list for GLM via Z.AI (international) and BigModel (China).
# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides).
# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}.
_GLM_MODELS: Dict[str, List[ModelOption]] = {
"quick": [
("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"),
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
("GLM-4.5-Air - Lightweight, cost-efficient", "glm-4.5-air"),
("Custom model ID", "custom"),
],
"deep": [
("GLM-5.1 - Latest flagship, 204K ctx", "glm-5.1"),
("GLM-5 - Flagship, 204K ctx", "glm-5"),
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
("Custom model ID", "custom"),
],
}
# Shared model list for Qwen's global (dashscope-intl) and CN (dashscope) endpoints.
# Source: modelstudio.console.alibabacloud.com (Featured Models — Flagship + Cost-optimized).
#
# Only versioned IDs are exposed in the dropdown. The version-less aliases
# (qwen-plus, qwen-flash) are documented by Alibaba as auto-upgrading
# pointers ("backbone, latest, and snapshot ... have been upgraded to the
# Qwen3 series"), which means their behavior shifts when Alibaba rotates
# the backing model. Users who want a specific generation pick it
# explicitly; users who really want auto-latest can enter the alias via
# "Custom model ID".
_QWEN_MODELS: Dict[str, List[ModelOption]] = {
"quick": [
("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"),
("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"),
("Custom model ID", "custom"),
],
"deep": [
("Qwen 3.6 Plus - Flagship vision-language, agentic coding SOTA", "qwen3.6-plus"),
("Qwen 3.5 Plus - Previous-gen flagship", "qwen3.5-plus"),
("Qwen 3 Max - Specialized for agent programming + tool use", "qwen3-max"),
("Custom model ID", "custom"),
],
}
# Shared model list for MiniMax's global and CN endpoints (same IDs).
# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api.
# All M2.x models share a 204,800-token context window.
_MINIMAX_MODELS: Dict[str, List[ModelOption]] = {
"quick": [
("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"),
("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"),
("MiniMax-M2.1-highspeed - M2.1 highspeed, 204K ctx", "MiniMax-M2.1-highspeed"),
("Custom model ID", "custom"),
],
"deep": [
("MiniMax-M2.7 - Flagship, SOTA on coding/agent benchmarks, 204K ctx", "MiniMax-M2.7"),
("MiniMax-M2.7-highspeed - Same quality as M2.7, ~100 TPS", "MiniMax-M2.7-highspeed"),
("MiniMax-M2.5 - Previous-gen flagship, 204K ctx", "MiniMax-M2.5"),
("MiniMax-M2.1 - Earlier M2 line, 204K ctx", "MiniMax-M2.1"),
("MiniMax-M2 - Base M2, 204K ctx", "MiniMax-M2"),
("Custom model ID", "custom"),
],
}
MODEL_OPTIONS: ProviderModeOptions = {
"openai": {
"quick": [
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
("GPT-5.4 Nano - Cheapest, high-volume tasks", "gpt-5.4-nano"),
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"),
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
],
"deep": [
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"),
("GPT-5.4 - Previous-gen frontier, 1M context, cost-effective", "gpt-5.4"),
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
("GPT-5.5 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.5-pro"),
],
},
"anthropic": {
"quick": [
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
("Claude Haiku 4.5 - Fastest with near-frontier intelligence", "claude-haiku-4-5"),
("Claude Sonnet 4.5 - High-performance for agents and coding", "claude-sonnet-4-5"),
],
"deep": [
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
("Claude Opus 4.7 - Latest frontier, long-running agents and coding", "claude-opus-4-7"),
("Claude Opus 4.6 - Frontier intelligence, agents and coding", "claude-opus-4-6"),
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
],
},
"google": {
"quick": [
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"),
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
("Gemini 3.1 Flash Lite - Most cost-efficient (GA)", "gemini-3.1-flash-lite"),
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
],
"deep": [
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 3.1 Pro - Reasoning-first, complex workflows (preview)", "gemini-3.1-pro-preview"),
("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"),
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
],
},
"xai": {
"quick": [
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4.20 (Non-Reasoning) - Latest, speed-optimized", "grok-4.20-non-reasoning"),
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
],
"deep": [
("Grok 4 - Flagship model", "grok-4-0709"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
],
},
"openrouter": {
"quick": [
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
],
"deep": [
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
("Grok 4.20 (Reasoning) - Latest frontier reasoning model", "grok-4.20-reasoning"),
("Grok 4 - Flagship (dated build)", "grok-4-0709"),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
("Grok 4.20 - Auto-select reasoning behavior", "grok-4.20"),
],
},
"deepseek": {
"quick": [
("DeepSeek V4 Flash - Latest V4 fast model", "deepseek-v4-flash"),
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
"deep": [
("DeepSeek V4 Pro - Latest V4 flagship model", "deepseek-v4-pro"),
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
},
# Qwen: same model IDs across global (dashscope-intl) and China
# (dashscope) endpoints, so the two provider keys share one model list.
"qwen": _QWEN_MODELS,
"qwen-cn": _QWEN_MODELS,
# GLM: Z.AI (international) and BigModel (China) host the same model
# IDs; the two provider keys share one model list.
"glm": _GLM_MODELS,
"glm-cn": _GLM_MODELS,
# MiniMax: same model IDs across global (.io) and China (.com) regions,
# so the two provider keys share one model list.
"minimax": _MINIMAX_MODELS,
"minimax-cn": _MINIMAX_MODELS,
# OpenRouter: fetched dynamically. Azure: any deployed model name.
"ollama": {
"quick": [
("Qwen3:latest (8B, local)", "qwen3:latest"),

View File

@@ -1,34 +1,157 @@
import os
from typing import Any, Optional
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from .api_key_env import get_api_key_env
from .base_client import BaseLLMClient, normalize_content
from .capabilities import get_capabilities
from .validators import validate_model
class NormalizedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with normalized content output.
"""ChatOpenAI with normalized content output and capability-aware binding.
The Responses API returns content as a list of typed blocks
(reasoning, text, etc.). This normalizes to string for consistent
downstream handling.
(reasoning, text, etc.). ``invoke`` normalizes to string for
consistent downstream handling.
``with_structured_output`` consults the per-model capability table
(``capabilities.get_capabilities``) to pick the method and to decide
whether ``tool_choice`` may be sent. Models that reject ``tool_choice``
(e.g. DeepSeek V4 and reasoner — per their official tool-calling
guide) still bind the schema as a tool, but no ``tool_choice``
parameter is sent.
Provider-specific quirks beyond structured-output (e.g. DeepSeek's
reasoning_content roundtrip) live in subclasses so this base class
stays small.
"""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
def with_structured_output(self, schema, *, method=None, **kwargs):
caps = get_capabilities(self.model_name)
if caps.preferred_structured_method == "none":
raise NotImplementedError(
f"{self.model_name} has no structured-output method available; "
f"agent factories will fall back to free-text generation."
)
method = method or caps.preferred_structured_method
# When the model rejects tool_choice, suppress langchain's hardcoded
# value. The schema is still bound as a tool — exactly what
# DeepSeek's official tool-calling examples do.
if method == "function_calling" and not caps.supports_tool_choice:
kwargs.setdefault("tool_choice", None)
return super().with_structured_output(schema, method=method, **kwargs)
def _input_to_messages(input_: Any) -> list:
"""Normalise a langchain LLM input to a list of message objects.
Accepts a list of messages, a ``ChatPromptValue`` (from a
ChatPromptTemplate), or anything else (treated as no messages).
Used by providers that need to walk the outgoing message history;
in particular DeepSeek thinking-mode propagation must work for
both bare-list invocations and ChatPromptTemplate-driven ones, so
treating only ``list`` here would silently skip half the call sites.
"""
if isinstance(input_, list):
return input_
if hasattr(input_, "to_messages"):
return input_.to_messages()
return []
class DeepSeekChatOpenAI(NormalizedChatOpenAI):
"""DeepSeek-specific overrides on top of the OpenAI-compatible client.
Thinking-mode round-trip is the only DeepSeek-specific behavior that
stays here. When DeepSeek's thinking models return a response with
``reasoning_content``, that field must be echoed back as part of the
assistant message on the next turn or the API fails with HTTP 400.
``_create_chat_result`` captures it on receive and
``_get_request_payload`` re-attaches it on send.
Tool-choice handling for V4 and reasoner — those models reject the
``tool_choice`` parameter — is handled by the capability dispatch in
``NormalizedChatOpenAI.with_structured_output``, not here.
"""
def _get_request_payload(self, input_, *, stop=None, **kwargs):
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
outgoing = payload.get("messages", [])
for message_dict, message in zip(outgoing, _input_to_messages(input_)):
if not isinstance(message, AIMessage):
continue
reasoning = message.additional_kwargs.get("reasoning_content")
if reasoning is not None:
message_dict["reasoning_content"] = reasoning
return payload
def _create_chat_result(self, response, generation_info=None):
chat_result = super()._create_chat_result(response, generation_info)
response_dict = (
response
if isinstance(response, dict)
else response.model_dump(
exclude={"choices": {"__all__": {"message": {"parsed"}}}}
)
)
for generation, choice in zip(
chat_result.generations, response_dict.get("choices", [])
):
reasoning = choice.get("message", {}).get("reasoning_content")
if reasoning is not None:
generation.message.additional_kwargs["reasoning_content"] = reasoning
return chat_result
class MinimaxChatOpenAI(NormalizedChatOpenAI):
"""MiniMax-specific overrides on top of the OpenAI-compatible client.
M2.x reasoning models embed ``<think>...</think>`` blocks directly in
``message.content`` by default, which would pollute saved reports.
Per platform.minimax.io/docs/api-reference/text-openai-api, setting
``reasoning_split=True`` in the request body redirects the thinking
block into ``reasoning_details`` so ``content`` stays clean.
Tool-choice handling for M2.x — those models accept only the string
enum ``{"none", "auto"}`` and reject langchain's function-spec dict —
is handled by the capability dispatch in
``NormalizedChatOpenAI.with_structured_output``, not here.
"""
def _get_request_payload(self, input_, *, stop=None, **kwargs):
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload.setdefault("reasoning_split", True)
return payload
# Kwargs forwarded from user config to ChatOpenAI
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "reasoning_effort",
"api_key", "callbacks", "http_client", "http_async_client",
)
# Provider base URLs and API key env vars
_PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None),
# Provider base URLs. API-key env vars live in api_key_env.PROVIDER_API_KEY_ENV
# (one canonical mapping consulted by both this client and the CLI's
# interactive key-prompt). Dual-region providers (qwen/glm/minimax) keep
# separate endpoints because international and China accounts cannot share
# credentials (#758).
_PROVIDER_BASE_URL = {
"xai": "https://api.x.ai/v1",
"deepseek": "https://api.deepseek.com",
"qwen": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"qwen-cn": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"glm": "https://api.z.ai/api/paas/v4/",
"glm-cn": "https://open.bigmodel.cn/api/paas/v4/",
"minimax": "https://api.minimax.io/v1",
"minimax-cn": "https://api.minimaxi.com/v1",
"openrouter": "https://openrouter.ai/api/v1",
"ollama": "http://localhost:11434/v1",
}
@@ -56,14 +179,22 @@ class OpenAIClient(BaseLLMClient):
self.warn_if_unknown_model()
llm_kwargs = {"model": self.model}
# Provider-specific base URL and auth
if self.provider in _PROVIDER_CONFIG:
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = base_url
# Provider-specific base URL and auth. An explicit base_url on the
# client (e.g. a corporate proxy) takes precedence over the
# provider default so users can route through their own gateway.
if self.provider in _PROVIDER_BASE_URL:
llm_kwargs["base_url"] = self.base_url or _PROVIDER_BASE_URL[self.provider]
api_key_env = get_api_key_env(self.provider)
if api_key_env:
api_key = os.environ.get(api_key_env)
if api_key:
llm_kwargs["api_key"] = api_key
else:
raise ValueError(
f"API key for provider '{self.provider}' is not set. "
f"Please set the {api_key_env} environment variable "
f"(e.g. add {api_key_env}=your_key to your .env file)."
)
else:
llm_kwargs["api_key"] = "ollama"
elif self.base_url:
@@ -79,7 +210,15 @@ class OpenAIClient(BaseLLMClient):
if self.provider == "openai":
llm_kwargs["use_responses_api"] = True
return NormalizedChatOpenAI(**llm_kwargs)
# Provider-specific quirks live in their own subclasses so the
# base NormalizedChatOpenAI stays free of provider branches.
if self.provider == "deepseek":
chat_cls = DeepSeekChatOpenAI
elif self.provider in ("minimax", "minimax-cn"):
chat_cls = MinimaxChatOpenAI
else:
chat_cls = NormalizedChatOpenAI
return chat_cls(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for the provider."""