mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
Merge remote-tracking branch 'upstream/main' into crypto-analysis-mvp
# Conflicts: # cli/utils.py # tradingagents/agents/analysts/social_media_analyst.py # tradingagents/agents/researchers/bear_researcher.py
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
import warnings
|
||||
|
||||
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
|
||||
# (and every llm_clients consumer) sees the user's keys regardless of
|
||||
# which entry point started the process. find_dotenv(usecwd=True) walks
|
||||
# from the CWD, so the installed `tradingagents` console script picks up
|
||||
# the project's .env instead of stepping up from site-packages.
|
||||
# load_dotenv defaults to override=False, so it never clobbers values
|
||||
# the caller has already exported.
|
||||
try:
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
||||
load_dotenv(find_dotenv(usecwd=True))
|
||||
load_dotenv(find_dotenv(".env.enterprise", usecwd=True), override=False)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# langchain-core 1.3.3 calls surface_langchain_deprecation_warnings() in
|
||||
# its own __init__, which prepends default-action filters for its
|
||||
# subclassed warning categories. To suppress a specific warning we must
|
||||
# install our filter AFTER langchain-core has installed its own, so import
|
||||
# it first. The package is a guaranteed transitive dep via langgraph.
|
||||
try:
|
||||
import langchain_core # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
|
||||
# explicit allowed_objects, which triggers a noisy pending-deprecation
|
||||
# warning from langchain-core 1.3.3 on every interpreter start. The fix
|
||||
# is already merged upstream (langchain-ai/langgraph#7743, 2026-05-08)
|
||||
# and will arrive in the next langgraph-checkpoint release. Remove this
|
||||
# block (and the langchain_core preload above) when we bump past it.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"The default value of `allowed_objects`.*",
|
||||
category=PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,10 @@ from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||
from .analysts.market_analyst import create_market_analyst
|
||||
from .analysts.news_analyst import create_news_analyst
|
||||
from .analysts.social_media_analyst import create_social_media_analyst
|
||||
from .analysts.sentiment_analyst import (
|
||||
create_sentiment_analyst,
|
||||
create_social_media_analyst, # deprecated alias kept for back-compat
|
||||
)
|
||||
|
||||
from .researchers.bear_researcher import create_bear_researcher
|
||||
from .researchers.bull_researcher import create_bull_researcher
|
||||
@@ -33,6 +36,7 @@ __all__ = [
|
||||
"create_aggressive_debator",
|
||||
"create_portfolio_manager",
|
||||
"create_conservative_debator",
|
||||
"create_social_media_analyst",
|
||||
"create_sentiment_analyst",
|
||||
"create_social_media_analyst", # deprecated; will be removed in a future version
|
||||
"create_trader",
|
||||
]
|
||||
|
||||
184
tradingagents/agents/analysts/sentiment_analyst.py
Normal file
184
tradingagents/agents/analysts/sentiment_analyst.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Sentiment analyst — multi-source sentiment analysis for a target ticker.
|
||||
|
||||
Previously named ``social_media_analyst``. Renamed and redesigned because
|
||||
the old version had a prompt that demanded social-media analysis but the
|
||||
only tool available was Yahoo Finance news — which led LLMs to fabricate
|
||||
Reddit/X/StockTwits content under prompt pressure (verified live).
|
||||
|
||||
The redesigned agent pre-fetches three complementary data sources before
|
||||
the LLM is invoked and injects them into the prompt as structured blocks:
|
||||
|
||||
1. News headlines — Yahoo Finance (institutional framing)
|
||||
2. StockTwits messages — retail-trader posts indexed by cashtag, with
|
||||
user-labeled Bullish/Bearish sentiment tags
|
||||
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
|
||||
|
||||
The agent does not use tool-calling; the data is in the prompt from
|
||||
turn 0. The LLM produces the sentiment report in a single invocation.
|
||||
|
||||
See: https://github.com/TauricResearch/TradingAgents/issues/557
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_language_instruction,
|
||||
get_news,
|
||||
)
|
||||
from tradingagents.dataflows.reddit import fetch_reddit_posts
|
||||
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
|
||||
|
||||
|
||||
def _seven_days_back(trade_date: str) -> str:
|
||||
return (datetime.strptime(trade_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def create_sentiment_analyst(llm):
|
||||
"""Create a sentiment analyst node for the trading graph.
|
||||
|
||||
Pre-fetches news + StockTwits + Reddit data, injects them into the
|
||||
prompt as structured blocks, and produces a sentiment report in a
|
||||
single LLM call.
|
||||
"""
|
||||
|
||||
def sentiment_analyst_node(state):
|
||||
ticker = state["company_of_interest"]
|
||||
end_date = state["trade_date"]
|
||||
start_date = _seven_days_back(end_date)
|
||||
instrument_context = build_instrument_context(ticker)
|
||||
|
||||
# Pre-fetch all three sources. Each fetcher degrades gracefully and
|
||||
# returns a string (no exceptions surface from here), so the LLM
|
||||
# always sees something — either real data or a clear placeholder.
|
||||
news_block = get_news.func(ticker, start_date, end_date)
|
||||
stocktwits_block = fetch_stocktwits_messages(ticker, limit=30)
|
||||
reddit_block = fetch_reddit_posts(ticker)
|
||||
|
||||
system_message = _build_system_message(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
news_block=news_block,
|
||||
stocktwits_block=stocktwits_block,
|
||||
reddit_block=reddit_block,
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
"\n{system_message}\n"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(current_date=end_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
# No bind_tools — the data is already in the prompt; a single LLM
|
||||
# call produces the report directly.
|
||||
chain = prompt | llm
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"sentiment_report": result.content,
|
||||
}
|
||||
|
||||
return sentiment_analyst_node
|
||||
|
||||
|
||||
def _build_system_message(
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
news_block: str,
|
||||
stocktwits_block: str,
|
||||
reddit_block: str,
|
||||
) -> str:
|
||||
"""Assemble the sentiment-analyst system message with structured data blocks."""
|
||||
return f"""You are a financial market sentiment analyst. Your task is to produce a comprehensive sentiment report for {ticker} covering the period from {start_date} to {end_date}, drawing on three complementary data sources that have already been collected for you.
|
||||
|
||||
## Data sources (pre-fetched, in this prompt)
|
||||
|
||||
### News headlines — Yahoo Finance, past 7 days
|
||||
Institutional framing. Fact-driven, slower-moving signal.
|
||||
|
||||
<start_of_news>
|
||||
{news_block}
|
||||
<end_of_news>
|
||||
|
||||
### StockTwits messages — retail-trader social platform indexed by cashtag
|
||||
Fast-moving signal. Each message carries a user-labeled sentiment tag (Bullish / Bearish / no-label) plus the message body.
|
||||
|
||||
<start_of_stocktwits>
|
||||
{stocktwits_block}
|
||||
<end_of_stocktwits>
|
||||
|
||||
### Reddit posts — r/wallstreetbets, r/stocks, r/investing (past 7 days)
|
||||
Community discussion. Engagement signal via upvote score and comment count. Subreddit character matters (r/wallstreetbets is often contrarian/exuberant; r/stocks more measured; r/investing longer-term).
|
||||
|
||||
<start_of_reddit>
|
||||
{reddit_block}
|
||||
<end_of_reddit>
|
||||
|
||||
## How to analyze this data (best practices)
|
||||
|
||||
1. **Read the StockTwits Bullish/Bearish ratio as a leading retail-sentiment signal.** A 70/30 bullish/bearish split is moderately bullish; ≥90/10 may indicate over-extension and contrarian risk; 50/50 is uncertainty. Sample size matters — base rates on the actual message count, not percentages alone.
|
||||
|
||||
2. **Look for cross-source divergences.** If news framing is bearish but StockTwits is overwhelmingly bullish, that mismatch is itself a signal — it can mean retail is leaning into a thesis the news flow hasn't caught up to (or vice versa, that retail is chasing while institutions are cautious).
|
||||
|
||||
3. **Weight Reddit posts by engagement.** A 400-upvote / 200-comment thread reflects community attention; a 3-upvote post is noise. Read the body excerpts for context — the title alone often misleads.
|
||||
|
||||
4. **Distinguish opinion from event.** A news headline ("Nvidia announces $500M Corning deal") is an event; a StockTwits post ("buying NVDA, this is going to moon") is opinion. Both are inputs but should be weighted differently in your conclusions.
|
||||
|
||||
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
|
||||
|
||||
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this caveat explicitly. If the sources are silent on a given subreddit, say so.
|
||||
|
||||
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
|
||||
|
||||
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
|
||||
|
||||
## Output
|
||||
|
||||
Produce a sentiment report covering, in order:
|
||||
|
||||
1. **Overall sentiment direction** — Bullish / Bearish / Neutral / Mixed — with a brief confidence note based on data quality and sample size.
|
||||
2. **Source-by-source breakdown** — what each of news / StockTwits / Reddit is telling you, with specific evidence (cite message counts, ratios, notable posts).
|
||||
3. **Divergences, alignments, and key narratives** across sources.
|
||||
4. **Catalysts and risks** surfaced by the data.
|
||||
5. **Markdown table** at the end summarizing key sentiment signals, their direction, source, and supporting evidence.
|
||||
|
||||
{get_language_instruction()}"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backwards-compatibility shim
|
||||
# ---------------------------------------------------------------------------
|
||||
def create_social_media_analyst(llm):
|
||||
"""Deprecated alias for :func:`create_sentiment_analyst`.
|
||||
|
||||
Kept so existing code that imports ``create_social_media_analyst``
|
||||
continues to work.
|
||||
|
||||
.. deprecated::
|
||||
Import :func:`create_sentiment_analyst` directly instead.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"create_social_media_analyst is deprecated and will be removed in a "
|
||||
"future version. Use create_sentiment_analyst instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return create_sentiment_analyst(llm)
|
||||
@@ -1,61 +1,23 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
||||
from tradingagents.dataflows.config import get_config
|
||||
"""Backwards-compatibility shim for the renamed module.
|
||||
|
||||
The agent is now ``sentiment_analyst`` and aggregates Yahoo Finance news,
|
||||
StockTwits cashtag streams, and Reddit posts into a single sentiment
|
||||
report. Import from ``tradingagents.agents.analysts.sentiment_analyst``
|
||||
going forward; this module will be removed in a future release.
|
||||
|
||||
def create_social_media_analyst(llm):
|
||||
def social_media_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
asset_type = state.get("asset_type", "stock")
|
||||
subject_label = "company" if asset_type == "stock" else "asset"
|
||||
instrument_context = build_instrument_context(
|
||||
state["company_of_interest"], asset_type
|
||||
)
|
||||
See: https://github.com/TauricResearch/TradingAgents/issues/557
|
||||
"""
|
||||
|
||||
tools = [
|
||||
get_news,
|
||||
]
|
||||
import warnings as _warnings
|
||||
|
||||
system_message = (
|
||||
f"You are a social media and targeted news researcher/analyst tasked with analyzing social media posts, recent {subject_label} news, and public sentiment for a specific {subject_label} over the past week. You will be given an asset identifier and your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors after looking at social media, sentiment, and recent news related to that {subject_label}. Use the get_news(query, start_date, end_date) tool to search for {subject_label}-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
+ get_language_instruction()
|
||||
)
|
||||
from tradingagents.agents.analysts.sentiment_analyst import ( # noqa: F401
|
||||
create_sentiment_analyst,
|
||||
create_social_media_analyst,
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"sentiment_report": report,
|
||||
}
|
||||
|
||||
return social_media_analyst_node
|
||||
_warnings.warn(
|
||||
"tradingagents.agents.analysts.social_media_analyst is deprecated. "
|
||||
"Import from tradingagents.agents.analysts.sentiment_analyst instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tradingagents.agents.schemas import ResearchPlan, render_research_plan
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_language_instruction,
|
||||
)
|
||||
from tradingagents.agents.utils.structured import (
|
||||
bind_structured,
|
||||
invoke_structured_or_freetext,
|
||||
@@ -37,7 +40,7 @@ Commit to a clear stance whenever the debate's strongest arguments warrant one;
|
||||
---
|
||||
|
||||
**Debate History:**
|
||||
{history}"""
|
||||
{history}""" + get_language_instruction()
|
||||
|
||||
investment_plan = invoke_structured_or_freetext(
|
||||
structured_llm,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||
|
||||
|
||||
def create_bear_researcher(llm):
|
||||
@@ -38,7 +39,7 @@ Latest world affairs news: {news_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bull argument: {current_response}
|
||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the {target_label}.
|
||||
"""
|
||||
""" + get_language_instruction()
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||
|
||||
|
||||
def create_bull_researcher(llm):
|
||||
@@ -36,7 +37,7 @@ Latest world affairs news: {news_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bear argument: {current_response}
|
||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position.
|
||||
"""
|
||||
""" + get_language_instruction()
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||
|
||||
|
||||
def create_aggressive_debator(llm):
|
||||
@@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||
|
||||
|
||||
def create_conservative_debator(llm):
|
||||
@@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
@@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
|
||||
@@ -7,7 +7,10 @@ import functools
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from tradingagents.agents.schemas import TraderProposal, render_trader_proposal
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_language_instruction,
|
||||
)
|
||||
from tradingagents.agents.utils.structured import (
|
||||
bind_structured,
|
||||
invoke_structured_or_freetext,
|
||||
@@ -30,6 +33,7 @@ def create_trader(llm):
|
||||
"You are a trading agent analyzing market data to make investment decisions. "
|
||||
"Based on your analysis, provide a specific recommendation to buy, sell, or hold. "
|
||||
"Anchor your reasoning in the analysts' reports and the research plan."
|
||||
+ get_language_instruction()
|
||||
),
|
||||
},
|
||||
{
|
||||
|
||||
@@ -52,7 +52,7 @@ class AgentState(MessagesState):
|
||||
|
||||
# research step
|
||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
||||
sentiment_report: Annotated[str, "Report from the Sentiment Analyst"]
|
||||
news_report: Annotated[
|
||||
str, "Report from the News Researcher of current world affairs"
|
||||
]
|
||||
|
||||
@@ -24,8 +24,10 @@ def get_language_instruction() -> str:
|
||||
"""Return a prompt instruction for the configured output language.
|
||||
|
||||
Returns empty string when English (default), so no extra tokens are used.
|
||||
Only applied to user-facing agents (analysts, portfolio manager).
|
||||
Internal debate agents stay in English for reasoning quality.
|
||||
Applied to every agent whose output reaches the saved report —
|
||||
analysts, researchers, debaters, research manager, trader, and
|
||||
portfolio manager — so a non-English run produces a fully localized
|
||||
report rather than a mix of languages.
|
||||
"""
|
||||
from tradingagents.dataflows.config import get_config
|
||||
lang = get_config().get("output_language", "English")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Optional
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
@tool
|
||||
@@ -23,16 +23,20 @@ def get_news(
|
||||
@tool
|
||||
def get_global_news(
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "Number of days to look back"] = 7,
|
||||
limit: Annotated[int, "Maximum number of articles to return"] = 5,
|
||||
look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None,
|
||||
limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global news data.
|
||||
Uses the configured news_data vendor.
|
||||
Uses the configured news_data vendor. Defaults for look_back_days and
|
||||
limit come from DEFAULT_CONFIG (global_news_lookback_days,
|
||||
global_news_article_limit); pass explicit values to override.
|
||||
|
||||
Args:
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
look_back_days (int): Number of days to look back (default 7)
|
||||
limit (int): Maximum number of articles to return (default 5)
|
||||
look_back_days (int): Number of days to look back; omit to inherit config
|
||||
limit (int): Maximum number of articles to return; omit to inherit config
|
||||
|
||||
Returns:
|
||||
str: A formatted string containing global news data
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import tradingagents.default_config as default_config
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional
|
||||
|
||||
import tradingagents.default_config as default_config
|
||||
|
||||
# Use default config but allow it to be overridden
|
||||
_config: Optional[Dict] = None
|
||||
|
||||
@@ -9,22 +11,31 @@ def initialize_config():
|
||||
"""Initialize the configuration with default values."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
"""Update the configuration with custom values."""
|
||||
"""Update the configuration with custom values.
|
||||
|
||||
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
|
||||
partial update like ``{"data_vendors": {"core_stock_apis": "alpha_vantage"}}``
|
||||
keeps the other nested keys from the default; scalar keys are replaced.
|
||||
"""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config.update(config)
|
||||
initialize_config()
|
||||
incoming = deepcopy(config)
|
||||
for key, value in incoming.items():
|
||||
if isinstance(value, dict) and isinstance(_config.get(key), dict):
|
||||
_config[key].update(value)
|
||||
else:
|
||||
_config[key] = value
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
"""Get the current configuration."""
|
||||
if _config is None:
|
||||
initialize_config()
|
||||
return _config.copy()
|
||||
return deepcopy(_config)
|
||||
|
||||
|
||||
# Initialize with default config
|
||||
|
||||
106
tradingagents/dataflows/reddit.py
Normal file
106
tradingagents/dataflows/reddit.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Reddit search fetcher for ticker-specific discussion posts.
|
||||
|
||||
Uses Reddit's public JSON endpoints (``reddit.com/r/{sub}/search.json``)
|
||||
which do not require an API key. Public throughput is ~10 requests per
|
||||
minute per IP, well within budget for a single agent run that queries
|
||||
a handful of finance subreddits per ticker.
|
||||
|
||||
Returns formatted plaintext blocks ready for prompt injection. Degrades
|
||||
gracefully — returns a placeholder string rather than raising, so callers
|
||||
never have to special-case missing data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterable
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import urlencode
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_API = "https://www.reddit.com/r/{sub}/search.json?{qs}"
|
||||
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
|
||||
|
||||
# Default subreddits ordered roughly by signal density for ticker-specific
|
||||
# discussion. wallstreetbets has the most volume but most noise; stocks /
|
||||
# investing trend more measured. Caller can override.
|
||||
DEFAULT_SUBREDDITS = ("wallstreetbets", "stocks", "investing")
|
||||
|
||||
|
||||
def _fetch_subreddit(
|
||||
ticker: str,
|
||||
sub: str,
|
||||
limit: int,
|
||||
timeout: float,
|
||||
) -> list[dict]:
|
||||
qs = urlencode({
|
||||
"q": ticker,
|
||||
"restrict_sr": "on",
|
||||
"sort": "new",
|
||||
"t": "week", # last 7 days
|
||||
"limit": limit,
|
||||
})
|
||||
url = _API.format(sub=sub, qs=qs)
|
||||
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
|
||||
try:
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
payload = json.loads(resp.read())
|
||||
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
|
||||
logger.warning("Reddit fetch failed for r/%s · %s: %s", sub, ticker, exc)
|
||||
return []
|
||||
children = (payload.get("data") or {}).get("children") or []
|
||||
return [c.get("data", {}) for c in children if isinstance(c, dict)]
|
||||
|
||||
|
||||
def fetch_reddit_posts(
|
||||
ticker: str,
|
||||
subreddits: Iterable[str] = DEFAULT_SUBREDDITS,
|
||||
limit_per_sub: int = 5,
|
||||
timeout: float = 10.0,
|
||||
inter_request_delay: float = 0.4,
|
||||
) -> str:
|
||||
"""Fetch recent Reddit posts mentioning ``ticker`` across finance
|
||||
subreddits and return them as a formatted plaintext block.
|
||||
|
||||
``inter_request_delay`` keeps us under Reddit's public rate limit
|
||||
(~10 req/min per IP) even if the caller queries many subreddits.
|
||||
"""
|
||||
blocks = []
|
||||
total_posts = 0
|
||||
for i, sub in enumerate(subreddits):
|
||||
if i > 0:
|
||||
time.sleep(inter_request_delay)
|
||||
posts = _fetch_subreddit(ticker, sub, limit_per_sub, timeout)
|
||||
total_posts += len(posts)
|
||||
if not posts:
|
||||
blocks.append(f"r/{sub}: <no posts found mentioning {ticker.upper()} in the past 7 days>")
|
||||
continue
|
||||
|
||||
lines = [f"r/{sub} — {len(posts)} recent posts mentioning {ticker.upper()}:"]
|
||||
for p in posts:
|
||||
title = (p.get("title") or "").replace("\n", " ").strip()
|
||||
score = p.get("score", 0)
|
||||
comments = p.get("num_comments", 0)
|
||||
created = p.get("created_utc")
|
||||
created_str = (
|
||||
time.strftime("%Y-%m-%d", time.gmtime(created)) if created else "?"
|
||||
)
|
||||
selftext = (p.get("selftext") or "").replace("\n", " ").strip()
|
||||
if len(selftext) > 240:
|
||||
selftext = selftext[:240] + "…"
|
||||
lines.append(
|
||||
f" [{created_str} · {score:>4}↑ · {comments:>3}c] {title}"
|
||||
+ (f"\n body excerpt: {selftext}" if selftext else "")
|
||||
)
|
||||
blocks.append("\n".join(lines))
|
||||
|
||||
if total_posts == 0:
|
||||
return (
|
||||
f"<no Reddit posts found mentioning {ticker.upper()} across "
|
||||
f"{', '.join(f'r/{s}' for s in subreddits)} in the past 7 days>"
|
||||
)
|
||||
return "\n\n".join(blocks)
|
||||
83
tradingagents/dataflows/stocktwits.py
Normal file
83
tradingagents/dataflows/stocktwits.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""StockTwits public symbol-stream fetcher.
|
||||
|
||||
StockTwits exposes a per-symbol message stream at
|
||||
``api.stocktwits.com/api/2/streams/symbol/{ticker}.json`` that requires no
|
||||
API key, no OAuth, and no registration. Each message includes a
|
||||
user-labeled sentiment field (``Bullish``/``Bearish``/null), the message
|
||||
body, timestamp, and posting user.
|
||||
|
||||
The function is deliberately self-contained: short timeout, graceful
|
||||
degradation on any HTTP or parse failure, and a string return type so
|
||||
the calling agent gets a uniform interface regardless of whether the
|
||||
network call succeeded.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_API = "https://api.stocktwits.com/api/2/streams/symbol/{ticker}.json"
|
||||
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
|
||||
|
||||
|
||||
def fetch_stocktwits_messages(ticker: str, limit: int = 30, timeout: float = 10.0) -> str:
|
||||
"""Fetch recent StockTwits messages for ``ticker`` and return them as a
|
||||
formatted plaintext block ready for prompt injection.
|
||||
|
||||
Returns a placeholder string when the endpoint is unreachable, the
|
||||
symbol has no messages, or the response shape is unexpected — the
|
||||
caller never has to special-case None or exceptions.
|
||||
"""
|
||||
url = _API.format(ticker=ticker.upper())
|
||||
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
|
||||
try:
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read())
|
||||
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
|
||||
logger.warning("StockTwits fetch failed for %s: %s", ticker, exc)
|
||||
return f"<stocktwits unavailable: {type(exc).__name__}>"
|
||||
|
||||
messages = data.get("messages", []) if isinstance(data, dict) else []
|
||||
if not messages:
|
||||
return f"<no StockTwits messages found for ${ticker.upper()}>"
|
||||
|
||||
lines = []
|
||||
bullish = bearish = unlabeled = 0
|
||||
for m in messages[:limit]:
|
||||
created = m.get("created_at", "")
|
||||
user = (m.get("user") or {}).get("username", "?")
|
||||
entities = m.get("entities") or {}
|
||||
sentiment_obj = entities.get("sentiment") or {}
|
||||
sentiment = sentiment_obj.get("basic") if isinstance(sentiment_obj, dict) else None
|
||||
body = (m.get("body") or "").replace("\n", " ").strip()
|
||||
if len(body) > 280:
|
||||
body = body[:280] + "…"
|
||||
|
||||
if sentiment == "Bullish":
|
||||
bullish += 1
|
||||
tag = "Bullish"
|
||||
elif sentiment == "Bearish":
|
||||
bearish += 1
|
||||
tag = "Bearish"
|
||||
else:
|
||||
unlabeled += 1
|
||||
tag = "no-label"
|
||||
lines.append(f"[{created} · @{user} · {tag}] {body}")
|
||||
|
||||
total = bullish + bearish + unlabeled
|
||||
bull_pct = round(100 * bullish / total) if total else 0
|
||||
bear_pct = round(100 * bearish / total) if total else 0
|
||||
summary = (
|
||||
f"Bullish: {bullish} ({bull_pct}%) · "
|
||||
f"Bearish: {bearish} ({bear_pct}%) · "
|
||||
f"Unlabeled: {unlabeled} · "
|
||||
f"Total: {total} most-recent messages"
|
||||
)
|
||||
return summary + "\n\n" + "\n".join(lines)
|
||||
@@ -1,9 +1,12 @@
|
||||
"""yfinance-based news data fetching functions."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import yfinance as yf
|
||||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from .config import get_config
|
||||
from .stockstats_utils import yf_retry
|
||||
|
||||
|
||||
@@ -64,9 +67,10 @@ def get_news_yfinance(
|
||||
Returns:
|
||||
Formatted string containing news articles
|
||||
"""
|
||||
article_limit = get_config()["news_article_limit"]
|
||||
try:
|
||||
stock = yf.Ticker(ticker)
|
||||
news = yf_retry(lambda: stock.get_news(count=20))
|
||||
news = yf_retry(lambda: stock.get_news(count=article_limit))
|
||||
|
||||
if not news:
|
||||
return f"No news found for {ticker}"
|
||||
@@ -106,27 +110,28 @@ def get_news_yfinance(
|
||||
|
||||
def get_global_news_yfinance(
|
||||
curr_date: str,
|
||||
look_back_days: int = 7,
|
||||
limit: int = 10,
|
||||
look_back_days: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global/macro economic news using yfinance Search.
|
||||
|
||||
Args:
|
||||
curr_date: Current date in yyyy-mm-dd format
|
||||
look_back_days: Number of days to look back
|
||||
limit: Maximum number of articles to return
|
||||
look_back_days: Number of days to look back. ``None`` falls back to
|
||||
``global_news_lookback_days`` from the active config.
|
||||
limit: Maximum number of articles to return. ``None`` falls back to
|
||||
``global_news_article_limit`` from the active config.
|
||||
|
||||
Returns:
|
||||
Formatted string containing global news articles
|
||||
"""
|
||||
# Search queries for macro/global news
|
||||
search_queries = [
|
||||
"stock market economy",
|
||||
"Federal Reserve interest rates",
|
||||
"inflation economic outlook",
|
||||
"global markets trading",
|
||||
]
|
||||
config = get_config()
|
||||
if look_back_days is None:
|
||||
look_back_days = config["global_news_lookback_days"]
|
||||
if limit is None:
|
||||
limit = config["global_news_article_limit"]
|
||||
search_queries = config["global_news_queries"]
|
||||
|
||||
all_news = []
|
||||
seen_titles = set()
|
||||
|
||||
@@ -2,7 +2,45 @@ import os
|
||||
|
||||
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# Single source of truth for env-var → config-key overrides. To expose
|
||||
# a new config key for environment-based override, add a row here — no
|
||||
# entry-point script changes required. Coercion is driven by the type
|
||||
# of the existing default, so users can keep writing plain strings in
|
||||
# their .env file.
|
||||
_ENV_OVERRIDES = {
|
||||
"TRADINGAGENTS_LLM_PROVIDER": "llm_provider",
|
||||
"TRADINGAGENTS_DEEP_THINK_LLM": "deep_think_llm",
|
||||
"TRADINGAGENTS_QUICK_THINK_LLM": "quick_think_llm",
|
||||
"TRADINGAGENTS_LLM_BACKEND_URL": "backend_url",
|
||||
"TRADINGAGENTS_OUTPUT_LANGUAGE": "output_language",
|
||||
"TRADINGAGENTS_MAX_DEBATE_ROUNDS": "max_debate_rounds",
|
||||
"TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds",
|
||||
"TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled",
|
||||
}
|
||||
|
||||
|
||||
def _coerce(value: str, reference):
|
||||
"""Coerce env-var string to the type of the existing default value."""
|
||||
if isinstance(reference, bool):
|
||||
return value.strip().lower() in ("true", "1", "yes", "on")
|
||||
if isinstance(reference, int) and not isinstance(reference, bool):
|
||||
return int(value)
|
||||
if isinstance(reference, float):
|
||||
return float(value)
|
||||
return value
|
||||
|
||||
|
||||
def _apply_env_overrides(config: dict) -> dict:
|
||||
"""Apply TRADINGAGENTS_* env vars to the config dict in-place."""
|
||||
for env_var, key in _ENV_OVERRIDES.items():
|
||||
raw = os.environ.get(env_var)
|
||||
if raw is None or raw == "":
|
||||
continue
|
||||
config[key] = _coerce(raw, config.get(key))
|
||||
return config
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _apply_env_overrides({
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
||||
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
||||
@@ -35,6 +73,21 @@ DEFAULT_CONFIG = {
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"max_recur_limit": 100,
|
||||
# News / data fetching parameters
|
||||
# Increase for longer lookback strategies or to broaden macro coverage;
|
||||
# decrease to reduce token usage in agent prompts.
|
||||
"news_article_limit": 20, # max articles per ticker (ticker-news)
|
||||
"global_news_article_limit": 10, # max articles for global/macro news
|
||||
"global_news_lookback_days": 7, # macro news lookback window
|
||||
# Search queries used by get_global_news for macro headlines. Extend or
|
||||
# replace to broaden geographic / sector coverage.
|
||||
"global_news_queries": [
|
||||
"Federal Reserve interest rates inflation",
|
||||
"S&P 500 earnings GDP economic outlook",
|
||||
"geopolitical risk trade war sanctions",
|
||||
"ECB Bank of England BOJ central bank policy",
|
||||
"oil commodities supply chain energy",
|
||||
],
|
||||
# Data vendor configuration
|
||||
# Category-level configuration (default for all tools in category)
|
||||
"data_vendors": {
|
||||
@@ -47,4 +100,4 @@ DEFAULT_CONFIG = {
|
||||
"tool_vendors": {
|
||||
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
@@ -54,7 +54,11 @@ class GraphSetup:
|
||||
tool_nodes["market"] = self.tool_nodes["market"]
|
||||
|
||||
if "social" in selected_analysts:
|
||||
analyst_nodes["social"] = create_social_media_analyst(
|
||||
# "social" selector key preserved for back-compat with existing
|
||||
# user configs; the underlying agent has been renamed to
|
||||
# sentiment_analyst (the old name advertised social-media data
|
||||
# the agent never had access to — see issue #557).
|
||||
analyst_nodes["social"] = create_sentiment_analyst(
|
||||
self.quick_thinking_llm
|
||||
)
|
||||
delete_nodes["social"] = create_msg_delete()
|
||||
|
||||
@@ -116,7 +116,9 @@ class TradingAgentsGraph:
|
||||
self.conditional_logic,
|
||||
)
|
||||
|
||||
self.propagator = Propagator()
|
||||
self.propagator = Propagator(
|
||||
max_recur_limit=self.config.get("max_recur_limit", 100),
|
||||
)
|
||||
self.reflector = Reflector(self.quick_thinking_llm)
|
||||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||||
|
||||
@@ -322,7 +324,11 @@ class TradingAgentsGraph:
|
||||
else:
|
||||
chunk["messages"][-1].pretty_print()
|
||||
trace.append(chunk)
|
||||
final_state = trace[-1]
|
||||
# Streamed chunks are per-node deltas. Merge them so the returned
|
||||
# state matches what graph.invoke() yields in the non-debug path.
|
||||
final_state = {}
|
||||
for chunk in trace:
|
||||
final_state.update(chunk)
|
||||
else:
|
||||
final_state = self.graph.invoke(init_agent_state, **args)
|
||||
|
||||
|
||||
44
tradingagents/llm_clients/api_key_env.py
Normal file
44
tradingagents/llm_clients/api_key_env.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Canonical provider -> API-key env-var mapping.
|
||||
|
||||
A single source of truth for which environment variable holds the API
|
||||
key for each supported LLM provider. Used by the CLI's interactive key
|
||||
prompt (cli/utils.ensure_api_key) and by anything else that needs to
|
||||
ask "does this provider require a key, and which env var is it?".
|
||||
|
||||
When adding a new provider, register its env var here so the CLI flow
|
||||
prompts for it automatically instead of failing on first API call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
"azure": "AZURE_OPENAI_API_KEY",
|
||||
"xai": "XAI_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
# Dual-region providers each carry their own account; keys are not
|
||||
# interchangeable between the international and China endpoints.
|
||||
"qwen": "DASHSCOPE_API_KEY",
|
||||
"qwen-cn": "DASHSCOPE_CN_API_KEY",
|
||||
"glm": "ZHIPU_API_KEY",
|
||||
"glm-cn": "ZHIPU_CN_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"minimax-cn": "MINIMAX_CN_API_KEY",
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
# Local runtimes do not authenticate.
|
||||
"ollama": None,
|
||||
}
|
||||
|
||||
|
||||
def get_api_key_env(provider: str) -> Optional[str]:
|
||||
"""Return the env var name for `provider`'s API key, or None if not applicable.
|
||||
|
||||
Unknown providers also return None — callers should treat that as
|
||||
"no key check possible" rather than as "no key required".
|
||||
"""
|
||||
return PROVIDER_API_KEY_ENV.get(provider.lower())
|
||||
120
tradingagents/llm_clients/capabilities.py
Normal file
120
tradingagents/llm_clients/capabilities.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Declarative per-model capability table for OpenAI-compatible providers.
|
||||
|
||||
This is the single place that knows which model IDs reject which API
|
||||
parameters or require which structured-output method. The LLM client
|
||||
subclasses consult ``get_capabilities(model_name)`` instead of hardcoding
|
||||
model-name ``if`` ladders, so adding a new model (or a new provider quirk)
|
||||
means editing this table — not the client code.
|
||||
|
||||
Pattern adapted from the per-model ``compat:`` flags DeepSeek themselves
|
||||
publish in their integration guides (e.g. the Oh My Pi config schema
|
||||
documents ``supportsToolChoice``, ``requiresReasoningContentForToolCalls``
|
||||
as declarative per-model fields).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
StructuredMethod = Literal[
|
||||
"function_calling", # uses tools; respects supports_tool_choice
|
||||
"json_mode", # uses response_format={"type":"json_object"}
|
||||
"json_schema", # uses response_format={"type":"json_schema",...}
|
||||
"none", # no structured output available; caller falls back to free-text
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelCapabilities:
|
||||
"""What an OpenAI-compatible model accepts at the API level."""
|
||||
|
||||
supports_tool_choice: bool
|
||||
supports_json_mode: bool
|
||||
supports_json_schema: bool
|
||||
preferred_structured_method: StructuredMethod
|
||||
# DeepSeek thinking-mode models 400 if reasoning_content from prior
|
||||
# assistant turns is not echoed back on the next request.
|
||||
requires_reasoning_content_roundtrip: bool = False
|
||||
|
||||
|
||||
# DeepSeek's thinking models accept the ``tools`` array but reject the
|
||||
# ``tool_choice`` parameter (official Oh My Pi integration guide and the
|
||||
# 400 response in issue #678). Their official tool-calling examples
|
||||
# (api-docs.deepseek.com/guides/tool_calls) pass ``tools=[...]`` without
|
||||
# ``tool_choice`` — we mirror that pattern by setting supports_tool_choice
|
||||
# to False and letting the client suppress the kwarg.
|
||||
_DEEPSEEK_THINKING = ModelCapabilities(
|
||||
supports_tool_choice=False,
|
||||
supports_json_mode=True,
|
||||
supports_json_schema=False,
|
||||
preferred_structured_method="function_calling",
|
||||
requires_reasoning_content_roundtrip=True,
|
||||
)
|
||||
|
||||
_DEEPSEEK_CHAT = ModelCapabilities(
|
||||
supports_tool_choice=True,
|
||||
supports_json_mode=True,
|
||||
supports_json_schema=False,
|
||||
preferred_structured_method="function_calling",
|
||||
)
|
||||
|
||||
# MiniMax M2.x reasoning models accept the tools array, but their
|
||||
# tool_choice parameter is restricted to the enum {"none", "auto"}
|
||||
# (platform.minimax.io/docs/api-reference/text-post). Langchain's
|
||||
# function_calling path sends tool_choice as a function-spec dict, which
|
||||
# MiniMax 400s — same shape as the DeepSeek bug. supports_tool_choice=False
|
||||
# makes the dispatch in NormalizedChatOpenAI suppress the kwarg; the schema
|
||||
# still ships as a tool. json_mode response_format is only for
|
||||
# MiniMax-Text-01, not M2.x.
|
||||
_MINIMAX_THINKING = ModelCapabilities(
|
||||
supports_tool_choice=False,
|
||||
supports_json_mode=False,
|
||||
supports_json_schema=False,
|
||||
preferred_structured_method="function_calling",
|
||||
)
|
||||
|
||||
_DEFAULT = ModelCapabilities(
|
||||
supports_tool_choice=True,
|
||||
supports_json_mode=True,
|
||||
supports_json_schema=True,
|
||||
preferred_structured_method="function_calling",
|
||||
)
|
||||
|
||||
|
||||
# Exact-ID matches take precedence over pattern matches.
|
||||
_BY_ID: dict[str, ModelCapabilities] = {
|
||||
"deepseek-chat": _DEEPSEEK_CHAT,
|
||||
"deepseek-reasoner": _DEEPSEEK_THINKING,
|
||||
"deepseek-v4-flash": _DEEPSEEK_THINKING,
|
||||
"deepseek-v4-pro": _DEEPSEEK_THINKING,
|
||||
# MiniMax — full official model lineup per
|
||||
# platform.minimax.io/docs/api-reference/text-openai-api
|
||||
"MiniMax-M2.7": _MINIMAX_THINKING,
|
||||
"MiniMax-M2.7-highspeed": _MINIMAX_THINKING,
|
||||
"MiniMax-M2.5": _MINIMAX_THINKING,
|
||||
"MiniMax-M2.5-highspeed": _MINIMAX_THINKING,
|
||||
"MiniMax-M2.1": _MINIMAX_THINKING,
|
||||
"MiniMax-M2.1-highspeed": _MINIMAX_THINKING,
|
||||
"MiniMax-M2": _MINIMAX_THINKING,
|
||||
}
|
||||
|
||||
# Forward-compat patterns. New ``deepseek-v5-*`` / ``deepseek-reasoner-*``
|
||||
# or ``MiniMax-M3*`` variants inherit the thinking-mode quirks automatically.
|
||||
_BY_PATTERN: list[tuple[re.Pattern[str], ModelCapabilities]] = [
|
||||
(re.compile(r"^deepseek-v\d"), _DEEPSEEK_THINKING),
|
||||
(re.compile(r"^deepseek-reasoner"), _DEEPSEEK_THINKING),
|
||||
(re.compile(r"^MiniMax-M\d"), _MINIMAX_THINKING),
|
||||
]
|
||||
|
||||
|
||||
def get_capabilities(model_name: str) -> ModelCapabilities:
|
||||
"""Resolve capabilities by exact ID, then pattern, then default."""
|
||||
if model_name in _BY_ID:
|
||||
return _BY_ID[model_name]
|
||||
for pattern, caps in _BY_PATTERN:
|
||||
if pattern.match(model_name):
|
||||
return caps
|
||||
return _DEFAULT
|
||||
@@ -4,7 +4,11 @@ from .base_client import BaseLLMClient
|
||||
|
||||
# Providers that use the OpenAI-compatible chat completions API
|
||||
_OPENAI_COMPATIBLE = (
|
||||
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
|
||||
"openai", "xai", "deepseek",
|
||||
"qwen", "qwen-cn",
|
||||
"glm", "glm-cn",
|
||||
"minimax", "minimax-cn",
|
||||
"ollama", "openrouter",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,59 +8,124 @@ ModelOption = Tuple[str, str]
|
||||
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
|
||||
|
||||
|
||||
# Shared model list for GLM via Z.AI (international) and BigModel (China).
|
||||
# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides).
|
||||
# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}.
|
||||
_GLM_MODELS: Dict[str, List[ModelOption]] = {
|
||||
"quick": [
|
||||
("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"),
|
||||
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
|
||||
("GLM-4.5-Air - Lightweight, cost-efficient", "glm-4.5-air"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("GLM-5.1 - Latest flagship, 204K ctx", "glm-5.1"),
|
||||
("GLM-5 - Flagship, 204K ctx", "glm-5"),
|
||||
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Shared model list for Qwen's global (dashscope-intl) and CN (dashscope) endpoints.
|
||||
# Source: modelstudio.console.alibabacloud.com (Featured Models — Flagship + Cost-optimized).
|
||||
#
|
||||
# Only versioned IDs are exposed in the dropdown. The version-less aliases
|
||||
# (qwen-plus, qwen-flash) are documented by Alibaba as auto-upgrading
|
||||
# pointers ("backbone, latest, and snapshot ... have been upgraded to the
|
||||
# Qwen3 series"), which means their behavior shifts when Alibaba rotates
|
||||
# the backing model. Users who want a specific generation pick it
|
||||
# explicitly; users who really want auto-latest can enter the alias via
|
||||
# "Custom model ID".
|
||||
_QWEN_MODELS: Dict[str, List[ModelOption]] = {
|
||||
"quick": [
|
||||
("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"),
|
||||
("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("Qwen 3.6 Plus - Flagship vision-language, agentic coding SOTA", "qwen3.6-plus"),
|
||||
("Qwen 3.5 Plus - Previous-gen flagship", "qwen3.5-plus"),
|
||||
("Qwen 3 Max - Specialized for agent programming + tool use", "qwen3-max"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Shared model list for MiniMax's global and CN endpoints (same IDs).
|
||||
# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api.
|
||||
# All M2.x models share a 204,800-token context window.
|
||||
_MINIMAX_MODELS: Dict[str, List[ModelOption]] = {
|
||||
"quick": [
|
||||
("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"),
|
||||
("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"),
|
||||
("MiniMax-M2.1-highspeed - M2.1 highspeed, 204K ctx", "MiniMax-M2.1-highspeed"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("MiniMax-M2.7 - Flagship, SOTA on coding/agent benchmarks, 204K ctx", "MiniMax-M2.7"),
|
||||
("MiniMax-M2.7-highspeed - Same quality as M2.7, ~100 TPS", "MiniMax-M2.7-highspeed"),
|
||||
("MiniMax-M2.5 - Previous-gen flagship, 204K ctx", "MiniMax-M2.5"),
|
||||
("MiniMax-M2.1 - Earlier M2 line, 204K ctx", "MiniMax-M2.1"),
|
||||
("MiniMax-M2 - Base M2, 204K ctx", "MiniMax-M2"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
MODEL_OPTIONS: ProviderModeOptions = {
|
||||
"openai": {
|
||||
"quick": [
|
||||
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
|
||||
("GPT-5.4 Nano - Cheapest, high-volume tasks", "gpt-5.4-nano"),
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"),
|
||||
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
|
||||
],
|
||||
"deep": [
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"),
|
||||
("GPT-5.4 - Previous-gen frontier, 1M context, cost-effective", "gpt-5.4"),
|
||||
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
|
||||
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
|
||||
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
|
||||
("GPT-5.5 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.5-pro"),
|
||||
],
|
||||
},
|
||||
"anthropic": {
|
||||
"quick": [
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
("Claude Haiku 4.5 - Fastest with near-frontier intelligence", "claude-haiku-4-5"),
|
||||
("Claude Sonnet 4.5 - High-performance for agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
"deep": [
|
||||
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
|
||||
("Claude Opus 4.7 - Latest frontier, long-running agents and coding", "claude-opus-4-7"),
|
||||
("Claude Opus 4.6 - Frontier intelligence, agents and coding", "claude-opus-4-6"),
|
||||
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
},
|
||||
"google": {
|
||||
"quick": [
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
|
||||
("Gemini 3.1 Flash Lite - Most cost-efficient (GA)", "gemini-3.1-flash-lite"),
|
||||
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
||||
],
|
||||
"deep": [
|
||||
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 3.1 Pro - Reasoning-first, complex workflows (preview)", "gemini-3.1-pro-preview"),
|
||||
("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
],
|
||||
},
|
||||
"xai": {
|
||||
"quick": [
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
("Grok 4.20 (Non-Reasoning) - Latest, speed-optimized", "grok-4.20-non-reasoning"),
|
||||
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||
],
|
||||
"deep": [
|
||||
("Grok 4 - Flagship model", "grok-4-0709"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
("Grok 4.20 (Reasoning) - Latest frontier reasoning model", "grok-4.20-reasoning"),
|
||||
("Grok 4 - Flagship (dated build)", "grok-4-0709"),
|
||||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
("Grok 4.20 - Auto-select reasoning behavior", "grok-4.20"),
|
||||
],
|
||||
},
|
||||
"deepseek": {
|
||||
@@ -76,31 +141,18 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
"qwen": {
|
||||
"quick": [
|
||||
("Qwen 3.5 Flash", "qwen3.5-flash"),
|
||||
("Qwen Plus", "qwen-plus"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("Qwen 3.6 Plus", "qwen3.6-plus"),
|
||||
("Qwen 3.5 Plus", "qwen3.5-plus"),
|
||||
("Qwen 3 Max", "qwen3-max"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
"glm": {
|
||||
"quick": [
|
||||
("GLM-4.7", "glm-4.7"),
|
||||
("GLM-5", "glm-5"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("GLM-5.1", "glm-5.1"),
|
||||
("GLM-5", "glm-5"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
# Qwen: same model IDs across global (dashscope-intl) and China
|
||||
# (dashscope) endpoints, so the two provider keys share one model list.
|
||||
"qwen": _QWEN_MODELS,
|
||||
"qwen-cn": _QWEN_MODELS,
|
||||
# GLM: Z.AI (international) and BigModel (China) host the same model
|
||||
# IDs; the two provider keys share one model list.
|
||||
"glm": _GLM_MODELS,
|
||||
"glm-cn": _GLM_MODELS,
|
||||
# MiniMax: same model IDs across global (.io) and China (.com) regions,
|
||||
# so the two provider keys share one model list.
|
||||
"minimax": _MINIMAX_MODELS,
|
||||
"minimax-cn": _MINIMAX_MODELS,
|
||||
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||
"ollama": {
|
||||
"quick": [
|
||||
|
||||
@@ -4,31 +4,47 @@ from typing import Any, Optional
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from .api_key_env import get_api_key_env
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .capabilities import get_capabilities
|
||||
from .validators import validate_model
|
||||
|
||||
|
||||
class NormalizedChatOpenAI(ChatOpenAI):
|
||||
"""ChatOpenAI with normalized content output.
|
||||
"""ChatOpenAI with normalized content output and capability-aware binding.
|
||||
|
||||
The Responses API returns content as a list of typed blocks
|
||||
(reasoning, text, etc.). ``invoke`` normalizes to string for
|
||||
consistent downstream handling. ``with_structured_output`` defaults
|
||||
to function-calling so the Responses-API parse path is avoided
|
||||
(langchain-openai's parse path emits noisy
|
||||
PydanticSerializationUnexpectedValue warnings per call without
|
||||
affecting correctness).
|
||||
consistent downstream handling.
|
||||
|
||||
Provider-specific quirks (e.g. DeepSeek's thinking mode) live in
|
||||
purpose-built subclasses below so this base class stays small.
|
||||
``with_structured_output`` consults the per-model capability table
|
||||
(``capabilities.get_capabilities``) to pick the method and to decide
|
||||
whether ``tool_choice`` may be sent. Models that reject ``tool_choice``
|
||||
(e.g. DeepSeek V4 and reasoner — per their official tool-calling
|
||||
guide) still bind the schema as a tool, but no ``tool_choice``
|
||||
parameter is sent.
|
||||
|
||||
Provider-specific quirks beyond structured-output (e.g. DeepSeek's
|
||||
reasoning_content roundtrip) live in subclasses so this base class
|
||||
stays small.
|
||||
"""
|
||||
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
def with_structured_output(self, schema, *, method=None, **kwargs):
|
||||
if method is None:
|
||||
method = "function_calling"
|
||||
caps = get_capabilities(self.model_name)
|
||||
if caps.preferred_structured_method == "none":
|
||||
raise NotImplementedError(
|
||||
f"{self.model_name} has no structured-output method available; "
|
||||
f"agent factories will fall back to free-text generation."
|
||||
)
|
||||
method = method or caps.preferred_structured_method
|
||||
# When the model rejects tool_choice, suppress langchain's hardcoded
|
||||
# value. The schema is still bound as a tool — exactly what
|
||||
# DeepSeek's official tool-calling examples do.
|
||||
if method == "function_calling" and not caps.supports_tool_choice:
|
||||
kwargs.setdefault("tool_choice", None)
|
||||
return super().with_structured_output(schema, method=method, **kwargs)
|
||||
|
||||
|
||||
@@ -52,18 +68,16 @@ def _input_to_messages(input_: Any) -> list:
|
||||
class DeepSeekChatOpenAI(NormalizedChatOpenAI):
|
||||
"""DeepSeek-specific overrides on top of the OpenAI-compatible client.
|
||||
|
||||
Two quirks that don't apply to other OpenAI-compatible providers:
|
||||
Thinking-mode round-trip is the only DeepSeek-specific behavior that
|
||||
stays here. When DeepSeek's thinking models return a response with
|
||||
``reasoning_content``, that field must be echoed back as part of the
|
||||
assistant message on the next turn or the API fails with HTTP 400.
|
||||
``_create_chat_result`` captures it on receive and
|
||||
``_get_request_payload`` re-attaches it on send.
|
||||
|
||||
1. **Thinking-mode round-trip.** When DeepSeek's thinking models return
|
||||
a response with ``reasoning_content``, that field must be echoed
|
||||
back as part of the assistant message on the next turn or the API
|
||||
fails with HTTP 400. ``_create_chat_result`` captures the field on
|
||||
receive and ``_get_request_payload`` re-attaches it on send.
|
||||
|
||||
2. **deepseek-reasoner has no tool_choice.** Structured output via
|
||||
function-calling is unavailable, so we raise NotImplementedError
|
||||
and let the agent factories fall back to free-text generation
|
||||
(see ``tradingagents/agents/utils/structured.py``).
|
||||
Tool-choice handling for V4 and reasoner — those models reject the
|
||||
``tool_choice`` parameter — is handled by the capability dispatch in
|
||||
``NormalizedChatOpenAI.with_structured_output``, not here.
|
||||
"""
|
||||
|
||||
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
@@ -94,14 +108,27 @@ class DeepSeekChatOpenAI(NormalizedChatOpenAI):
|
||||
generation.message.additional_kwargs["reasoning_content"] = reasoning
|
||||
return chat_result
|
||||
|
||||
def with_structured_output(self, schema, *, method=None, **kwargs):
|
||||
if self.model_name == "deepseek-reasoner":
|
||||
raise NotImplementedError(
|
||||
"deepseek-reasoner does not support tool_choice; structured "
|
||||
"output is unavailable. Agent factories fall back to "
|
||||
"free-text generation automatically."
|
||||
)
|
||||
return super().with_structured_output(schema, method=method, **kwargs)
|
||||
|
||||
class MinimaxChatOpenAI(NormalizedChatOpenAI):
|
||||
"""MiniMax-specific overrides on top of the OpenAI-compatible client.
|
||||
|
||||
M2.x reasoning models embed ``<think>...</think>`` blocks directly in
|
||||
``message.content`` by default, which would pollute saved reports.
|
||||
Per platform.minimax.io/docs/api-reference/text-openai-api, setting
|
||||
``reasoning_split=True`` in the request body redirects the thinking
|
||||
block into ``reasoning_details`` so ``content`` stays clean.
|
||||
|
||||
Tool-choice handling for M2.x — those models accept only the string
|
||||
enum ``{"none", "auto"}`` and reject langchain's function-spec dict —
|
||||
is handled by the capability dispatch in
|
||||
``NormalizedChatOpenAI.with_structured_output``, not here.
|
||||
"""
|
||||
|
||||
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
payload.setdefault("reasoning_split", True)
|
||||
return payload
|
||||
|
||||
|
||||
# Kwargs forwarded from user config to ChatOpenAI
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
@@ -109,14 +136,22 @@ _PASSTHROUGH_KWARGS = (
|
||||
"api_key", "callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
# Provider base URLs and API key env vars
|
||||
_PROVIDER_CONFIG = {
|
||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
|
||||
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
|
||||
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
|
||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||
"ollama": ("http://localhost:11434/v1", None),
|
||||
# Provider base URLs. API-key env vars live in api_key_env.PROVIDER_API_KEY_ENV
|
||||
# (one canonical mapping consulted by both this client and the CLI's
|
||||
# interactive key-prompt). Dual-region providers (qwen/glm/minimax) keep
|
||||
# separate endpoints because international and China accounts cannot share
|
||||
# credentials (#758).
|
||||
_PROVIDER_BASE_URL = {
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"deepseek": "https://api.deepseek.com",
|
||||
"qwen": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"qwen-cn": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"glm": "https://api.z.ai/api/paas/v4/",
|
||||
"glm-cn": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"minimax": "https://api.minimax.io/v1",
|
||||
"minimax-cn": "https://api.minimaxi.com/v1",
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"ollama": "http://localhost:11434/v1",
|
||||
}
|
||||
|
||||
|
||||
@@ -147,13 +182,19 @@ class OpenAIClient(BaseLLMClient):
|
||||
# Provider-specific base URL and auth. An explicit base_url on the
|
||||
# client (e.g. a corporate proxy) takes precedence over the
|
||||
# provider default so users can route through their own gateway.
|
||||
if self.provider in _PROVIDER_CONFIG:
|
||||
default_base, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||
llm_kwargs["base_url"] = self.base_url or default_base
|
||||
if self.provider in _PROVIDER_BASE_URL:
|
||||
llm_kwargs["base_url"] = self.base_url or _PROVIDER_BASE_URL[self.provider]
|
||||
api_key_env = get_api_key_env(self.provider)
|
||||
if api_key_env:
|
||||
api_key = os.environ.get(api_key_env)
|
||||
if api_key:
|
||||
llm_kwargs["api_key"] = api_key
|
||||
else:
|
||||
raise ValueError(
|
||||
f"API key for provider '{self.provider}' is not set. "
|
||||
f"Please set the {api_key_env} environment variable "
|
||||
f"(e.g. add {api_key_env}=your_key to your .env file)."
|
||||
)
|
||||
else:
|
||||
llm_kwargs["api_key"] = "ollama"
|
||||
elif self.base_url:
|
||||
@@ -169,9 +210,14 @@ class OpenAIClient(BaseLLMClient):
|
||||
if self.provider == "openai":
|
||||
llm_kwargs["use_responses_api"] = True
|
||||
|
||||
# DeepSeek's thinking-mode quirks live in their own subclass so the
|
||||
# base NormalizedChatOpenAI stays free of provider-specific branches.
|
||||
chat_cls = DeepSeekChatOpenAI if self.provider == "deepseek" else NormalizedChatOpenAI
|
||||
# Provider-specific quirks live in their own subclasses so the
|
||||
# base NormalizedChatOpenAI stays free of provider branches.
|
||||
if self.provider == "deepseek":
|
||||
chat_cls = DeepSeekChatOpenAI
|
||||
elif self.provider in ("minimax", "minimax-cn"):
|
||||
chat_cls = MinimaxChatOpenAI
|
||||
else:
|
||||
chat_cls = NormalizedChatOpenAI
|
||||
return chat_cls(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user