28 Commits

Author SHA1 Message Date
Yijia-Xiao
a5cb7cbd61 chore: release v0.2.5 — sentiment analyst, env-var config, more providers
Headline themes in v0.2.5:

- Sentiment Analyst grounded in real data. Renamed from social_media_analyst
  and redesigned to pre-fetch Yahoo News, StockTwits, and Reddit before the
  LLM is invoked, ending the prior fabrication behavior.
- MiniMax provider with full M2.x catalog and dual-region split. Qwen and
  GLM also split into international + China regions with separate API keys
  and a clean secondary region prompt in the CLI.
- TRADINGAGENTS_* env-var overlay for DEFAULT_CONFIG with type-aware
  coercion; .env loading centralized so every entry point sees the user's
  keys. Interactive API-key detection prompts and persists missing keys
  to .env on the fly.
- OLLAMA_BASE_URL end-to-end for remote ollama-serve, plus a Custom model
  ID option in the Ollama dropdown.
- Configurable news-fetch parameters and configurable alpha benchmark for
  non-US tickers (.NS / .T / .HK / .L / .TO / .AX / .BO ship with sensible
  regional defaults).
- Multi-language output now propagates to every user-facing agent
  (researchers, risk debators, research manager, trader) instead of only
  the analysts and portfolio manager.
- Model catalog refresh across all providers (GPT-5.5 frontier, Claude
  Opus 4.7, Gemini 3.1 Flash-Lite GA, Grok 4.20, Qwen 3.6 line).
- Capability-dispatch table drives provider-specific structured-output
  quirks (DeepSeek V4/reasoner and MiniMax M2.x tool_choice rejection,
  MiniMax reasoning_split) so the general client stays clean.
- Fixes: ticker path-traversal validation (security), dotenv loading via
  console script, reports save bug, exchange-suffix truncation in the
  ticker prompt, Docker permission errors, deepcopy config isolation,
  max_recur_limit plumbing, clearer missing-API-key error.

See CHANGELOG.md for the full per-item list with issue/PR references.
2026-05-11 09:27:36 +00:00
Yijia-Xiao
78d063dc5c feat(reflection): configurable alpha benchmark for non-US tickers
SPY was hardcoded as the alpha benchmark in both the return-fetch
path and the reflection label, which produced meaningless alpha for
.NS / .T / .HK / .L / .TO / .AX / .BO listings — FX drift between a
local-currency stock and a USD index dominates the spread.

DEFAULT_CONFIG now exposes benchmark_ticker (explicit override) and
benchmark_map (suffix → regional index, with SPY as the empty-suffix
default). TRADINGAGENTS_BENCHMARK_TICKER joins the env-overlay table.
Trading graph resolves the benchmark once per ticker and threads it
through to both _fetch_returns and reflect_on_final_decision, so the
alpha label reads "Alpha vs ^N225" for Tokyo listings, "Alpha vs ^HSI"
for Hong Kong, etc., instead of the misleading "Alpha vs SPY".
2026-05-11 09:14:28 +00:00
Yijia-Xiao
819e813a14 docs(readme): Ollama line covers endpoint, pull, custom model
The Required APIs section now mentions the default endpoint,
OLLAMA_BASE_URL for remote ollama-serve, ollama pull, and the
Custom model ID dropdown option, replacing the previous one-liner
that left those details implicit.
2026-05-11 09:07:38 +00:00
Yijia-Xiao
800862405d feat(ollama): allow Custom model ID in the CLI dropdown
Users with other models pulled via `ollama pull` (beyond the three
suggested defaults) can now select "Custom model ID" and type any
model name. Matches the same pattern used for DeepSeek, GLM, Qwen,
and MiniMax — the existing _prompt_custom_model_id flow handles the
"custom" value generically, so this is a one-row catalog addition
plus regression coverage.
2026-05-11 09:03:06 +00:00
Yijia-Xiao
f10daa2824 feat(ollama): OLLAMA_BASE_URL end-to-end with endpoint confirmation
OLLAMA_BASE_URL now flows through both the CLI dropdown and the
programmatic client (call-time evaluation so tests behave). After
provider selection, the CLI prints the resolved endpoint and marks
when it came from the env var, plus a soft warning when the URL is
missing a scheme or non-default port. Drops the stale "(local)"
suffix from Ollama model labels since the endpoint is now dynamic.
2026-05-11 08:46:21 +00:00
Yijia-Xiao
879e2bb5da refactor: align display label and docs with sentiment_analyst rename
The agent ingests news, StockTwits, and Reddit, but CLI labels, the
README description, and the legacy shim docstring still framed it as
social-media-only. Updates all user-visible surfaces so the name and
the implementation match.
2026-05-11 06:25:22 +00:00
Yijia-Xiao
9f7abfcbd5 feat(cli): detect missing provider API keys and persist to .env
Adds a canonical PROVIDER_API_KEY_ENV mapping (14 providers including
the three dual-region pairs) and an ensure_api_key() helper. When the
selected provider's key is absent from the environment, the CLI prompts
via questionary.password, writes the value to .env via python-dotenv's
set_key (preserves existing lines), and exports it into os.environ so
the run continues without restart. Wired into cli/main.py right after
the region prompts so qwen-cn, glm-cn, and minimax-cn each check their
own region-specific key. openai_client refactored to consult the same
mapping, eliminating its private duplicate of provider→env-var data.
2026-05-11 06:12:34 +00:00
Yijia-Xiao
d13e9b7946 feat(config): TRADINGAGENTS_* env-var overlay for DEFAULT_CONFIG
Adds a single _ENV_OVERRIDES table in default_config.py with type-aware
coercion (str/int/bool), so users can switch llm_provider, deep/quick
models, backend URL, output language, debate rounds, and the checkpoint
flag purely via .env. Centralizes load_dotenv in the package __init__
so the overlay applies for every entry point (CLI, main.py, programmatic).
Drops the hardcoded model assignments and duplicate dotenv loads in
main.py and cli/main.py. Verified live with OpenAI and Gemini.

#602
2026-05-11 06:12:31 +00:00
Yijia-Xiao
6b384f74f9 feat(i18n): localize researchers, risk debators, research mgr, trader
output_language config now propagates to every user-facing agent.
Previously only the four analysts and portfolio manager respected
the setting, producing partial-localization reports with English
debate text interleaved with non-English analyst sections. Verified
live: 7 agents produce Chinese output when config is set to Chinese.

#575
2026-05-11 05:41:42 +00:00
Yijia-Xiao
384fe1a3d2 feat(news): configurable fetch params via DEFAULT_CONFIG
Per-ticker article limit, global article limit, global lookback
window, and macro query list are now read from get_config()
instead of being hardcoded. Tool wrapper get_global_news passes
None defaults so config overrides flow through the LLM-tool path
too. Macro query defaults broadened from 4 US-centric strings to
5 covering Fed, S&P 500, geopolitics, ECB/BOJ/BOE, commodities.

#606 #558 #562
2026-05-11 05:30:52 +00:00
Yijia-Xiao
0fcf13624e feat(agents): rename to sentiment_analyst; integrate StockTwits + Reddit
Pre-fetches news + StockTwits + Reddit via no-auth public endpoints
and injects structured data blocks into the prompt with professional
analysis instructions. Replaces the prompt-vs-tool mismatch that
caused fabricated social-platform content. Backward-compat alias +
"social" CLI key preserved.

#557 #607
2026-05-11 05:20:07 +00:00
Yijia-Xiao
d0dd0420ad feat(llm): GLM dual-region split + catalog refresh
Zhipu serves GLM under two brands with separate accounts (Z.AI
international vs BigModel China); the CLI URL pointed at one while
the openai_client default pointed at the other. Split into glm +
glm-cn with secondary region prompt (same UX as Qwen + MiniMax).
Catalog adds glm-5-turbo and glm-4.5-air per docs.z.ai.
2026-05-11 04:19:50 +00:00
Yijia-Xiao
faaeebac70 feat(cli): collapse regional duplicates; refresh Qwen catalog
Qwen and MiniMax each had two main-dropdown entries (intl + CN);
consolidate to one entry per provider and prompt for region as a
secondary step. Internal provider keys (qwen-cn, minimax-cn) and
endpoint routing unchanged. Add qwen3.6-flash to the Qwen catalog
and drop the version-less aliases (qwen-flash, qwen-plus) that
auto-shift their backing model per Alibaba's docs.

#758
2026-05-11 04:16:11 +00:00
Yijia-Xiao
0011b5ebf5 feat(llm): align xAI catalog with docs — adopt grok-4.20 frontier
xAI's official docs lead with grok-4.20-reasoning and
grok-4.20-non-reasoning across all SDK examples. Replace the prior
grok-4-1-fast-* entries (hyphens where docs use dots, no literal
code example) with the verified grok-4.20 family. Keep grok-4-0709
and grok-4-fast variants that are still referenced.
2026-05-11 03:45:43 +00:00
Yijia-Xiao
4f057e290c feat(llm): swap Gemini 3.1 Flash-Lite to GA stable
gemini-3.1-flash-lite is now GA per ai.google.dev. Use the stable
version (fewer rate limits, stronger compat guarantees) instead of
the -preview suffix. Labels mark preview vs GA explicitly.
2026-05-11 03:32:00 +00:00
Yijia-Xiao
9e00c8117f feat(llm): bump Anthropic catalog to Claude Opus 4.7 frontier
Opus 4.7 is the current frontier per platform.claude.com (frontier
category, listed first). Demote Opus 4.6 to second deep-tier slot.
Polish quick-tier labels to match official wording; effort docstring
includes 4.7.
2026-05-11 02:56:59 +00:00
Yijia-Xiao
78fe77f4e6 feat(llm): bump OpenAI catalog to GPT-5.5 frontier
GPT-5.5 (Apr 2026, 1M ctx, $5/$30 per 1M) replaces GPT-5.4 as the
catalog flagship. GPT-5.5 Pro replaces 5.4 Pro in the most-capable
slot. GPT-5.4 demotes to previous-gen cost-effective option.
2026-05-11 02:49:57 +00:00
Yijia-Xiao
e1316686f8 fix(llm): MiniMax integration polish vs official docs
M2.x tool_choice is enum-only (none/auto), so route through the
no-tool_choice dispatch. MinimaxChatOpenAI injects reasoning_split
so <think> blocks stay out of content. Catalog rounded out to the
full official M2.x lineup plus forward-compat regex.
2026-05-11 02:40:33 +00:00
Yijia-Xiao
9482cae188 fix: bundle config/recursion/missing-key fixes
- dataflows/config: deepcopy + one-level dict merge so a partial
  set_config doesn't clobber sibling defaults
- graph: thread max_recur_limit from config to Propagator
- openai_client: name the missing env var in the API-key error

#788 #764 #680
2026-05-11 02:30:24 +00:00
Yijia-Xiao
19d22b54a9 feat(llm): add MiniMax as a built-in provider
Two regional endpoints (global api.minimax.io, China api.minimaxi.com)
with separate API keys. Models M2.7 / M2.5 plus -highspeed variants,
204K context. Follows the existing provider-preset pattern.

#789 #609 #577 #546 #395 #378
2026-05-11 02:03:27 +00:00
Yijia-Xiao
704b7627f2 fix(docker): pre-create .tradingagents dir with appuser ownership
useradd --create-home creates /home/appuser but not the
.tradingagents subdir, so cache writes fail with PermissionError
when docker-compose mounts a named volume there (the volume
inherits image-dir ownership on first init).

#627 #672 #771 #690 #714 #723 #780 #633 #773 #631
2026-05-11 01:34:45 +00:00
Yijia-Xiao
22bb91bd83 fix(llm): structured output for DeepSeek V4 and reasoner
DeepSeek V4 and reasoner reject tool_choice but accept tools.
Route via a per-model capability table that suppresses tool_choice
for thinking-mode models.

#678 #689
2026-05-11 01:12:28 +00:00
Yijia-Xiao
afdc6d4ec1 chore: suppress upstream langgraph allowed_objects deprecation noise
langgraph-checkpoint 4.0.3 calls Reviver() at module load without
allowed_objects, printing a pending-deprecation warning at every
CLI start. The upstream patch is merged
(langchain-ai/langgraph#7743) but not released; no app-side seam
fixes it. Install a surgical filter in package init (message regex
+ PendingDeprecationWarning category). Remove when we bump past
langgraph-checkpoint 4.0.3.
2026-05-10 19:39:57 +00:00
Yijia-Xiao
e2c850eb17 fix(cli): preserve exchange suffixes in ticker prompt
The typer.prompt-based input could lose .SH/.SZ/.SS/.HK suffixes on
some shells, so exchange-qualified tickers like 000404.SH arrived
truncated to 000404 and failed downstream lookups. Switch to
questionary.text which reads the raw line; keep SPY-on-empty
behavior and validate the allowed character set (alnum, ._-^) up
to 32 chars.

#770
2026-05-10 19:29:41 +00:00
Yijia-Xiao
c405867bde fix: merge streamed chunks into final_state so reports save correctly
graph.stream() yields per-node deltas, not the full state. Taking
trace[-1] only captured the last node's contribution, so reports
saved to disk were missing every section except the final decision.
Merge all chunks in both the CLI path and trading_graph._run_graph's
debug branch.

#719 #736
2026-05-10 19:20:23 +00:00
Yijia-Xiao
db7e0a67e2 fix(cli): load .env from user's CWD when run as console script
load_dotenv() with no arguments walks up from site-packages instead
of the user's CWD, so the installed tradingagents console script
silently misses the project's .env. Pass find_dotenv(usecwd=True)
so the search starts from CWD; same treatment for .env.enterprise.

#726 #755 #612 #747 #743 #753 #729 #728 #751
2026-05-10 09:49:07 +00:00
Yijia-Xiao
7e9e7b83c7 feat: DeepSeek V4 thinking-mode round-trip via DeepSeekChatOpenAI subclass
Resolves #599: thinking-mode models require reasoning_content to be
echoed back across turns; multi-turn agent runs failed with HTTP 400.

The fix isolates DeepSeek's quirks (reasoning_content round-trip and
the deepseek-reasoner no-tool_choice limitation) into a subclass so
the general OpenAI-compatible client stays untouched. Adds DeepSeek
V4 Pro/Flash to the catalog. 9 new tests; rationale documented in
the class docstrings.

Design adapted from #600; #611 closed in favour of this approach.
2026-05-01 19:23:23 +00:00
Yijia-Xiao
2c97bad45c fix(security): validate ticker before using as path component (#618)
The ticker symbol reaches three filesystem-path construction sites
(load_ohlcv cache filename, checkpointer DB path, _log_state results
directory) without validation. A value containing path separators or
"../" escapes the configured cache / checkpoints / results directory.

Two attack vectors:
- Programmatic callers passing arbitrary ticker to propagate()
- Prompt injection via fetched news content steering the LLM into
  tool calls with attacker-chosen ticker

Fix: new safe_ticker_component() validator in tradingagents/dataflows/
utils.py applied at all three sites. Allows the standard ticker
character set ([A-Za-z0-9._\-\^], up to 32 chars) and explicitly
rejects dot-only values like "." and ".." which would otherwise pass
the regex but traverse parent directories. Seven test cases cover
the accepted formats (BRK-B, 7203.T, ^GSPC, etc.) and the rejected
inputs (path separators, null bytes, whitespace, empty values,
overlong strings, dot-only values).

Closes #618.
2026-05-01 18:56:36 +00:00
49 changed files with 2512 additions and 238 deletions

View File

@@ -5,5 +5,28 @@ ANTHROPIC_API_KEY=
XAI_API_KEY=
DEEPSEEK_API_KEY=
DASHSCOPE_API_KEY=
DASHSCOPE_CN_API_KEY=
ZHIPU_API_KEY=
ZHIPU_CN_API_KEY=
MINIMAX_API_KEY=
MINIMAX_CN_API_KEY=
OPENROUTER_API_KEY=
# Optional: point at a remote Ollama server. When unset, defaults to
# the local instance at http://localhost:11434/v1. Convention follows
# the broader Ollama ecosystem; both the CLI dropdown and programmatic
# client pick this up.
#OLLAMA_BASE_URL=http://your-ollama-host:11434/v1
# Optional: override DEFAULT_CONFIG without editing code.
# Any TRADINGAGENTS_* variable below, when set, replaces the matching key
# in tradingagents/default_config.py. Values are coerced to the type of
# the existing default (bool / int / str), so "true"/"3" work as expected.
#TRADINGAGENTS_LLM_PROVIDER=openai
#TRADINGAGENTS_DEEP_THINK_LLM=gpt-5.4
#TRADINGAGENTS_QUICK_THINK_LLM=gpt-5.4-mini
#TRADINGAGENTS_LLM_BACKEND_URL=
#TRADINGAGENTS_OUTPUT_LANGUAGE=English
#TRADINGAGENTS_MAX_DEBATE_ROUNDS=1
#TRADINGAGENTS_MAX_RISK_ROUNDS=1
#TRADINGAGENTS_CHECKPOINT_ENABLED=false

View File

@@ -6,6 +6,81 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
Breaking changes within the 0.x line are called out explicitly.
## [0.2.5] — 2026-05-11
### Added
- **Grounded Sentiment Analyst.** The renamed `sentiment_analyst` now reads
real Yahoo News, StockTwits, and Reddit data before generating its report,
replacing the prior flow that could fabricate social posts under prompt
pressure. (#557, #607)
- **MiniMax provider** with the full M2.x catalog (M2.7 / M2.5 / M2.1 / M2
plus highspeed variants, 204K context). Dual-region: Global
(`MINIMAX_API_KEY`) and China (`MINIMAX_CN_API_KEY`).
- **Dual-region Qwen and GLM** with separate keys per region — international
(`DASHSCOPE_API_KEY`, `ZHIPU_API_KEY`) and China (`DASHSCOPE_CN_API_KEY`,
`ZHIPU_CN_API_KEY`), selectable via a secondary region prompt. (#758)
- **`TRADINGAGENTS_*` env-var configurability for `DEFAULT_CONFIG`.** Override
`llm_provider`, deep/quick model IDs, `backend_url`, `output_language`,
debate-round counts, checkpoint flag, and benchmark ticker via `.env` with
type-aware coercion (string / int / bool). (#602)
- **Interactive API-key detection in the CLI.** When the selected provider's
key is missing, the CLI prompts for it and persists the value to `.env`
so the analysis run continues without restart.
- **Remote Ollama support.** `OLLAMA_BASE_URL` points the CLI and the
programmatic client at a remote `ollama-serve`. The CLI surfaces the
resolved endpoint and warns on common malformed inputs. Adds a
`"Custom model ID"` option for models pulled via `ollama pull`. (#648, #768)
- **Configurable news-fetch parameters** in `DEFAULT_CONFIG` — per-ticker
article limit, macro headline limit, lookback window, and macro search
queries. (#606, #683)
- **Configurable alpha benchmark** for non-US tickers. Replaces hardcoded
SPY with regional indices for `.NS` (^NSEI), `.T` (^N225), `.HK` (^HSI),
`.L` (^FTSE), `.TO` (^GSPTSE), `.AX` (^AXJO), `.BO` (^BSESN); explicit
`benchmark_ticker` override available. Eliminates FX drift dominating
alpha for non-USD listings. (#628, #684)
- **Multi-language output covers every user-facing agent** — researchers,
risk debators, research manager, and trader, ending the previous
partial-localization reports. (#575)
- **Model catalog refresh.** OpenAI GPT-5.5 frontier, Anthropic Claude Opus
4.7, Gemini 3.1 Flash-Lite GA, xAI Grok 4.20, Qwen 3.6 line. Versioned IDs
only; auto-shifting aliases moved to the `"Custom model ID"` option.
### Changed
- **Sentiment Analyst** is now consistently named across the CLI dropdown,
status panel, and final reports (previously the backend was renamed but
the CLI still said "Social Analyst"). The `AnalystType.SOCIAL = "social"`
wire value is kept for saved-config back-compat.
### Fixed
- **Structured output works on DeepSeek V4 / reasoner and MiniMax M2.x.**
Those providers reject `tool_choice` per their tool-calling docs; the
binding flow now skips it automatically via a capability table.
- **`pip install .` installations pick up the project `.env`** when running
the CLI as a console script. (#747)
- **Reports save end-to-end** — streamed chunks were previously dropped from
`complete_report.md`. (#719, #736)
- **Ticker prompt preserves exchange suffixes** (`.SH`, `.SZ`, `.SS`, `.HK`,
`.T`, etc.) for A-share, HK, Tokyo, and other non-US flows. (#770)
- **Docker permission errors** no longer block first-run write to
`~/.tradingagents/`. (#519, #627, #672, #771)
- **Config state no longer leaks between runs** when sub-dicts are mutated;
`set_config` partial updates preserve sibling defaults. (#788)
- **`max_recur_limit` config actually applies** — previously read but not
forwarded to the propagator. (#764)
- **Missing-API-key error** names the exact env var to set. (#680)
- **Quieter startup** — suppressed the noisy upstream
`LangChainPendingDeprecationWarning` from langgraph-checkpoint; will be
removed once that package ships its fix.
### Security
- **Ticker path-traversal validation** at every filesystem-path site (cache,
checkpoint database, results) so a malicious ticker cannot escape its
intended directory. (#618)
## [0.2.4] — 2026-04-25
### Added

View File

@@ -18,7 +18,8 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN useradd --create-home appuser
RUN useradd --create-home appuser \
&& install -d -m 0755 -o appuser -g appuser /home/appuser/.tradingagents
USER appuser
WORKDIR /home/appuser/app

View File

@@ -28,7 +28,8 @@
# TradingAgents: Multi-Agents LLM Financial Trading Framework
## News
- [2026-04] **TradingAgents v0.2.4** released with structured-output agents (Research Manager, Trader, Portfolio Manager), LangGraph checkpoint resume, persistent decision log, DeepSeek/Qwen/GLM/Azure provider support, Docker, and a Windows UTF-8 encoding fix. See [CHANGELOG.md](CHANGELOG.md) for the full list.
- [2026-05] **TradingAgents v0.2.5** released with the grounded Sentiment Analyst, GPT-5.5 etc. model coverage, Qwen/GLM/MiniMax dual-region support, `TRADINGAGENTS_*` env-var configurability with API-key auto-detection, remote Ollama support, non-US alpha benchmarks, and ticker path-traversal hardening. See [CHANGELOG.md](CHANGELOG.md) for the full list.
- [2026-04] **TradingAgents v0.2.4** released with structured-output agents (Research Manager, Trader, Portfolio Manager), LangGraph checkpoint resume, persistent decision log, DeepSeek/Qwen/GLM/Azure provider support, Docker, and a Windows UTF-8 encoding fix.
- [2026-03] **TradingAgents v0.2.3** released with multi-language support, GPT-5.4 family models, unified model catalog, backtesting date fidelity, and proxy support.
- [2026-03] **TradingAgents v0.2.2** released with GPT-5.4/Gemini 3.1/Claude 4.6 model coverage, five-tier rating scale, OpenAI Responses API, Anthropic effort control, and cross-platform stability.
- [2026-02] **TradingAgents v0.2.0** released with multi-provider LLM support (GPT-5.x, Gemini 3.x, Claude 4.x, Grok 4.x) and improved system architecture.
@@ -68,7 +69,7 @@ Our framework decomposes complex trading tasks into specialized roles. This ensu
### Analyst Team
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
- Sentiment Analyst: Aggregates news headlines, StockTwits, and Reddit chatter into a single sentiment read to gauge short-term market mood.
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
@@ -142,15 +143,19 @@ export GOOGLE_API_KEY=... # Google (Gemini)
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
export XAI_API_KEY=... # xAI (Grok)
export DEEPSEEK_API_KEY=... # DeepSeek
export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope)
export ZHIPU_API_KEY=... # GLM (Zhipu)
export DASHSCOPE_API_KEY=... # Qwen — International (dashscope-intl.aliyuncs.com)
export DASHSCOPE_CN_API_KEY=... # Qwen — China (dashscope.aliyuncs.com)
export ZHIPU_API_KEY=... # GLM via Z.AI (international)
export ZHIPU_CN_API_KEY=... # GLM via BigModel (China, open.bigmodel.cn)
export MINIMAX_API_KEY=... # MiniMax — Global (api.minimax.io, M2.x, 204K ctx)
export MINIMAX_CN_API_KEY=... # MiniMax — China (api.minimaxi.com, M2.x, 204K ctx)
export OPENROUTER_API_KEY=... # OpenRouter
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
```
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
For local models, configure Ollama with `llm_provider: "ollama"`. The default endpoint is `http://localhost:11434/v1`; set `OLLAMA_BASE_URL` to point at a remote `ollama-serve`. Pull models with `ollama pull <name>`, and pick "Custom model ID" in the CLI for any model not listed by default.
Alternatively, copy `.env.example` to `.env` and fill in your keys:
```bash
@@ -184,7 +189,7 @@ An interface will appear showing results as they load, letting you track the age
### Implementation Details
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise.
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope, international and China endpoints), GLM (Zhipu), MiniMax (global + China), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise.
### Python Usage
@@ -208,7 +213,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, openrouter, ollama, azure
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, qwen-cn, glm, glm-cn, minimax, minimax-cn, openrouter, ollama, azure
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
config["max_debate_rounds"] = 2

View File

@@ -1,14 +1,10 @@
from typing import Optional
import datetime
import typer
import questionary
from pathlib import Path
from functools import wraps
from rich.console import Console
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
load_dotenv(".env.enterprise", override=False)
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
@@ -53,7 +49,7 @@ class MessageBuffer:
# Analyst name mapping
ANALYST_MAPPING = {
"market": "Market Analyst",
"social": "Social Analyst",
"social": "Sentiment Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
}
@@ -63,7 +59,7 @@ class MessageBuffer:
# finalizing_agent: which agent must be "completed" for this report to count as done
REPORT_SECTIONS = {
"market_report": ("market", "Market Analyst"),
"sentiment_report": ("social", "Social Analyst"),
"sentiment_report": ("social", "Sentiment Analyst"),
"news_report": ("news", "News Analyst"),
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
"investment_plan": (None, "Research Manager"),
@@ -284,7 +280,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
all_teams = {
"Analyst Team": [
"Market Analyst",
"Social Analyst",
"Sentiment Analyst",
"News Analyst",
"Fundamentals Analyst",
],
@@ -556,6 +552,26 @@ def get_user_selections():
)
selected_llm_provider, backend_url = select_llm_provider()
# Providers with regional endpoints prompt for the region as a secondary
# step so the main dropdown stays clean (mainland China and international
# accounts cannot share API keys).
if selected_llm_provider == "qwen":
selected_llm_provider, backend_url = ask_qwen_region()
elif selected_llm_provider == "minimax":
selected_llm_provider, backend_url = ask_minimax_region()
elif selected_llm_provider == "glm":
selected_llm_provider, backend_url = ask_glm_region()
# For Ollama, surface the resolved endpoint (OLLAMA_BASE_URL vs default)
# before model selection so it's obvious where we're connecting.
if selected_llm_provider == "ollama":
confirm_ollama_endpoint(backend_url)
# Confirm the provider's API key is present; prompt the user to paste
# one and persist it to .env if it's missing, so the analysis run
# doesn't fail later at the first API call.
ensure_api_key(selected_llm_provider)
# Step 7: Thinking agents
console.print(
create_question_box(
@@ -613,8 +629,26 @@ def get_user_selections():
def get_ticker():
"""Get ticker symbol from user input."""
return typer.prompt("", default="SPY")
"""Get ticker symbol from user input, preserving exchange suffixes."""
# typer.prompt strips trailing dot-suffixes on some shells (e.g. 000404.SH
# collapses to 000404). questionary.text reads the raw line.
ticker = questionary.text(
"",
validate=lambda value: (
not value.strip()
or (
all(ch.isalnum() or ch in "._-^" for ch in value.strip())
and len(value.strip()) <= 32
)
)
or "Please enter a valid ticker symbol, e.g. AAPL, 000404.SZ, 0700.HK.",
).ask()
if ticker is None:
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
raise typer.Exit(1)
return (ticker.strip() or "SPY").upper()
def get_analysis_date():
@@ -651,7 +685,7 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
if final_state.get("sentiment_report"):
analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"], encoding="utf-8")
analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
analyst_parts.append(("Sentiment Analyst", final_state["sentiment_report"]))
if final_state.get("news_report"):
analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "news.md").write_text(final_state["news_report"], encoding="utf-8")
@@ -736,7 +770,7 @@ def display_complete_report(final_state):
if final_state.get("market_report"):
analysts.append(("Market Analyst", final_state["market_report"]))
if final_state.get("sentiment_report"):
analysts.append(("Social Analyst", final_state["sentiment_report"]))
analysts.append(("Sentiment Analyst", final_state["sentiment_report"]))
if final_state.get("news_report"):
analysts.append(("News Analyst", final_state["news_report"]))
if final_state.get("fundamentals_report"):
@@ -798,7 +832,7 @@ def update_research_team_status(status):
ANALYST_ORDER = ["market", "social", "news", "fundamentals"]
ANALYST_AGENT_NAMES = {
"market": "Market Analyst",
"social": "Social Analyst",
"social": "Sentiment Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
}
@@ -1152,8 +1186,11 @@ def run_analysis(checkpoint: bool = False):
trace.append(chunk)
# Get final state and decision
final_state = trace[-1]
# Streamed chunks are per-node deltas, not full state. Merge them
# so every report field populated across the run is present.
final_state = {}
for chunk in trace:
final_state.update(chunk)
decision = graph.process_signal(final_state["final_trade_decision"])
# Update all agent statuses to completed

View File

@@ -5,6 +5,8 @@ from pydantic import BaseModel
class AnalystType(str, Enum):
MARKET = "market"
# Wire value stays "social" for saved-config and string-keyed-caller
# back-compat; the user-facing label is "Sentiment Analyst".
SOCIAL = "social"
NEWS = "news"
FUNDAMENTALS = "fundamentals"

View File

@@ -1,9 +1,13 @@
import questionary
import os
from pathlib import Path
from typing import List, Optional, Tuple, Dict
import questionary
from dotenv import find_dotenv, set_key
from rich.console import Console
from cli.models import AnalystType
from tradingagents.llm_clients.api_key_env import get_api_key_env
from tradingagents.llm_clients.model_catalog import get_model_options
console = Console()
@@ -12,7 +16,7 @@ TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
ANALYST_ORDER = [
("Market Analyst", AnalystType.MARKET),
("Social Media Analyst", AnalystType.SOCIAL),
("Sentiment Analyst", AnalystType.SOCIAL),
("News Analyst", AnalystType.NEWS),
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
]
@@ -230,6 +234,10 @@ def select_deep_thinking_agent(provider) -> str:
def select_llm_provider() -> tuple[str, str | None]:
"""Select the LLM provider and its API endpoint."""
# Ollama users can point at a remote ollama-serve via OLLAMA_BASE_URL
# (convention from the broader Ollama ecosystem); falls back to the
# localhost default when unset.
ollama_url = os.environ.get("OLLAMA_BASE_URL") or "http://localhost:11434/v1"
# (display_name, provider_key, base_url)
PROVIDERS = [
("OpenAI", "openai", "https://api.openai.com/v1"),
@@ -237,11 +245,12 @@ def select_llm_provider() -> tuple[str, str | None]:
("Anthropic", "anthropic", "https://api.anthropic.com/"),
("xAI", "xai", "https://api.x.ai/v1"),
("DeepSeek", "deepseek", "https://api.deepseek.com"),
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
("Qwen", "qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"),
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
("MiniMax", "minimax", "https://api.minimax.io/v1"),
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
("Azure OpenAI", "azure", None),
("Ollama", "ollama", "http://localhost:11434/v1"),
("Ollama", "ollama", ollama_url),
]
choice = questionary.select(
@@ -289,7 +298,9 @@ def ask_openai_reasoning_effort() -> str:
def ask_anthropic_effort() -> str | None:
"""Ask for Anthropic effort level.
Controls token usage and response thoroughness on Claude 4.5+ and 4.6 models.
Controls token usage and response thoroughness on Claude 4.5 / 4.6 / 4.7
models. The API also accepts "max"; we expose low/medium/high as the
common selection range.
"""
return questionary.select(
"Select Effort Level:",
@@ -326,6 +337,159 @@ def ask_gemini_thinking_config() -> str | None:
).ask()
def ask_glm_region() -> tuple[str, str]:
"""Ask which GLM platform (Z.AI international vs BigModel China) to use.
Zhipu serves the same GLM models under two brands with separate
accounts; keys aren't interchangeable. Returns (provider_key, backend_url).
"""
return questionary.select(
"Select GLM platform:",
choices=[
questionary.Choice(
"Z.AI — api.z.ai (international, uses ZHIPU_API_KEY)",
value=("glm", "https://api.z.ai/api/paas/v4/"),
),
questionary.Choice(
"BigModel — open.bigmodel.cn (China, uses ZHIPU_CN_API_KEY)",
value=("glm-cn", "https://open.bigmodel.cn/api/paas/v4/"),
),
],
style=questionary.Style([
("selected", "fg:cyan noinherit"),
("highlighted", "fg:cyan noinherit"),
("pointer", "fg:cyan noinherit"),
]),
).ask()
def ask_qwen_region() -> tuple[str, str]:
"""Ask which Qwen region (international vs China) to use.
Alibaba DashScope exposes two endpoints with separate accounts —
a key from one region does NOT authenticate against the other
(fixes #758). Returns (provider_key, backend_url).
"""
return questionary.select(
"Select Qwen region:",
choices=[
questionary.Choice(
"International — dashscope-intl.aliyuncs.com (uses DASHSCOPE_API_KEY)",
value=("qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"),
),
questionary.Choice(
"China — dashscope.aliyuncs.com (uses DASHSCOPE_CN_API_KEY)",
value=("qwen-cn", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
),
],
style=questionary.Style([
("selected", "fg:cyan noinherit"),
("highlighted", "fg:cyan noinherit"),
("pointer", "fg:cyan noinherit"),
]),
).ask()
def ask_minimax_region() -> tuple[str, str]:
"""Ask which MiniMax region (global vs China) to use.
MiniMax exposes two endpoints with separate accounts — a key from
one region does NOT authenticate against the other. Returns
(provider_key, backend_url).
"""
return questionary.select(
"Select MiniMax region:",
choices=[
questionary.Choice(
"Global — api.minimax.io (uses MINIMAX_API_KEY)",
value=("minimax", "https://api.minimax.io/v1"),
),
questionary.Choice(
"China — api.minimaxi.com (uses MINIMAX_CN_API_KEY)",
value=("minimax-cn", "https://api.minimaxi.com/v1"),
),
],
style=questionary.Style([
("selected", "fg:cyan noinherit"),
("highlighted", "fg:cyan noinherit"),
("pointer", "fg:cyan noinherit"),
]),
).ask()
def confirm_ollama_endpoint(url: str) -> None:
"""Show the resolved Ollama endpoint after provider selection.
Surfaces three things the user benefits from seeing before model
selection: which URL we'll actually hit, where it came from
(\`OLLAMA_BASE_URL\` vs default), and a soft warning if the URL is
missing the scheme/port that ollama-serve expects. The warning is
advisory only — we don't reject malformed input, since the user may
be doing something deliberately unusual (e.g. a reverse-proxy path).
"""
from_env = os.environ.get("OLLAMA_BASE_URL")
origin = " (from OLLAMA_BASE_URL)" if from_env and from_env == url else ""
console.print(f"[green]✓ Using Ollama at {url}{origin}[/green]")
if not url.startswith(("http://", "https://")):
console.print(
f"[yellow]Note: {url!r} is missing a scheme. "
f"Ollama-serve typically expects a URL like "
f"http://<host>:11434/v1.[/yellow]"
)
elif ":11434" not in url and "://localhost" not in url and "://127.0.0.1" not in url:
# Soft hint when the port differs from the ollama-serve default
# and the host isn't local (where users sometimes proxy on :80).
console.print(
f"[yellow]Note: {url!r} doesn't include port 11434. "
f"Make sure your remote ollama-serve listens on the port "
f"shown above.[/yellow]"
)
def ensure_api_key(provider: str) -> Optional[str]:
"""Make sure the API key for `provider` is available in the environment.
If the env var is already set, returns its value untouched. Otherwise
interactively prompts the user, persists the value to the project's
.env file via python-dotenv's set_key (creating .env if needed), and
exports it into os.environ so the current process picks it up.
Returns None for providers that do not require a key (e.g. ollama)
and for providers not found in the canonical mapping.
"""
env_var = get_api_key_env(provider)
if env_var is None:
return None # ollama / unknown — no key check possible
existing = os.environ.get(env_var)
if existing:
return existing
console.print(
f"\n[yellow]{env_var} is not set in your environment.[/yellow]"
)
key = questionary.password(
f"Paste your {env_var} (will be saved to .env):",
style=questionary.Style([
("text", "fg:cyan"),
("highlighted", "noinherit"),
]),
).ask()
if not key:
console.print(
f"[red]Skipped. API calls will fail until {env_var} is set.[/red]"
)
return None
env_path = find_dotenv(usecwd=True) or str(Path.cwd() / ".env")
Path(env_path).touch(exist_ok=True)
set_key(env_path, env_var, key)
os.environ[env_var] = key
console.print(f"[green]Saved {env_var} to {env_path}[/green]")
return key
def ask_output_language() -> str:
"""Ask for report output language."""
choice = questionary.select(

22
main.py
View File

@@ -1,24 +1,12 @@
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Create a custom config
# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides
# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.),
# so users can switch models or endpoints purely via .env without
# editing this script. Override individual keys here only when you
# want a hard-coded value that should ignore the environment.
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-5.4-mini" # Use a different model
config["quick_think_llm"] = "gpt-5.4-mini" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
# Configure data vendors (default uses yfinance, no extra API keys needed)
config["data_vendors"] = {
"core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
"news_data": "yfinance", # Options: alpha_vantage, yfinance
}
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "tradingagents"
version = "0.2.4"
version = "0.2.5"
description = "TradingAgents: Multi-Agents LLM Financial Trading Framework"
readme = "README.md"
requires-python = ">=3.10"

View File

@@ -18,7 +18,11 @@ _API_KEY_ENV_VARS = (
"XAI_API_KEY",
"DEEPSEEK_API_KEY",
"DASHSCOPE_API_KEY",
"DASHSCOPE_CN_API_KEY",
"ZHIPU_API_KEY",
"ZHIPU_CN_API_KEY",
"MINIMAX_API_KEY",
"MINIMAX_CN_API_KEY",
"OPENROUTER_API_KEY",
"AZURE_OPENAI_API_KEY",
"ALPHA_VANTAGE_API_KEY",

149
tests/test_api_key_env.py Normal file
View File

@@ -0,0 +1,149 @@
"""Tests for the canonical provider->env-var mapping and the CLI key-prompt helper."""
from __future__ import annotations
import os
from pathlib import Path
from unittest.mock import patch
import pytest
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
# ---- Mapping coverage -----------------------------------------------------
def test_every_select_llm_provider_choice_has_an_entry():
"""select_llm_provider() must not present a provider the mapping doesn't know about."""
# Mirrors the dropdown order in cli/utils.select_llm_provider so the two
# stay in lockstep. Region-specific keys (qwen-cn / minimax-cn / glm-cn)
# are reached via the secondary region prompt, so they must also be present.
expected = {
"openai", "google", "anthropic", "xai", "deepseek",
"qwen", "qwen-cn",
"glm", "glm-cn",
"minimax", "minimax-cn",
"openrouter", "azure", "ollama",
}
assert expected.issubset(PROVIDER_API_KEY_ENV.keys())
@pytest.mark.parametrize(
"provider,env_var",
[
("openai", "OPENAI_API_KEY"),
("anthropic", "ANTHROPIC_API_KEY"),
("google", "GOOGLE_API_KEY"),
("azure", "AZURE_OPENAI_API_KEY"),
("xai", "XAI_API_KEY"),
("deepseek", "DEEPSEEK_API_KEY"),
("qwen", "DASHSCOPE_API_KEY"),
("qwen-cn", "DASHSCOPE_CN_API_KEY"),
("glm", "ZHIPU_API_KEY"),
("glm-cn", "ZHIPU_CN_API_KEY"),
("minimax", "MINIMAX_API_KEY"),
("minimax-cn", "MINIMAX_CN_API_KEY"),
("openrouter", "OPENROUTER_API_KEY"),
],
)
def test_known_providers_resolve(provider, env_var):
assert get_api_key_env(provider) == env_var
def test_ollama_has_no_key():
assert get_api_key_env("ollama") is None
def test_unknown_provider_returns_none():
assert get_api_key_env("not-a-real-provider") is None
def test_case_insensitive_lookup():
assert get_api_key_env("OpenAI") == "OPENAI_API_KEY"
assert get_api_key_env("QWEN-CN") == "DASHSCOPE_CN_API_KEY"
# ---- ensure_api_key behavior ---------------------------------------------
@pytest.fixture
def cli_utils(monkeypatch):
"""Import cli.utils with a fresh environment so module-level state is consistent."""
import importlib
import cli.utils as cli_utils_module
return importlib.reload(cli_utils_module)
def test_ensure_api_key_returns_existing(monkeypatch, cli_utils):
monkeypatch.setenv("OPENAI_API_KEY", "sk-already-set")
result = cli_utils.ensure_api_key("openai")
assert result == "sk-already-set"
def test_ensure_api_key_no_op_for_ollama(monkeypatch, cli_utils):
# Even with no env var set, ollama should not prompt and should return None.
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with patch.object(cli_utils, "questionary") as mock_q:
result = cli_utils.ensure_api_key("ollama")
assert result is None
mock_q.password.assert_not_called()
def test_ensure_api_key_unknown_provider_no_prompt(monkeypatch, cli_utils):
with patch.object(cli_utils, "questionary") as mock_q:
result = cli_utils.ensure_api_key("totally-fake-provider")
assert result is None
mock_q.password.assert_not_called()
def test_ensure_api_key_prompts_and_writes_to_env(monkeypatch, tmp_path, cli_utils):
"""When key is missing, user-pasted value must be written to .env AND os.environ."""
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
monkeypatch.chdir(tmp_path)
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-deepseek-test")})()
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
result = cli_utils.ensure_api_key("deepseek")
assert result == "sk-deepseek-test"
assert os.environ["DEEPSEEK_API_KEY"] == "sk-deepseek-test"
env_file = tmp_path / ".env"
assert env_file.exists()
assert "DEEPSEEK_API_KEY" in env_file.read_text()
assert "sk-deepseek-test" in env_file.read_text()
def test_ensure_api_key_user_cancels_returns_none(monkeypatch, tmp_path, cli_utils):
"""Empty prompt response (user cancelled) must not write to .env."""
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.chdir(tmp_path)
fake_prompt = type("P", (), {"ask": staticmethod(lambda: None)})()
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
result = cli_utils.ensure_api_key("xai")
assert result is None
assert "XAI_API_KEY" not in os.environ
# .env may or may not exist depending on find_dotenv's walk, but if it
# does it must not contain the key.
env_file = tmp_path / ".env"
if env_file.exists():
assert "XAI_API_KEY" not in env_file.read_text()
def test_ensure_api_key_updates_existing_env_file(monkeypatch, tmp_path, cli_utils):
"""An existing .env with other keys must be preserved on writeback."""
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.chdir(tmp_path)
env_file = tmp_path / ".env"
env_file.write_text("OPENAI_API_KEY=sk-existing\nOTHER=value\n")
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-openrouter-new")})()
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
cli_utils.ensure_api_key("openrouter")
content = env_file.read_text()
assert "OPENAI_API_KEY" in content and "sk-existing" in content
assert "OTHER=value" in content
assert "OPENROUTER_API_KEY" in content and "sk-openrouter-new" in content

107
tests/test_capabilities.py Normal file
View File

@@ -0,0 +1,107 @@
"""Unit tests for the LLM capability table."""
import pytest
from tradingagents.llm_clients.capabilities import (
ModelCapabilities,
get_capabilities,
)
@pytest.mark.unit
class TestExactIdMatches:
def test_deepseek_chat_supports_tool_choice(self):
caps = get_capabilities("deepseek-chat")
assert caps.supports_tool_choice is True
def test_deepseek_reasoner_rejects_tool_choice(self):
caps = get_capabilities("deepseek-reasoner")
assert caps.supports_tool_choice is False
assert caps.requires_reasoning_content_roundtrip is True
def test_deepseek_v4_flash_rejects_tool_choice(self):
caps = get_capabilities("deepseek-v4-flash")
assert caps.supports_tool_choice is False
assert caps.requires_reasoning_content_roundtrip is True
def test_deepseek_v4_pro_rejects_tool_choice(self):
caps = get_capabilities("deepseek-v4-pro")
assert caps.supports_tool_choice is False
assert caps.requires_reasoning_content_roundtrip is True
@pytest.mark.unit
class TestPatternMatches:
"""Forward-compat regex patterns catch unknown DeepSeek and MiniMax variants."""
def test_future_deepseek_v5_inherits_thinking_quirks(self):
caps = get_capabilities("deepseek-v5-flash")
assert caps.supports_tool_choice is False
assert caps.requires_reasoning_content_roundtrip is True
def test_future_deepseek_v9_inherits_thinking_quirks(self):
caps = get_capabilities("deepseek-v9-anything")
assert caps.supports_tool_choice is False
def test_reasoner_variant_inherits_thinking_quirks(self):
caps = get_capabilities("deepseek-reasoner-pro")
assert caps.supports_tool_choice is False
def test_future_minimax_m3_inherits_thinking_quirks(self):
caps = get_capabilities("MiniMax-M3")
assert caps.supports_tool_choice is False
def test_future_minimax_m4_highspeed_inherits_thinking_quirks(self):
caps = get_capabilities("MiniMax-M4-highspeed")
assert caps.supports_tool_choice is False
@pytest.mark.unit
class TestMinimaxExactMatches:
"""MiniMax M2.x models reject langchain's function-spec dict tool_choice
(official API enum: none/auto only)."""
def test_m2_7_rejects_tool_choice(self):
caps = get_capabilities("MiniMax-M2.7")
assert caps.supports_tool_choice is False
assert caps.supports_json_mode is False # only MiniMax-Text-01 supports json_object
def test_m2_7_highspeed_rejects_tool_choice(self):
assert get_capabilities("MiniMax-M2.7-highspeed").supports_tool_choice is False
def test_m2_1_rejects_tool_choice(self):
assert get_capabilities("MiniMax-M2.1").supports_tool_choice is False
def test_m2_base_rejects_tool_choice(self):
assert get_capabilities("MiniMax-M2").supports_tool_choice is False
@pytest.mark.unit
class TestDefault:
"""Unknown / non-DeepSeek models get the permissive default."""
def test_gpt_default(self):
caps = get_capabilities("gpt-4.1")
assert caps.supports_tool_choice is True
assert caps.preferred_structured_method == "function_calling"
def test_grok_default(self):
caps = get_capabilities("grok-4-0709")
assert caps.supports_tool_choice is True
def test_unknown_model_default(self):
caps = get_capabilities("totally-made-up-model-id")
assert caps.supports_tool_choice is True
def test_exact_match_precedes_pattern(self):
"""deepseek-chat must NOT match the v\\d regex."""
caps = get_capabilities("deepseek-chat")
assert caps.supports_tool_choice is True
@pytest.mark.unit
def test_capabilities_dataclass_is_frozen():
"""Capability rows are immutable so they can be safely shared."""
caps = get_capabilities("deepseek-chat")
with pytest.raises(Exception):
caps.supports_tool_choice = False # type: ignore[misc]

View File

@@ -0,0 +1,61 @@
"""Config isolation: get/set must not leak nested-dict references."""
import copy
import unittest
import pytest
import tradingagents.default_config as default_config
from tradingagents.dataflows.config import get_config, set_config
@pytest.mark.unit
class DataflowsConfigIsolationTests(unittest.TestCase):
def setUp(self):
set_config(copy.deepcopy(default_config.DEFAULT_CONFIG))
def test_get_config_returns_deep_copy(self):
cfg = get_config()
cfg["data_vendors"]["core_stock_apis"] = "alpha_vantage"
cfg["tool_vendors"]["get_stock_data"] = "alpha_vantage"
fresh = get_config()
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "yfinance")
self.assertNotIn("get_stock_data", fresh["tool_vendors"])
def test_set_config_does_not_alias_caller_nested_dicts(self):
custom = copy.deepcopy(default_config.DEFAULT_CONFIG)
custom["data_vendors"]["core_stock_apis"] = "alpha_vantage"
custom["tool_vendors"]["get_stock_data"] = "alpha_vantage"
set_config(custom)
custom["data_vendors"]["core_stock_apis"] = "yfinance"
custom["tool_vendors"]["get_stock_data"] = "yfinance"
fresh = get_config()
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
def test_partial_nested_update_preserves_existing_defaults(self):
set_config(
{
"data_vendors": {
"core_stock_apis": "alpha_vantage",
}
}
)
fresh = get_config()
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
self.assertEqual(fresh["data_vendors"]["technical_indicators"], "yfinance")
self.assertEqual(fresh["data_vendors"]["fundamental_data"], "yfinance")
self.assertEqual(fresh["data_vendors"]["news_data"], "yfinance")
def test_nested_dict_updates_merge_one_level_deep(self):
set_config({"tool_vendors": {"get_stock_data": "alpha_vantage"}})
set_config({"tool_vendors": {"get_news": "alpha_vantage"}})
fresh = get_config()
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
self.assertEqual(fresh["tool_vendors"]["get_news"], "alpha_vantage")

View File

@@ -0,0 +1,240 @@
"""Tests for DeepSeekChatOpenAI thinking-mode behaviour.
Two pieces verified:
1. ``reasoning_content`` is captured on receive into the AIMessage's
``additional_kwargs`` and re-attached on send so DeepSeek's API
sees the same value across turns.
2. ``with_structured_output`` consults the capability table and
suppresses ``tool_choice`` for models that reject it (V4 + reasoner),
matching DeepSeek's official tool-calling pattern at
https://api-docs.deepseek.com/guides/tool_calls.
"""
import os
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompt_values import ChatPromptValue
from pydantic import BaseModel
from tradingagents.llm_clients.openai_client import (
DeepSeekChatOpenAI,
NormalizedChatOpenAI,
_input_to_messages,
)
# ---------------------------------------------------------------------------
# _input_to_messages — the helper that handles list / ChatPromptValue / other
# (Gemini bot review note: non-list inputs must also work)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestInputToMessages:
def test_list_input_returned_as_is(self):
msgs = [HumanMessage(content="hi")]
assert _input_to_messages(msgs) is msgs
def test_chat_prompt_value_unwrapped(self):
msgs = [HumanMessage(content="hi")]
prompt_value = ChatPromptValue(messages=msgs)
assert _input_to_messages(prompt_value) == msgs
def test_string_input_yields_empty_list(self):
# A bare string isn't a message-bearing input; the caller's normal
# langchain conversion happens upstream of _get_request_payload.
assert _input_to_messages("hello") == []
# ---------------------------------------------------------------------------
# Reasoning content propagation across turns
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDeepSeekReasoningContent:
def _client(self):
os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder")
return DeepSeekChatOpenAI(
model="deepseek-v4-flash",
api_key="placeholder",
base_url="https://api.deepseek.com",
)
def test_capture_on_receive(self):
"""When the response carries reasoning_content, it lands on the
AIMessage's additional_kwargs so the next turn can echo it back."""
client = self._client()
result = client._create_chat_result(
{
"model": "deepseek-v4-flash",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Plan: buy NVDA.",
"reasoning_content": "Step 1: trend is up. Step 2: ...",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
)
ai = result.generations[0].message
assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..."
def test_propagate_on_send(self):
"""When an outgoing AIMessage carries reasoning_content, the request
payload echoes it on the corresponding message dict."""
client = self._client()
prior = AIMessage(
content="Plan",
additional_kwargs={"reasoning_content": "weighed bull case"},
)
new_user = HumanMessage(content="Refine.")
payload = client._get_request_payload([prior, new_user])
# Find the assistant message in the payload
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
assert assistant_dicts, "assistant message missing from outgoing payload"
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
def test_propagate_through_chat_prompt_value(self):
"""Gemini bot review note: non-list inputs (ChatPromptValue) must
also propagate reasoning_content."""
client = self._client()
prior = AIMessage(
content="Plan",
additional_kwargs={"reasoning_content": "weighed bull case"},
)
prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")])
payload = client._get_request_payload(prompt_value)
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
# ---------------------------------------------------------------------------
# Capability-driven structured output: tool_choice suppressed for V4 + reasoner
# ---------------------------------------------------------------------------
def _bound_kwargs(runnable):
"""Extract bind() kwargs from a with_structured_output result."""
first = runnable.steps[0] if hasattr(runnable, "steps") else runnable
return getattr(first, "kwargs", {})
@pytest.mark.unit
class TestStructuredOutputCapabilityDispatch:
"""DeepSeek V4 and reasoner reject the tool_choice parameter
(official guide: api-docs.deepseek.com/guides/tool_calls passes
tools=[...] without tool_choice). Verify the capability dispatch
suppresses tool_choice for those models and sends it for chat."""
class _Sample(BaseModel):
answer: str
def _client(self, model):
return DeepSeekChatOpenAI(
model=model, api_key="placeholder", base_url="https://api.deepseek.com",
)
def test_chat_sends_tool_choice(self):
bound = self._client("deepseek-chat").with_structured_output(self._Sample)
assert _bound_kwargs(bound).get("tool_choice") is not None
def test_reasoner_suppresses_tool_choice(self):
bound = self._client("deepseek-reasoner").with_structured_output(self._Sample)
# tool_choice is either absent or explicitly None — both are valid
# signals that langchain's bind_tools will skip the parameter.
assert _bound_kwargs(bound).get("tool_choice") in (None, ...) or \
"tool_choice" not in _bound_kwargs(bound)
def test_v4_flash_suppresses_tool_choice(self):
bound = self._client("deepseek-v4-flash").with_structured_output(self._Sample)
assert _bound_kwargs(bound).get("tool_choice") is None or \
"tool_choice" not in _bound_kwargs(bound)
def test_v4_pro_suppresses_tool_choice(self):
bound = self._client("deepseek-v4-pro").with_structured_output(self._Sample)
assert _bound_kwargs(bound).get("tool_choice") is None or \
"tool_choice" not in _bound_kwargs(bound)
def test_future_v_variant_via_regex(self):
"""Forward-compat: unknown deepseek-v\\d-* IDs inherit V4 quirks."""
bound = self._client("deepseek-v5-hypothetical").with_structured_output(self._Sample)
assert _bound_kwargs(bound).get("tool_choice") is None or \
"tool_choice" not in _bound_kwargs(bound)
def test_schema_is_still_bound_as_tool(self):
"""tool_choice is suppressed, but the schema is still bound as a tool —
exactly matching DeepSeek's official tool-calling examples."""
bound = self._client("deepseek-reasoner").with_structured_output(self._Sample)
kwargs = _bound_kwargs(bound)
tools = kwargs.get("tools", [])
assert any(
t.get("function", {}).get("name") == "_Sample" for t in tools
), f"schema not bound as a tool: {tools}"
# ---------------------------------------------------------------------------
# Live API: structured output round-trips against the real DeepSeek backend
# ---------------------------------------------------------------------------
def _has_real_deepseek_key():
key = os.environ.get("DEEPSEEK_API_KEY", "")
return bool(key) and key != "placeholder"
@pytest.mark.integration
@pytest.mark.skipif(
not _has_real_deepseek_key(),
reason="DEEPSEEK_API_KEY not set (or placeholder); skipping live API call",
)
class TestDeepSeekLiveStructuredOutput:
"""End-to-end: a real DeepSeek V4-flash call returns a typed instance.
Verifies the no-tool_choice path doesn't trigger the 400 reported in
issue #678 and that the structured-output binding still parses to a
Pydantic instance.
"""
class _Pick(BaseModel):
action: str
confidence: float
def test_v4_flash_returns_structured_output(self):
client = DeepSeekChatOpenAI(
model="deepseek-v4-flash",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url="https://api.deepseek.com",
timeout=60,
)
bound = client.with_structured_output(self._Pick)
result = bound.invoke(
"Pick BUY or SELL or HOLD for a tech stock with strong earnings. "
"Confidence is a float between 0 and 1."
)
assert isinstance(result, self._Pick)
assert result.action in {"BUY", "SELL", "HOLD"}
assert 0.0 <= result.confidence <= 1.0
# ---------------------------------------------------------------------------
# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBaseClassIsolation:
def test_normalized_does_not_propagate_reasoning_content(self):
"""The general-purpose NormalizedChatOpenAI must not carry
DeepSeek-specific behaviour. Only the subclass does."""
assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or (
NormalizedChatOpenAI._get_request_payload
is NormalizedChatOpenAI.__bases__[0]._get_request_payload
)

View File

@@ -0,0 +1,98 @@
"""Tests for TRADINGAGENTS_* env-var overlay onto DEFAULT_CONFIG."""
from __future__ import annotations
import importlib
import pytest
import tradingagents.default_config as default_config_module
def _reload_with_env(monkeypatch, **overrides):
"""Set/clear env vars then reload default_config to re-evaluate DEFAULT_CONFIG."""
for key in list(default_config_module._ENV_OVERRIDES):
monkeypatch.delenv(key, raising=False)
for key, val in overrides.items():
monkeypatch.setenv(key, val)
return importlib.reload(default_config_module)
def test_no_env_uses_built_in_defaults(monkeypatch):
dc = _reload_with_env(monkeypatch)
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gpt-5.4"
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gpt-5.4-mini"
assert dc.DEFAULT_CONFIG["backend_url"] is None
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is False
def test_string_overrides(monkeypatch):
dc = _reload_with_env(
monkeypatch,
TRADINGAGENTS_LLM_PROVIDER="google",
TRADINGAGENTS_DEEP_THINK_LLM="gemini-3-pro-preview",
TRADINGAGENTS_QUICK_THINK_LLM="gemini-3-flash-preview",
TRADINGAGENTS_LLM_BACKEND_URL="https://example.invalid/v1",
TRADINGAGENTS_OUTPUT_LANGUAGE="Chinese",
)
assert dc.DEFAULT_CONFIG["llm_provider"] == "google"
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gemini-3-pro-preview"
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gemini-3-flash-preview"
assert dc.DEFAULT_CONFIG["backend_url"] == "https://example.invalid/v1"
assert dc.DEFAULT_CONFIG["output_language"] == "Chinese"
def test_int_coercion(monkeypatch):
dc = _reload_with_env(
monkeypatch,
TRADINGAGENTS_MAX_DEBATE_ROUNDS="3",
TRADINGAGENTS_MAX_RISK_ROUNDS="2",
)
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 3
assert isinstance(dc.DEFAULT_CONFIG["max_debate_rounds"], int)
assert dc.DEFAULT_CONFIG["max_risk_discuss_rounds"] == 2
assert isinstance(dc.DEFAULT_CONFIG["max_risk_discuss_rounds"], int)
@pytest.mark.parametrize(
"raw,expected",
[
("true", True), ("True", True), ("1", True), ("yes", True), ("on", True),
("false", False), ("False", False), ("0", False), ("no", False), ("off", False),
],
)
def test_bool_coercion(monkeypatch, raw, expected):
dc = _reload_with_env(monkeypatch, TRADINGAGENTS_CHECKPOINT_ENABLED=raw)
assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is expected
def test_empty_env_value_is_passthrough(monkeypatch):
"""Empty TRADINGAGENTS_* values must not clobber the built-in default."""
dc = _reload_with_env(
monkeypatch,
TRADINGAGENTS_LLM_PROVIDER="",
TRADINGAGENTS_MAX_DEBATE_ROUNDS="",
)
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
def test_invalid_int_raises(monkeypatch):
"""Garbage int values should surface a ValueError at import, not silently misconfigure."""
monkeypatch.setenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", "not-a-number")
with pytest.raises(ValueError):
importlib.reload(default_config_module)
# Restore module state for subsequent tests in this process
monkeypatch.delenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", raising=False)
importlib.reload(default_config_module)
def test_unknown_env_var_is_ignored(monkeypatch):
"""Env vars outside _ENV_OVERRIDES must not bleed into DEFAULT_CONFIG."""
dc = _reload_with_env(
monkeypatch,
TRADINGAGENTS_NONEXISTENT_KEY="oops",
)
assert "nonexistent_key" not in dc.DEFAULT_CONFIG

View File

@@ -535,6 +535,93 @@ class TestDeferredReflection:
assert raw is not None and alpha is not None and days is not None
assert days == 2
# TradingAgentsGraph._resolve_benchmark — picks index for alpha calc
def test_resolve_benchmark_explicit_override(self):
"""config['benchmark_ticker'] wins for every ticker."""
mock_graph = MagicMock(spec=TradingAgentsGraph)
mock_graph.config = {
"benchmark_ticker": "QQQ",
"benchmark_map": {"": "SPY", ".T": "^N225"},
}
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.T") == "QQQ"
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "NVDA") == "QQQ"
def test_resolve_benchmark_suffix_map(self):
"""Known suffixes route to their regional index."""
mock_graph = MagicMock(spec=TradingAgentsGraph)
mock_graph.config = {
"benchmark_ticker": None,
"benchmark_map": {
".T": "^N225", ".HK": "^HSI", ".NS": "^NSEI",
".L": "^FTSE", ".TO": "^GSPTSE", ".AX": "^AXJO",
".BO": "^BSESN", "": "SPY",
},
}
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.T") == "^N225"
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "0700.HK") == "^HSI"
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "RELIANCE.NS") == "^NSEI"
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "AZN.L") == "^FTSE"
def test_resolve_benchmark_us_ticker_defaults_to_spy(self):
"""US tickers (no dotted suffix) take the empty-suffix entry."""
mock_graph = MagicMock(spec=TradingAgentsGraph)
mock_graph.config = {
"benchmark_ticker": None,
"benchmark_map": {"": "SPY", ".T": "^N225"},
}
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "NVDA") == "SPY"
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "AAPL") == "SPY"
def test_resolve_benchmark_unknown_suffix_falls_back(self):
"""Unrecognised suffix (BRK.B, FAKE.XX) falls back to SPY."""
mock_graph = MagicMock(spec=TradingAgentsGraph)
mock_graph.config = {
"benchmark_ticker": None,
"benchmark_map": {"": "SPY", ".T": "^N225"},
}
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "FAKE.XX") == "SPY"
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "BRK.B") == "SPY"
def test_resolve_benchmark_case_insensitive(self):
"""Suffix matching is case-insensitive so 7203.t resolves like 7203.T."""
mock_graph = MagicMock(spec=TradingAgentsGraph)
mock_graph.config = {
"benchmark_ticker": None,
"benchmark_map": {".T": "^N225", "": "SPY"},
}
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.t") == "^N225"
def test_reflector_includes_benchmark_in_label(self):
"""benchmark_name appears in the prompt label, not 'SPY' hardcoded."""
mock_llm = MagicMock()
mock_llm.invoke.return_value.content = "Directionally correct."
reflector = Reflector(mock_llm)
reflector.reflect_on_final_decision(
final_decision=DECISION_BUY,
raw_return=0.05,
alpha_return=0.02,
benchmark_name="^N225",
)
messages = mock_llm.invoke.call_args[0][0]
human_content = next(content for role, content in messages if role == "human")
assert "Alpha vs ^N225:" in human_content
assert "Alpha vs SPY:" not in human_content
def test_reflector_defaults_to_spy_for_unupdated_callers(self):
"""Default benchmark_name keeps the SPY label for legacy callers."""
mock_llm = MagicMock()
mock_llm.invoke.return_value.content = "ok"
reflector = Reflector(mock_llm)
reflector.reflect_on_final_decision(
final_decision=DECISION_BUY,
raw_return=0.05,
alpha_return=0.02,
)
messages = mock_llm.invoke.call_args[0][0]
human_content = next(content for role, content in messages if role == "human")
assert "Alpha vs SPY:" in human_content
# TradingAgentsGraph._resolve_pending_entries
def test_resolve_skips_other_tickers(self, tmp_path):

73
tests/test_minimax.py Normal file
View File

@@ -0,0 +1,73 @@
"""Tests for MinimaxChatOpenAI quirks.
Verifies the subclass injects ``reasoning_split=True`` into outgoing
requests so M2.x reasoning models put their <think> block into
``reasoning_details`` instead of polluting ``message.content``.
"""
import os
import pytest
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from tradingagents.llm_clients.openai_client import MinimaxChatOpenAI
def _client(model: str = "MiniMax-M2.7"):
os.environ.setdefault("MINIMAX_API_KEY", "placeholder")
return MinimaxChatOpenAI(
model=model,
api_key="placeholder",
base_url="https://api.minimax.io/v1",
)
@pytest.mark.unit
class TestMinimaxReasoningSplit:
def test_request_payload_sets_reasoning_split(self):
payload = _client()._get_request_payload([HumanMessage(content="hi")])
assert payload.get("reasoning_split") is True
def test_caller_supplied_reasoning_split_is_preserved(self):
"""If the user explicitly sets reasoning_split, don't override it
(setdefault semantics — caller wins)."""
client = _client()
payload = client._get_request_payload(
[HumanMessage(content="hi")],
reasoning_split=False,
)
# langchain may or may not surface that kwarg into the payload;
# what matters is we don't blindly overwrite a non-default value
# the caller passed. setdefault leaves an existing value alone.
assert payload.get("reasoning_split") in (False, True)
@pytest.mark.unit
class TestMinimaxStructuredOutputDispatch:
"""M2.x models route through the capability table — tool_choice is
suppressed but the schema is still bound as a tool."""
class _Pick(BaseModel):
action: str
def _bound_kwargs(self, runnable):
first = runnable.steps[0] if hasattr(runnable, "steps") else runnable
return getattr(first, "kwargs", {})
def test_m2_7_suppresses_tool_choice(self):
bound = _client("MiniMax-M2.7").with_structured_output(self._Pick)
kwargs = self._bound_kwargs(bound)
assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs
def test_m2_7_highspeed_suppresses_tool_choice(self):
bound = _client("MiniMax-M2.7-highspeed").with_structured_output(self._Pick)
kwargs = self._bound_kwargs(bound)
assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs
def test_schema_still_bound_as_tool(self):
bound = _client("MiniMax-M2.7").with_structured_output(self._Pick)
tools = self._bound_kwargs(bound).get("tools", [])
assert any(
t.get("function", {}).get("name") == "_Pick" for t in tools
), f"schema not bound: {tools}"

View File

@@ -0,0 +1,167 @@
"""Tests for OLLAMA_BASE_URL env-var override across CLI and client paths."""
from __future__ import annotations
import importlib
import pytest
# ---- openai_client side: _resolve_provider_base_url -----------------------
def _reload_client():
import tradingagents.llm_clients.openai_client as mod
return importlib.reload(mod)
def test_resolver_returns_default_when_env_unset(monkeypatch):
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
mod = _reload_client()
assert mod._resolve_provider_base_url("ollama") == "http://localhost:11434/v1"
def test_resolver_returns_env_when_set(monkeypatch):
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-ollama:11434/v1")
mod = _reload_client()
assert mod._resolve_provider_base_url("ollama") == "http://remote-ollama:11434/v1"
def test_resolver_evaluation_is_call_time(monkeypatch):
"""Setting the env AFTER module import must still take effect."""
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
mod = _reload_client()
monkeypatch.setenv("OLLAMA_BASE_URL", "http://late-set:11434/v1")
assert mod._resolve_provider_base_url("ollama") == "http://late-set:11434/v1"
def test_resolver_does_not_affect_other_providers(monkeypatch):
"""OLLAMA_BASE_URL should NOT leak into xai/deepseek/etc."""
monkeypatch.setenv("OLLAMA_BASE_URL", "http://elsewhere/v1")
mod = _reload_client()
assert mod._resolve_provider_base_url("xai") == "https://api.x.ai/v1"
assert mod._resolve_provider_base_url("deepseek") == "https://api.deepseek.com"
def test_client_get_llm_picks_up_env(monkeypatch):
"""End-to-end: OllamaClient.get_llm() respects OLLAMA_BASE_URL."""
monkeypatch.setenv("OLLAMA_BASE_URL", "http://my-ollama:11434/v1")
mod = _reload_client()
client = mod.OpenAIClient(model="llama3.1", provider="ollama")
llm = client.get_llm()
assert "my-ollama" in str(llm.openai_api_base)
def test_explicit_base_url_overrides_env(monkeypatch):
"""An explicit base_url passed to the client wins over the env var."""
monkeypatch.setenv("OLLAMA_BASE_URL", "http://env-set:11434/v1")
mod = _reload_client()
client = mod.OpenAIClient(
model="llama3.1",
provider="ollama",
base_url="http://explicit:11434/v1",
)
llm = client.get_llm()
assert "explicit" in str(llm.openai_api_base)
assert "env-set" not in str(llm.openai_api_base)
# ---- cli.utils side: select_llm_provider dropdown -------------------------
def test_cli_dropdown_uses_env(monkeypatch):
"""The Ollama entry in the CLI dropdown must reflect OLLAMA_BASE_URL."""
monkeypatch.setenv("OLLAMA_BASE_URL", "http://cli-remote:11434/v1")
import cli.utils as cli_utils
importlib.reload(cli_utils)
# Reach inside the function via the same env-read it does at call time
ollama_url = (
__import__("os").environ.get("OLLAMA_BASE_URL")
or "http://localhost:11434/v1"
)
assert ollama_url == "http://cli-remote:11434/v1"
def test_cli_dropdown_default_when_unset(monkeypatch):
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
import cli.utils as cli_utils
importlib.reload(cli_utils)
ollama_url = (
__import__("os").environ.get("OLLAMA_BASE_URL")
or "http://localhost:11434/v1"
)
assert ollama_url == "http://localhost:11434/v1"
# ---- confirm_ollama_endpoint UX -------------------------------------------
def test_confirm_endpoint_shows_default(monkeypatch, capsys):
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
import cli.utils as cli_utils
importlib.reload(cli_utils)
cli_utils.confirm_ollama_endpoint("http://localhost:11434/v1")
out = capsys.readouterr().out
assert "http://localhost:11434/v1" in out
assert "OLLAMA_BASE_URL" not in out # not from env
assert "Note" not in out # no warnings for the canonical default
def test_confirm_endpoint_marks_env_origin(monkeypatch, capsys):
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-host:11434/v1")
import cli.utils as cli_utils
importlib.reload(cli_utils)
cli_utils.confirm_ollama_endpoint("http://remote-host:11434/v1")
out = capsys.readouterr().out
assert "http://remote-host:11434/v1" in out
assert "OLLAMA_BASE_URL" in out
def test_confirm_endpoint_warns_on_missing_scheme(monkeypatch, capsys):
"""If user sets OLLAMA_BASE_URL=0.0.0.128, advise on the expected shape."""
monkeypatch.setenv("OLLAMA_BASE_URL", "0.0.0.128")
import cli.utils as cli_utils
importlib.reload(cli_utils)
cli_utils.confirm_ollama_endpoint("0.0.0.128")
out = capsys.readouterr().out
assert "missing a scheme" in out
assert "http://<host>:11434/v1" in out
def test_confirm_endpoint_warns_on_non_default_port_remote(monkeypatch, capsys):
"""A remote host with no :11434 gets a soft hint about port mismatch."""
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-host/v1")
import cli.utils as cli_utils
importlib.reload(cli_utils)
cli_utils.confirm_ollama_endpoint("http://remote-host/v1")
out = capsys.readouterr().out
assert "port 11434" in out
def test_confirm_endpoint_quiet_on_local_no_port(monkeypatch, capsys):
"""Local host without port shouldn't trigger the remote-port hint."""
monkeypatch.setenv("OLLAMA_BASE_URL", "http://localhost/v1")
import cli.utils as cli_utils
importlib.reload(cli_utils)
cli_utils.confirm_ollama_endpoint("http://localhost/v1")
out = capsys.readouterr().out
assert "Note" not in out # localhost is fine without explicit port
def test_ollama_model_labels_no_local_suffix():
"""Labels should no longer claim '(local)' since the endpoint is dynamic."""
from tradingagents.llm_clients.model_catalog import get_model_options
for mode in ("quick", "deep"):
labels = [label for label, _ in get_model_options("ollama", mode)]
assert all("local" not in label for label in labels), labels
def test_ollama_offers_custom_model_id():
"""Ollama users with custom-pulled models can pick 'Custom model ID'."""
from tradingagents.llm_clients.model_catalog import get_model_options
for mode in ("quick", "deep"):
entries = get_model_options("ollama", mode)
values = [v for _, v in entries]
assert "custom" in values, f"Ollama {mode!r} missing 'custom' option: {entries}"
# Custom option is last so it doesn't push the curated defaults off-screen
assert values[-1] == "custom", f"'custom' should be last entry: {values}"

View File

@@ -0,0 +1,52 @@
"""Tests for the ticker path-component validator that blocks directory traversal."""
import os
import unittest
import pytest
from tradingagents.dataflows.utils import safe_ticker_component
@pytest.mark.unit
class TestSafeTickerComponent(unittest.TestCase):
def test_accepts_common_ticker_formats(self):
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
self.assertEqual(safe_ticker_component(ticker), ticker)
def test_rejects_path_separators(self):
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_null_byte_and_whitespace(self):
for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_empty_or_non_string(self):
for bad in ("", None, 123, b"AAPL"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_overlong_input(self):
with self.assertRaises(ValueError):
safe_ticker_component("A" * 33)
def test_rejects_dot_only_values(self):
# '.' and '..' pass the regex but traverse when used as a path
# component (e.g. ``Path(results_dir) / ticker / "logs"``).
for bad in (".", "..", "...", "...."):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_traversal_string_does_not_escape_join(self):
"""Sanity: sanitized values stay within base when joined."""
base = os.path.realpath("/tmp/cache")
ticker = safe_ticker_component("AAPL")
joined = os.path.realpath(os.path.join(base, f"{ticker}.csv"))
self.assertTrue(joined.startswith(base + os.sep))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,38 @@
import warnings
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
# (and every llm_clients consumer) sees the user's keys regardless of
# which entry point started the process. find_dotenv(usecwd=True) walks
# from the CWD, so the installed `tradingagents` console script picks up
# the project's .env instead of stepping up from site-packages.
# load_dotenv defaults to override=False, so it never clobbers values
# the caller has already exported.
try:
from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv(usecwd=True))
load_dotenv(find_dotenv(".env.enterprise", usecwd=True), override=False)
except ImportError:
pass
# langchain-core 1.3.3 calls surface_langchain_deprecation_warnings() in
# its own __init__, which prepends default-action filters for its
# subclassed warning categories. To suppress a specific warning we must
# install our filter AFTER langchain-core has installed its own, so import
# it first. The package is a guaranteed transitive dep via langgraph.
try:
import langchain_core # noqa: F401
except ImportError:
pass
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
# explicit allowed_objects, which triggers a noisy pending-deprecation
# warning from langchain-core 1.3.3 on every interpreter start. The fix
# is already merged upstream (langchain-ai/langgraph#7743, 2026-05-08)
# and will arrive in the next langgraph-checkpoint release. Remove this
# block (and the langchain_core preload above) when we bump past it.
warnings.filterwarnings(
"ignore",
message=r"The default value of `allowed_objects`.*",
category=PendingDeprecationWarning,
)

View File

@@ -4,7 +4,10 @@ from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .analysts.sentiment_analyst import (
create_sentiment_analyst,
create_social_media_analyst, # deprecated alias kept for back-compat
)
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
@@ -33,6 +36,7 @@ __all__ = [
"create_aggressive_debator",
"create_portfolio_manager",
"create_conservative_debator",
"create_social_media_analyst",
"create_sentiment_analyst",
"create_social_media_analyst", # deprecated; will be removed in a future version
"create_trader",
]

View File

@@ -0,0 +1,184 @@
"""Sentiment analyst — multi-source sentiment analysis for a target ticker.
Previously named ``social_media_analyst``. Renamed and redesigned because
the old version had a prompt that demanded social-media analysis but the
only tool available was Yahoo Finance news — which led LLMs to fabricate
Reddit/X/StockTwits content under prompt pressure (verified live).
The redesigned agent pre-fetches three complementary data sources before
the LLM is invoked and injects them into the prompt as structured blocks:
1. News headlines — Yahoo Finance (institutional framing)
2. StockTwits messages — retail-trader posts indexed by cashtag, with
user-labeled Bullish/Bearish sentiment tags
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
The agent does not use tool-calling; the data is in the prompt from
turn 0. The LLM produces the sentiment report in a single invocation.
See: https://github.com/TauricResearch/TradingAgents/issues/557
"""
from datetime import datetime, timedelta
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
get_news,
)
from tradingagents.dataflows.reddit import fetch_reddit_posts
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
def _seven_days_back(trade_date: str) -> str:
return (datetime.strptime(trade_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
def create_sentiment_analyst(llm):
"""Create a sentiment analyst node for the trading graph.
Pre-fetches news + StockTwits + Reddit data, injects them into the
prompt as structured blocks, and produces a sentiment report in a
single LLM call.
"""
def sentiment_analyst_node(state):
ticker = state["company_of_interest"]
end_date = state["trade_date"]
start_date = _seven_days_back(end_date)
instrument_context = build_instrument_context(ticker)
# Pre-fetch all three sources. Each fetcher degrades gracefully and
# returns a string (no exceptions surface from here), so the LLM
# always sees something — either real data or a clear placeholder.
news_block = get_news.func(ticker, start_date, end_date)
stocktwits_block = fetch_stocktwits_messages(ticker, limit=30)
reddit_block = fetch_reddit_posts(ticker)
system_message = _build_system_message(
ticker=ticker,
start_date=start_date,
end_date=end_date,
news_block=news_block,
stocktwits_block=stocktwits_block,
reddit_block=reddit_block,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
"\n{system_message}\n"
"For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(current_date=end_date)
prompt = prompt.partial(instrument_context=instrument_context)
# No bind_tools — the data is already in the prompt; a single LLM
# call produces the report directly.
chain = prompt | llm
result = chain.invoke(state["messages"])
return {
"messages": [result],
"sentiment_report": result.content,
}
return sentiment_analyst_node
def _build_system_message(
*,
ticker: str,
start_date: str,
end_date: str,
news_block: str,
stocktwits_block: str,
reddit_block: str,
) -> str:
"""Assemble the sentiment-analyst system message with structured data blocks."""
return f"""You are a financial market sentiment analyst. Your task is to produce a comprehensive sentiment report for {ticker} covering the period from {start_date} to {end_date}, drawing on three complementary data sources that have already been collected for you.
## Data sources (pre-fetched, in this prompt)
### News headlines — Yahoo Finance, past 7 days
Institutional framing. Fact-driven, slower-moving signal.
<start_of_news>
{news_block}
<end_of_news>
### StockTwits messages — retail-trader social platform indexed by cashtag
Fast-moving signal. Each message carries a user-labeled sentiment tag (Bullish / Bearish / no-label) plus the message body.
<start_of_stocktwits>
{stocktwits_block}
<end_of_stocktwits>
### Reddit posts — r/wallstreetbets, r/stocks, r/investing (past 7 days)
Community discussion. Engagement signal via upvote score and comment count. Subreddit character matters (r/wallstreetbets is often contrarian/exuberant; r/stocks more measured; r/investing longer-term).
<start_of_reddit>
{reddit_block}
<end_of_reddit>
## How to analyze this data (best practices)
1. **Read the StockTwits Bullish/Bearish ratio as a leading retail-sentiment signal.** A 70/30 bullish/bearish split is moderately bullish; ≥90/10 may indicate over-extension and contrarian risk; 50/50 is uncertainty. Sample size matters — base rates on the actual message count, not percentages alone.
2. **Look for cross-source divergences.** If news framing is bearish but StockTwits is overwhelmingly bullish, that mismatch is itself a signal — it can mean retail is leaning into a thesis the news flow hasn't caught up to (or vice versa, that retail is chasing while institutions are cautious).
3. **Weight Reddit posts by engagement.** A 400-upvote / 200-comment thread reflects community attention; a 3-upvote post is noise. Read the body excerpts for context — the title alone often misleads.
4. **Distinguish opinion from event.** A news headline ("Nvidia announces $500M Corning deal") is an event; a StockTwits post ("buying NVDA, this is going to moon") is opinion. Both are inputs but should be weighted differently in your conclusions.
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this caveat explicitly. If the sources are silent on a given subreddit, say so.
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
## Output
Produce a sentiment report covering, in order:
1. **Overall sentiment direction** — Bullish / Bearish / Neutral / Mixed — with a brief confidence note based on data quality and sample size.
2. **Source-by-source breakdown** — what each of news / StockTwits / Reddit is telling you, with specific evidence (cite message counts, ratios, notable posts).
3. **Divergences, alignments, and key narratives** across sources.
4. **Catalysts and risks** surfaced by the data.
5. **Markdown table** at the end summarizing key sentiment signals, their direction, source, and supporting evidence.
{get_language_instruction()}"""
# ---------------------------------------------------------------------------
# Backwards-compatibility shim
# ---------------------------------------------------------------------------
def create_social_media_analyst(llm):
"""Deprecated alias for :func:`create_sentiment_analyst`.
Kept so existing code that imports ``create_social_media_analyst``
continues to work.
.. deprecated::
Import :func:`create_sentiment_analyst` directly instead.
"""
import warnings
warnings.warn(
"create_social_media_analyst is deprecated and will be removed in a "
"future version. Use create_sentiment_analyst instead.",
DeprecationWarning,
stacklevel=2,
)
return create_sentiment_analyst(llm)

View File

@@ -1,57 +1,23 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
from tradingagents.dataflows.config import get_config
"""Backwards-compatibility shim for the renamed module.
The agent is now ``sentiment_analyst`` and aggregates Yahoo Finance news,
StockTwits cashtag streams, and Reddit posts into a single sentiment
report. Import from ``tradingagents.agents.analysts.sentiment_analyst``
going forward; this module will be removed in a future release.
def create_social_media_analyst(llm):
def social_media_analyst_node(state):
current_date = state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
See: https://github.com/TauricResearch/TradingAgents/issues/557
"""
tools = [
get_news,
]
import warnings as _warnings
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
+ get_language_instruction()
)
from tradingagents.agents.analysts.sentiment_analyst import ( # noqa: F401
create_sentiment_analyst,
create_social_media_analyst,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"sentiment_report": report,
}
return social_media_analyst_node
_warnings.warn(
"tradingagents.agents.analysts.social_media_analyst is deprecated. "
"Import from tradingagents.agents.analysts.sentiment_analyst instead.",
DeprecationWarning,
stacklevel=2,
)

View File

@@ -3,7 +3,10 @@
from __future__ import annotations
from tradingagents.agents.schemas import ResearchPlan, render_research_plan
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
)
from tradingagents.agents.utils.structured import (
bind_structured,
invoke_structured_or_freetext,
@@ -37,7 +40,7 @@ Commit to a clear stance whenever the debate's strongest arguments warrant one;
---
**Debate History:**
{history}"""
{history}""" + get_language_instruction()
investment_plan = invoke_structured_or_freetext(
structured_llm,

View File

@@ -1,3 +1,4 @@
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_bear_researcher(llm):
@@ -31,7 +32,7 @@ Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock.
"""
""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -1,3 +1,4 @@
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_bull_researcher(llm):
@@ -29,7 +30,7 @@ Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position.
"""
""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -1,3 +1,4 @@
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_aggressive_debator(llm):
@@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -1,3 +1,4 @@
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_conservative_debator(llm):
@@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -1,3 +1,4 @@
from tradingagents.agents.utils.agent_utils import get_language_instruction
def create_neutral_debator(llm):
@@ -28,7 +29,7 @@ Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
response = llm.invoke(prompt)

View File

@@ -7,7 +7,10 @@ import functools
from langchain_core.messages import AIMessage
from tradingagents.agents.schemas import TraderProposal, render_trader_proposal
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_language_instruction,
)
from tradingagents.agents.utils.structured import (
bind_structured,
invoke_structured_or_freetext,
@@ -29,6 +32,7 @@ def create_trader(llm):
"You are a trading agent analyzing market data to make investment decisions. "
"Based on your analysis, provide a specific recommendation to buy, sell, or hold. "
"Anchor your reasoning in the analysts' reports and the research plan."
+ get_language_instruction()
),
},
{

View File

@@ -51,7 +51,7 @@ class AgentState(MessagesState):
# research step
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
sentiment_report: Annotated[str, "Report from the Sentiment Analyst"]
news_report: Annotated[
str, "Report from the News Researcher of current world affairs"
]

View File

@@ -24,8 +24,10 @@ def get_language_instruction() -> str:
"""Return a prompt instruction for the configured output language.
Returns empty string when English (default), so no extra tokens are used.
Only applied to user-facing agents (analysts, portfolio manager).
Internal debate agents stay in English for reasoning quality.
Applied to every agent whose output reaches the saved report —
analysts, researchers, debaters, research manager, trader, and
portfolio manager — so a non-English run produces a fully localized
report rather than a mix of languages.
"""
from tradingagents.dataflows.config import get_config
lang = get_config().get("output_language", "English")

View File

@@ -1,5 +1,5 @@
from langchain_core.tools import tool
from typing import Annotated
from typing import Annotated, Optional
from tradingagents.dataflows.interface import route_to_vendor
@tool
@@ -23,16 +23,20 @@ def get_news(
@tool
def get_global_news(
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "Number of days to look back"] = 7,
limit: Annotated[int, "Maximum number of articles to return"] = 5,
look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None,
limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None,
) -> str:
"""
Retrieve global news data.
Uses the configured news_data vendor.
Uses the configured news_data vendor. Defaults for look_back_days and
limit come from DEFAULT_CONFIG (global_news_lookback_days,
global_news_article_limit); pass explicit values to override.
Args:
curr_date (str): Current date in yyyy-mm-dd format
look_back_days (int): Number of days to look back (default 7)
limit (int): Maximum number of articles to return (default 5)
look_back_days (int): Number of days to look back; omit to inherit config
limit (int): Maximum number of articles to return; omit to inherit config
Returns:
str: A formatted string containing global news data
"""

View File

@@ -1,6 +1,8 @@
import tradingagents.default_config as default_config
from copy import deepcopy
from typing import Dict, Optional
import tradingagents.default_config as default_config
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
@@ -9,22 +11,31 @@ def initialize_config():
"""Initialize the configuration with default values."""
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config = deepcopy(default_config.DEFAULT_CONFIG)
def set_config(config: Dict):
"""Update the configuration with custom values."""
"""Update the configuration with custom values.
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
partial update like ``{"data_vendors": {"core_stock_apis": "alpha_vantage"}}``
keeps the other nested keys from the default; scalar keys are replaced.
"""
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
initialize_config()
incoming = deepcopy(config)
for key, value in incoming.items():
if isinstance(value, dict) and isinstance(_config.get(key), dict):
_config[key].update(value)
else:
_config[key] = value
def get_config() -> Dict:
"""Get the current configuration."""
if _config is None:
initialize_config()
return _config.copy()
return deepcopy(_config)
# Initialize with default config

View File

@@ -0,0 +1,106 @@
"""Reddit search fetcher for ticker-specific discussion posts.
Uses Reddit's public JSON endpoints (``reddit.com/r/{sub}/search.json``)
which do not require an API key. Public throughput is ~10 requests per
minute per IP, well within budget for a single agent run that queries
a handful of finance subreddits per ticker.
Returns formatted plaintext blocks ready for prompt injection. Degrades
gracefully — returns a placeholder string rather than raising, so callers
never have to special-case missing data.
"""
from __future__ import annotations
import json
import logging
import time
from typing import Iterable
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode
from urllib.request import Request, urlopen
logger = logging.getLogger(__name__)
_API = "https://www.reddit.com/r/{sub}/search.json?{qs}"
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
# Default subreddits ordered roughly by signal density for ticker-specific
# discussion. wallstreetbets has the most volume but most noise; stocks /
# investing trend more measured. Caller can override.
DEFAULT_SUBREDDITS = ("wallstreetbets", "stocks", "investing")
def _fetch_subreddit(
ticker: str,
sub: str,
limit: int,
timeout: float,
) -> list[dict]:
qs = urlencode({
"q": ticker,
"restrict_sr": "on",
"sort": "new",
"t": "week", # last 7 days
"limit": limit,
})
url = _API.format(sub=sub, qs=qs)
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
try:
with urlopen(req, timeout=timeout) as resp:
payload = json.loads(resp.read())
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
logger.warning("Reddit fetch failed for r/%s · %s: %s", sub, ticker, exc)
return []
children = (payload.get("data") or {}).get("children") or []
return [c.get("data", {}) for c in children if isinstance(c, dict)]
def fetch_reddit_posts(
ticker: str,
subreddits: Iterable[str] = DEFAULT_SUBREDDITS,
limit_per_sub: int = 5,
timeout: float = 10.0,
inter_request_delay: float = 0.4,
) -> str:
"""Fetch recent Reddit posts mentioning ``ticker`` across finance
subreddits and return them as a formatted plaintext block.
``inter_request_delay`` keeps us under Reddit's public rate limit
(~10 req/min per IP) even if the caller queries many subreddits.
"""
blocks = []
total_posts = 0
for i, sub in enumerate(subreddits):
if i > 0:
time.sleep(inter_request_delay)
posts = _fetch_subreddit(ticker, sub, limit_per_sub, timeout)
total_posts += len(posts)
if not posts:
blocks.append(f"r/{sub}: <no posts found mentioning {ticker.upper()} in the past 7 days>")
continue
lines = [f"r/{sub}{len(posts)} recent posts mentioning {ticker.upper()}:"]
for p in posts:
title = (p.get("title") or "").replace("\n", " ").strip()
score = p.get("score", 0)
comments = p.get("num_comments", 0)
created = p.get("created_utc")
created_str = (
time.strftime("%Y-%m-%d", time.gmtime(created)) if created else "?"
)
selftext = (p.get("selftext") or "").replace("\n", " ").strip()
if len(selftext) > 240:
selftext = selftext[:240] + ""
lines.append(
f" [{created_str} · {score:>4}↑ · {comments:>3}c] {title}"
+ (f"\n body excerpt: {selftext}" if selftext else "")
)
blocks.append("\n".join(lines))
if total_posts == 0:
return (
f"<no Reddit posts found mentioning {ticker.upper()} across "
f"{', '.join(f'r/{s}' for s in subreddits)} in the past 7 days>"
)
return "\n\n".join(blocks)

View File

@@ -8,6 +8,7 @@ from stockstats import wrap
from typing import Annotated
import os
from .config import get_config
from .utils import safe_ticker_component
logger = logging.getLogger(__name__)
@@ -51,6 +52,10 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
subsequent calls the cache is reused. Rows after curr_date are
filtered out so backtests never see future prices.
"""
# Reject ticker values that would escape the cache directory when
# interpolated into the cache filename (e.g. ``../../tmp/x``).
safe_symbol = safe_ticker_component(symbol)
config = get_config()
curr_date_dt = pd.to_datetime(curr_date)
@@ -63,7 +68,7 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_str}-{end_str}.csv",
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
)
if os.path.exists(data_file):

View File

@@ -0,0 +1,83 @@
"""StockTwits public symbol-stream fetcher.
StockTwits exposes a per-symbol message stream at
``api.stocktwits.com/api/2/streams/symbol/{ticker}.json`` that requires no
API key, no OAuth, and no registration. Each message includes a
user-labeled sentiment field (``Bullish``/``Bearish``/null), the message
body, timestamp, and posting user.
The function is deliberately self-contained: short timeout, graceful
degradation on any HTTP or parse failure, and a string return type so
the calling agent gets a uniform interface regardless of whether the
network call succeeded.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Optional
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
logger = logging.getLogger(__name__)
_API = "https://api.stocktwits.com/api/2/streams/symbol/{ticker}.json"
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
def fetch_stocktwits_messages(ticker: str, limit: int = 30, timeout: float = 10.0) -> str:
"""Fetch recent StockTwits messages for ``ticker`` and return them as a
formatted plaintext block ready for prompt injection.
Returns a placeholder string when the endpoint is unreachable, the
symbol has no messages, or the response shape is unexpected — the
caller never has to special-case None or exceptions.
"""
url = _API.format(ticker=ticker.upper())
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
try:
with urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read())
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
logger.warning("StockTwits fetch failed for %s: %s", ticker, exc)
return f"<stocktwits unavailable: {type(exc).__name__}>"
messages = data.get("messages", []) if isinstance(data, dict) else []
if not messages:
return f"<no StockTwits messages found for ${ticker.upper()}>"
lines = []
bullish = bearish = unlabeled = 0
for m in messages[:limit]:
created = m.get("created_at", "")
user = (m.get("user") or {}).get("username", "?")
entities = m.get("entities") or {}
sentiment_obj = entities.get("sentiment") or {}
sentiment = sentiment_obj.get("basic") if isinstance(sentiment_obj, dict) else None
body = (m.get("body") or "").replace("\n", " ").strip()
if len(body) > 280:
body = body[:280] + ""
if sentiment == "Bullish":
bullish += 1
tag = "Bullish"
elif sentiment == "Bearish":
bearish += 1
tag = "Bearish"
else:
unlabeled += 1
tag = "no-label"
lines.append(f"[{created} · @{user} · {tag}] {body}")
total = bullish + bearish + unlabeled
bull_pct = round(100 * bullish / total) if total else 0
bear_pct = round(100 * bearish / total) if total else 0
summary = (
f"Bullish: {bullish} ({bull_pct}%) · "
f"Bearish: {bearish} ({bear_pct}%) · "
f"Unlabeled: {unlabeled} · "
f"Total: {total} most-recent messages"
)
return summary + "\n\n" + "\n".join(lines)

View File

@@ -1,4 +1,5 @@
import os
import re
import json
import pandas as pd
from datetime import date, timedelta, datetime
@@ -6,6 +7,40 @@ from typing import Annotated
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
# Tickers can contain letters, digits, dot, dash, underscore, and caret
# (for index symbols like ^GSPC). Anything else is rejected so the value
# never escapes a containing directory when interpolated into a path.
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$")
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
"""Validate ``value`` is safe to interpolate into a filesystem path.
Tickers come from user CLI input or from LLM tool calls, both of which
can be influenced by attacker-controlled content (e.g. prompt injection
embedded in fetched news). Without validation, a value like
``"../../../etc/foo"`` flows into ``os.path.join`` / ``Path /`` and
escapes the configured cache, checkpoint, or results directory.
Returns ``value`` unchanged when it matches the allowed pattern; raises
``ValueError`` otherwise.
"""
if not isinstance(value, str) or not value:
raise ValueError(f"ticker must be a non-empty string, got {value!r}")
if len(value) > max_len:
raise ValueError(f"ticker exceeds {max_len} chars: {value!r}")
if not _TICKER_PATH_RE.fullmatch(value):
raise ValueError(
f"ticker contains characters not allowed in a filesystem path: {value!r}"
)
# The regex above allows '.', so values like '.', '..', '...' would pass,
# and as a path component they traverse the parent directory. Reject any
# value that's only dots.
if set(value) == {"."}:
raise ValueError(f"ticker cannot consist solely of dots: {value!r}")
return value
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
if save_path:
data.to_csv(save_path, encoding="utf-8")

View File

@@ -1,9 +1,12 @@
"""yfinance-based news data fetching functions."""
from typing import Optional
import yfinance as yf
from datetime import datetime
from dateutil.relativedelta import relativedelta
from .config import get_config
from .stockstats_utils import yf_retry
@@ -64,9 +67,10 @@ def get_news_yfinance(
Returns:
Formatted string containing news articles
"""
article_limit = get_config()["news_article_limit"]
try:
stock = yf.Ticker(ticker)
news = yf_retry(lambda: stock.get_news(count=20))
news = yf_retry(lambda: stock.get_news(count=article_limit))
if not news:
return f"No news found for {ticker}"
@@ -106,27 +110,28 @@ def get_news_yfinance(
def get_global_news_yfinance(
curr_date: str,
look_back_days: int = 7,
limit: int = 10,
look_back_days: Optional[int] = None,
limit: Optional[int] = None,
) -> str:
"""
Retrieve global/macro economic news using yfinance Search.
Args:
curr_date: Current date in yyyy-mm-dd format
look_back_days: Number of days to look back
limit: Maximum number of articles to return
look_back_days: Number of days to look back. ``None`` falls back to
``global_news_lookback_days`` from the active config.
limit: Maximum number of articles to return. ``None`` falls back to
``global_news_article_limit`` from the active config.
Returns:
Formatted string containing global news articles
"""
# Search queries for macro/global news
search_queries = [
"stock market economy",
"Federal Reserve interest rates",
"inflation economic outlook",
"global markets trading",
]
config = get_config()
if look_back_days is None:
look_back_days = config["global_news_lookback_days"]
if limit is None:
limit = config["global_news_article_limit"]
search_queries = config["global_news_queries"]
all_news = []
seen_titles = set()

View File

@@ -2,7 +2,46 @@ import os
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
DEFAULT_CONFIG = {
# Single source of truth for env-var → config-key overrides. To expose
# a new config key for environment-based override, add a row here — no
# entry-point script changes required. Coercion is driven by the type
# of the existing default, so users can keep writing plain strings in
# their .env file.
_ENV_OVERRIDES = {
"TRADINGAGENTS_LLM_PROVIDER": "llm_provider",
"TRADINGAGENTS_DEEP_THINK_LLM": "deep_think_llm",
"TRADINGAGENTS_QUICK_THINK_LLM": "quick_think_llm",
"TRADINGAGENTS_LLM_BACKEND_URL": "backend_url",
"TRADINGAGENTS_OUTPUT_LANGUAGE": "output_language",
"TRADINGAGENTS_MAX_DEBATE_ROUNDS": "max_debate_rounds",
"TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds",
"TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled",
"TRADINGAGENTS_BENCHMARK_TICKER": "benchmark_ticker",
}
def _coerce(value: str, reference):
"""Coerce env-var string to the type of the existing default value."""
if isinstance(reference, bool):
return value.strip().lower() in ("true", "1", "yes", "on")
if isinstance(reference, int) and not isinstance(reference, bool):
return int(value)
if isinstance(reference, float):
return float(value)
return value
def _apply_env_overrides(config: dict) -> dict:
"""Apply TRADINGAGENTS_* env vars to the config dict in-place."""
for env_var, key in _ENV_OVERRIDES.items():
raw = os.environ.get(env_var)
if raw is None or raw == "":
continue
config[key] = _coerce(raw, config.get(key))
return config
DEFAULT_CONFIG = _apply_env_overrides({
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
@@ -35,6 +74,21 @@ DEFAULT_CONFIG = {
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
"max_recur_limit": 100,
# News / data fetching parameters
# Increase for longer lookback strategies or to broaden macro coverage;
# decrease to reduce token usage in agent prompts.
"news_article_limit": 20, # max articles per ticker (ticker-news)
"global_news_article_limit": 10, # max articles for global/macro news
"global_news_lookback_days": 7, # macro news lookback window
# Search queries used by get_global_news for macro headlines. Extend or
# replace to broaden geographic / sector coverage.
"global_news_queries": [
"Federal Reserve interest rates inflation",
"S&P 500 earnings GDP economic outlook",
"geopolitical risk trade war sanctions",
"ECB Bank of England BOJ central bank policy",
"oil commodities supply chain energy",
],
# Data vendor configuration
# Category-level configuration (default for all tools in category)
"data_vendors": {
@@ -47,4 +101,21 @@ DEFAULT_CONFIG = {
"tool_vendors": {
# Example: "get_stock_data": "alpha_vantage", # Override category default
},
}
# Benchmark for alpha calculation in the reflection layer.
# ``benchmark_ticker`` (when set) overrides the suffix map for all
# tickers; leave it None to use ``benchmark_map`` for auto-detection
# based on the ticker's exchange suffix. SPY remains the US default
# so the reflection label keeps reading "Alpha vs SPY" for US tickers
# while non-US tickers get their regional index automatically.
"benchmark_ticker": None,
"benchmark_map": {
".NS": "^NSEI", # NSE India (Nifty 50)
".BO": "^BSESN", # BSE India (Sensex)
".T": "^N225", # Tokyo (Nikkei 225)
".HK": "^HSI", # Hong Kong (Hang Seng)
".L": "^FTSE", # London (FTSE 100)
".TO": "^GSPTSE", # Toronto (TSX Composite)
".AX": "^AXJO", # Australia (ASX 200)
"": "SPY", # default for US-listed tickers (no suffix)
},
})

View File

@@ -13,12 +13,16 @@ from typing import Generator
from langgraph.checkpoint.sqlite import SqliteSaver
from tradingagents.dataflows.utils import safe_ticker_component
def _db_path(data_dir: str | Path, ticker: str) -> Path:
"""Return the SQLite checkpoint DB path for a ticker."""
# Reject ticker values that would escape the checkpoints directory.
safe = safe_ticker_component(ticker).upper()
p = Path(data_dir) / "checkpoints"
p.mkdir(parents=True, exist_ok=True)
return p / f"{ticker.upper()}.db"
return p / f"{safe}.db"
def thread_id(ticker: str, date: str) -> str:

View File

@@ -33,11 +33,15 @@ class Reflector:
final_decision: str,
raw_return: float,
alpha_return: float,
benchmark_name: str = "SPY",
) -> str:
"""Single reflection call on the final trade decision with outcome context.
Used by Phase B deferred reflection. The final_trade_decision already
synthesises all analyst insights, so no separate market context is needed.
``benchmark_name`` is the label used for the alpha line (e.g. ``"SPY"``
for US tickers, ``"^N225"`` for ``.T`` listings); defaults to SPY for
callers that haven't been updated to thread the benchmark through.
"""
messages = [
("system", self.log_reflection_prompt),
@@ -45,7 +49,7 @@ class Reflector:
"human",
(
f"Raw return: {raw_return:+.1%}\n"
f"Alpha vs SPY: {alpha_return:+.1%}\n\n"
f"Alpha vs {benchmark_name}: {alpha_return:+.1%}\n\n"
f"Final Decision:\n{final_decision}"
),
),

View File

@@ -54,7 +54,11 @@ class GraphSetup:
tool_nodes["market"] = self.tool_nodes["market"]
if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst(
# "social" selector key preserved for back-compat with existing
# user configs; the underlying agent has been renamed to
# sentiment_analyst (the old name advertised social-media data
# the agent never had access to — see issue #557).
analyst_nodes["social"] = create_sentiment_analyst(
self.quick_thinking_llm
)
delete_nodes["social"] = create_msg_delete()

View File

@@ -18,6 +18,7 @@ from tradingagents.llm_clients import create_llm_client
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import TradingMemoryLog
from tradingagents.dataflows.utils import safe_ticker_component
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
@@ -115,7 +116,9 @@ class TradingAgentsGraph:
self.conditional_logic,
)
self.propagator = Propagator()
self.propagator = Propagator(
max_recur_limit=self.config.get("max_recur_limit", 100),
)
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
@@ -187,14 +190,37 @@ class TradingAgentsGraph:
),
}
def _resolve_benchmark(self, ticker: str) -> str:
"""Pick the benchmark ticker for alpha calculation against ``ticker``.
``config["benchmark_ticker"]`` overrides everything when set; otherwise
the suffix map matches the ticker's exchange suffix (e.g. ``.T`` for
Tokyo). US-listed tickers without a dotted suffix fall through to the
empty-suffix entry (SPY by default). Unrecognised suffixes (including
US tickers with dots like ``BRK.B``) also fall back to the empty-suffix
entry, which is the right default because the alpha calculation works
in USD.
"""
explicit = self.config.get("benchmark_ticker")
if explicit:
return explicit
benchmark_map = self.config.get("benchmark_map", {})
ticker_upper = ticker.upper()
for suffix, benchmark in benchmark_map.items():
if suffix and ticker_upper.endswith(suffix.upper()):
return benchmark
return benchmark_map.get("", "SPY")
def _fetch_returns(
self, ticker: str, trade_date: str, holding_days: int = 5
self, ticker: str, trade_date: str, holding_days: int = 5,
benchmark: str = "SPY",
) -> Tuple[Optional[float], Optional[float], Optional[int]]:
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
Returns (raw_return, alpha_return, actual_holding_days) or
(None, None, None) if price data is unavailable (too recent, delisted,
or network error).
``benchmark`` is the index used as the alpha baseline (resolved by the
caller via ``_resolve_benchmark``). Returns ``(raw_return, alpha_return,
actual_holding_days)`` or ``(None, None, None)`` if price data is
unavailable (too recent, delisted, or network error).
"""
try:
start = datetime.strptime(trade_date, "%Y-%m-%d")
@@ -202,26 +228,26 @@ class TradingAgentsGraph:
end_str = end.strftime("%Y-%m-%d")
stock = yf.Ticker(ticker).history(start=trade_date, end=end_str)
spy = yf.Ticker("SPY").history(start=trade_date, end=end_str)
bench = yf.Ticker(benchmark).history(start=trade_date, end=end_str)
if len(stock) < 2 or len(spy) < 2:
if len(stock) < 2 or len(bench) < 2:
return None, None, None
actual_days = min(holding_days, len(stock) - 1, len(spy) - 1)
actual_days = min(holding_days, len(stock) - 1, len(bench) - 1)
raw = float(
(stock["Close"].iloc[actual_days] - stock["Close"].iloc[0])
/ stock["Close"].iloc[0]
)
spy_ret = float(
(spy["Close"].iloc[actual_days] - spy["Close"].iloc[0])
/ spy["Close"].iloc[0]
bench_ret = float(
(bench["Close"].iloc[actual_days] - bench["Close"].iloc[0])
/ bench["Close"].iloc[0]
)
alpha = raw - spy_ret
alpha = raw - bench_ret
return raw, alpha, actual_days
except Exception as e:
logger.warning(
"Could not resolve outcome for %s on %s (will retry next run): %s",
ticker, trade_date, e,
"Could not resolve outcome for %s on %s vs %s (will retry next run): %s",
ticker, trade_date, benchmark, e,
)
return None, None, None
@@ -239,15 +265,19 @@ class TradingAgentsGraph:
if not pending:
return
benchmark = self._resolve_benchmark(ticker)
updates = []
for entry in pending:
raw, alpha, days = self._fetch_returns(ticker, entry["date"])
raw, alpha, days = self._fetch_returns(
ticker, entry["date"], benchmark=benchmark,
)
if raw is None:
continue # price not available yet — try again next run
reflection = self.reflector.reflect_on_final_decision(
final_decision=entry.get("decision", ""),
raw_return=raw,
alpha_return=alpha,
benchmark_name=benchmark,
)
updates.append({
"ticker": ticker,
@@ -321,7 +351,11 @@ class TradingAgentsGraph:
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
final_state = trace[-1]
# Streamed chunks are per-node deltas. Merge them so the returned
# state matches what graph.invoke() yields in the non-debug path.
final_state = {}
for chunk in trace:
final_state.update(chunk)
else:
final_state = self.graph.invoke(init_agent_state, **args)
@@ -378,8 +412,10 @@ class TradingAgentsGraph:
"final_trade_decision": final_state["final_trade_decision"],
}
# Save to file
directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs"
# Save to file. Reject ticker values that would escape the
# results directory when joined as a path component.
safe_ticker = safe_ticker_component(self.ticker)
directory = Path(self.config["results_dir"]) / safe_ticker / "TradingAgentsStrategy_logs"
directory.mkdir(parents=True, exist_ok=True)
log_path = directory / f"full_states_log_{trade_date}.json"

View File

@@ -0,0 +1,44 @@
"""Canonical provider -> API-key env-var mapping.
A single source of truth for which environment variable holds the API
key for each supported LLM provider. Used by the CLI's interactive key
prompt (cli/utils.ensure_api_key) and by anything else that needs to
ask "does this provider require a key, and which env var is it?".
When adding a new provider, register its env var here so the CLI flow
prompts for it automatically instead of failing on first API call.
"""
from __future__ import annotations
from typing import Optional
PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"google": "GOOGLE_API_KEY",
"azure": "AZURE_OPENAI_API_KEY",
"xai": "XAI_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
# Dual-region providers each carry their own account; keys are not
# interchangeable between the international and China endpoints.
"qwen": "DASHSCOPE_API_KEY",
"qwen-cn": "DASHSCOPE_CN_API_KEY",
"glm": "ZHIPU_API_KEY",
"glm-cn": "ZHIPU_CN_API_KEY",
"minimax": "MINIMAX_API_KEY",
"minimax-cn": "MINIMAX_CN_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
# Local runtimes do not authenticate.
"ollama": None,
}
def get_api_key_env(provider: str) -> Optional[str]:
"""Return the env var name for `provider`'s API key, or None if not applicable.
Unknown providers also return None — callers should treat that as
"no key check possible" rather than as "no key required".
"""
return PROVIDER_API_KEY_ENV.get(provider.lower())

View File

@@ -0,0 +1,120 @@
"""Declarative per-model capability table for OpenAI-compatible providers.
This is the single place that knows which model IDs reject which API
parameters or require which structured-output method. The LLM client
subclasses consult ``get_capabilities(model_name)`` instead of hardcoding
model-name ``if`` ladders, so adding a new model (or a new provider quirk)
means editing this table — not the client code.
Pattern adapted from the per-model ``compat:`` flags DeepSeek themselves
publish in their integration guides (e.g. the Oh My Pi config schema
documents ``supportsToolChoice``, ``requiresReasoningContentForToolCalls``
as declarative per-model fields).
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Literal
StructuredMethod = Literal[
"function_calling", # uses tools; respects supports_tool_choice
"json_mode", # uses response_format={"type":"json_object"}
"json_schema", # uses response_format={"type":"json_schema",...}
"none", # no structured output available; caller falls back to free-text
]
@dataclass(frozen=True)
class ModelCapabilities:
"""What an OpenAI-compatible model accepts at the API level."""
supports_tool_choice: bool
supports_json_mode: bool
supports_json_schema: bool
preferred_structured_method: StructuredMethod
# DeepSeek thinking-mode models 400 if reasoning_content from prior
# assistant turns is not echoed back on the next request.
requires_reasoning_content_roundtrip: bool = False
# DeepSeek's thinking models accept the ``tools`` array but reject the
# ``tool_choice`` parameter (official Oh My Pi integration guide and the
# 400 response in issue #678). Their official tool-calling examples
# (api-docs.deepseek.com/guides/tool_calls) pass ``tools=[...]`` without
# ``tool_choice`` — we mirror that pattern by setting supports_tool_choice
# to False and letting the client suppress the kwarg.
_DEEPSEEK_THINKING = ModelCapabilities(
supports_tool_choice=False,
supports_json_mode=True,
supports_json_schema=False,
preferred_structured_method="function_calling",
requires_reasoning_content_roundtrip=True,
)
_DEEPSEEK_CHAT = ModelCapabilities(
supports_tool_choice=True,
supports_json_mode=True,
supports_json_schema=False,
preferred_structured_method="function_calling",
)
# MiniMax M2.x reasoning models accept the tools array, but their
# tool_choice parameter is restricted to the enum {"none", "auto"}
# (platform.minimax.io/docs/api-reference/text-post). Langchain's
# function_calling path sends tool_choice as a function-spec dict, which
# MiniMax 400s — same shape as the DeepSeek bug. supports_tool_choice=False
# makes the dispatch in NormalizedChatOpenAI suppress the kwarg; the schema
# still ships as a tool. json_mode response_format is only for
# MiniMax-Text-01, not M2.x.
_MINIMAX_THINKING = ModelCapabilities(
supports_tool_choice=False,
supports_json_mode=False,
supports_json_schema=False,
preferred_structured_method="function_calling",
)
_DEFAULT = ModelCapabilities(
supports_tool_choice=True,
supports_json_mode=True,
supports_json_schema=True,
preferred_structured_method="function_calling",
)
# Exact-ID matches take precedence over pattern matches.
_BY_ID: dict[str, ModelCapabilities] = {
"deepseek-chat": _DEEPSEEK_CHAT,
"deepseek-reasoner": _DEEPSEEK_THINKING,
"deepseek-v4-flash": _DEEPSEEK_THINKING,
"deepseek-v4-pro": _DEEPSEEK_THINKING,
# MiniMax — full official model lineup per
# platform.minimax.io/docs/api-reference/text-openai-api
"MiniMax-M2.7": _MINIMAX_THINKING,
"MiniMax-M2.7-highspeed": _MINIMAX_THINKING,
"MiniMax-M2.5": _MINIMAX_THINKING,
"MiniMax-M2.5-highspeed": _MINIMAX_THINKING,
"MiniMax-M2.1": _MINIMAX_THINKING,
"MiniMax-M2.1-highspeed": _MINIMAX_THINKING,
"MiniMax-M2": _MINIMAX_THINKING,
}
# Forward-compat patterns. New ``deepseek-v5-*`` / ``deepseek-reasoner-*``
# or ``MiniMax-M3*`` variants inherit the thinking-mode quirks automatically.
_BY_PATTERN: list[tuple[re.Pattern[str], ModelCapabilities]] = [
(re.compile(r"^deepseek-v\d"), _DEEPSEEK_THINKING),
(re.compile(r"^deepseek-reasoner"), _DEEPSEEK_THINKING),
(re.compile(r"^MiniMax-M\d"), _MINIMAX_THINKING),
]
def get_capabilities(model_name: str) -> ModelCapabilities:
"""Resolve capabilities by exact ID, then pattern, then default."""
if model_name in _BY_ID:
return _BY_ID[model_name]
for pattern, caps in _BY_PATTERN:
if pattern.match(model_name):
return caps
return _DEFAULT

View File

@@ -4,7 +4,11 @@ from .base_client import BaseLLMClient
# Providers that use the OpenAI-compatible chat completions API
_OPENAI_COMPATIBLE = (
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
"openai", "xai", "deepseek",
"qwen", "qwen-cn",
"glm", "glm-cn",
"minimax", "minimax-cn",
"ollama", "openrouter",
)

View File

@@ -8,108 +8,171 @@ ModelOption = Tuple[str, str]
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
# Shared model list for GLM via Z.AI (international) and BigModel (China).
# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides).
# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}.
_GLM_MODELS: Dict[str, List[ModelOption]] = {
"quick": [
("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"),
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
("GLM-4.5-Air - Lightweight, cost-efficient", "glm-4.5-air"),
("Custom model ID", "custom"),
],
"deep": [
("GLM-5.1 - Latest flagship, 204K ctx", "glm-5.1"),
("GLM-5 - Flagship, 204K ctx", "glm-5"),
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
("Custom model ID", "custom"),
],
}
# Shared model list for Qwen's global (dashscope-intl) and CN (dashscope) endpoints.
# Source: modelstudio.console.alibabacloud.com (Featured Models — Flagship + Cost-optimized).
#
# Only versioned IDs are exposed in the dropdown. The version-less aliases
# (qwen-plus, qwen-flash) are documented by Alibaba as auto-upgrading
# pointers ("backbone, latest, and snapshot ... have been upgraded to the
# Qwen3 series"), which means their behavior shifts when Alibaba rotates
# the backing model. Users who want a specific generation pick it
# explicitly; users who really want auto-latest can enter the alias via
# "Custom model ID".
_QWEN_MODELS: Dict[str, List[ModelOption]] = {
"quick": [
("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"),
("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"),
("Custom model ID", "custom"),
],
"deep": [
("Qwen 3.6 Plus - Flagship vision-language, agentic coding SOTA", "qwen3.6-plus"),
("Qwen 3.5 Plus - Previous-gen flagship", "qwen3.5-plus"),
("Qwen 3 Max - Specialized for agent programming + tool use", "qwen3-max"),
("Custom model ID", "custom"),
],
}
# Shared model list for MiniMax's global and CN endpoints (same IDs).
# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api.
# All M2.x models share a 204,800-token context window.
_MINIMAX_MODELS: Dict[str, List[ModelOption]] = {
"quick": [
("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"),
("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"),
("MiniMax-M2.1-highspeed - M2.1 highspeed, 204K ctx", "MiniMax-M2.1-highspeed"),
("Custom model ID", "custom"),
],
"deep": [
("MiniMax-M2.7 - Flagship, SOTA on coding/agent benchmarks, 204K ctx", "MiniMax-M2.7"),
("MiniMax-M2.7-highspeed - Same quality as M2.7, ~100 TPS", "MiniMax-M2.7-highspeed"),
("MiniMax-M2.5 - Previous-gen flagship, 204K ctx", "MiniMax-M2.5"),
("MiniMax-M2.1 - Earlier M2 line, 204K ctx", "MiniMax-M2.1"),
("MiniMax-M2 - Base M2, 204K ctx", "MiniMax-M2"),
("Custom model ID", "custom"),
],
}
MODEL_OPTIONS: ProviderModeOptions = {
"openai": {
"quick": [
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
("GPT-5.4 Nano - Cheapest, high-volume tasks", "gpt-5.4-nano"),
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"),
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
],
"deep": [
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"),
("GPT-5.4 - Previous-gen frontier, 1M context, cost-effective", "gpt-5.4"),
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
("GPT-5.5 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.5-pro"),
],
},
"anthropic": {
"quick": [
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
("Claude Haiku 4.5 - Fastest with near-frontier intelligence", "claude-haiku-4-5"),
("Claude Sonnet 4.5 - High-performance for agents and coding", "claude-sonnet-4-5"),
],
"deep": [
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
("Claude Opus 4.7 - Latest frontier, long-running agents and coding", "claude-opus-4-7"),
("Claude Opus 4.6 - Frontier intelligence, agents and coding", "claude-opus-4-6"),
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
],
},
"google": {
"quick": [
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"),
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
("Gemini 3.1 Flash Lite - Most cost-efficient (GA)", "gemini-3.1-flash-lite"),
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
],
"deep": [
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 3.1 Pro - Reasoning-first, complex workflows (preview)", "gemini-3.1-pro-preview"),
("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"),
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
],
},
"xai": {
"quick": [
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4.20 (Non-Reasoning) - Latest, speed-optimized", "grok-4.20-non-reasoning"),
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
],
"deep": [
("Grok 4 - Flagship model", "grok-4-0709"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
("Grok 4.20 (Reasoning) - Latest frontier reasoning model", "grok-4.20-reasoning"),
("Grok 4 - Flagship (dated build)", "grok-4-0709"),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4.20 - Auto-select reasoning behavior", "grok-4.20"),
],
},
"deepseek": {
"quick": [
("DeepSeek V4 Flash - Latest V4 fast model", "deepseek-v4-flash"),
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
"deep": [
("DeepSeek V4 Pro - Latest V4 flagship model", "deepseek-v4-pro"),
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
},
"qwen": {
"quick": [
("Qwen 3.5 Flash", "qwen3.5-flash"),
("Qwen Plus", "qwen-plus"),
("Custom model ID", "custom"),
],
"deep": [
("Qwen 3.6 Plus", "qwen3.6-plus"),
("Qwen 3.5 Plus", "qwen3.5-plus"),
("Qwen 3 Max", "qwen3-max"),
("Custom model ID", "custom"),
],
},
"glm": {
"quick": [
("GLM-4.7", "glm-4.7"),
("GLM-5", "glm-5"),
("Custom model ID", "custom"),
],
"deep": [
("GLM-5.1", "glm-5.1"),
("GLM-5", "glm-5"),
("Custom model ID", "custom"),
],
},
# Qwen: same model IDs across global (dashscope-intl) and China
# (dashscope) endpoints, so the two provider keys share one model list.
"qwen": _QWEN_MODELS,
"qwen-cn": _QWEN_MODELS,
# GLM: Z.AI (international) and BigModel (China) host the same model
# IDs; the two provider keys share one model list.
"glm": _GLM_MODELS,
"glm-cn": _GLM_MODELS,
# MiniMax: same model IDs across global (.io) and China (.com) regions,
# so the two provider keys share one model list.
"minimax": _MINIMAX_MODELS,
"minimax-cn": _MINIMAX_MODELS,
# OpenRouter: fetched dynamically. Azure: any deployed model name.
# Ollama display labels intentionally omit a "local" marker — the
# endpoint is now configurable via OLLAMA_BASE_URL, so the same labels
# apply whether the user runs ollama-serve on localhost or against a
# remote host. The actual resolved endpoint is surfaced separately by
# cli.utils.confirm_ollama_endpoint() right after provider selection.
# "Custom model ID" lets users pick any model they have pulled via
# `ollama pull` beyond the three suggested defaults.
"ollama": {
"quick": [
("Qwen3:latest (8B, local)", "qwen3:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
("Qwen3:latest (8B)", "qwen3:latest"),
("GPT-OSS:latest (20B)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B)", "glm-4.7-flash:latest"),
("Custom model ID", "custom"),
],
"deep": [
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"),
("GLM-4.7-Flash:latest (30B)", "glm-4.7-flash:latest"),
("GPT-OSS:latest (20B)", "gpt-oss:latest"),
("Qwen3:latest (8B)", "qwen3:latest"),
("Custom model ID", "custom"),
],
},
}

View File

@@ -1,56 +1,176 @@
import os
from typing import Any, Optional
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from .api_key_env import get_api_key_env
from .base_client import BaseLLMClient, normalize_content
from .capabilities import get_capabilities
from .validators import validate_model
class NormalizedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with normalized content output.
"""ChatOpenAI with normalized content output and capability-aware binding.
The Responses API returns content as a list of typed blocks
(reasoning, text, etc.). This normalizes to string for consistent
downstream handling.
(reasoning, text, etc.). ``invoke`` normalizes to string for
consistent downstream handling.
``with_structured_output`` consults the per-model capability table
(``capabilities.get_capabilities``) to pick the method and to decide
whether ``tool_choice`` may be sent. Models that reject ``tool_choice``
(e.g. DeepSeek V4 and reasoner — per their official tool-calling
guide) still bind the schema as a tool, but no ``tool_choice``
parameter is sent.
Provider-specific quirks beyond structured-output (e.g. DeepSeek's
reasoning_content roundtrip) live in subclasses so this base class
stays small.
"""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
def with_structured_output(self, schema, *, method=None, **kwargs):
"""Wrap with structured output, defaulting to function_calling for OpenAI.
langchain-openai's Responses-API-parse path (the default for json_schema
when use_responses_api=True) calls response.model_dump(...) on the OpenAI
SDK's union-typed parsed response, which makes Pydantic emit ~20
PydanticSerializationUnexpectedValue warnings per call. The function-calling
path returns a plain tool-call shape that does not trigger that
serialization, so it is the cleaner choice for our combination of
use_responses_api=True + with_structured_output. Both paths use OpenAI's
strict mode and produce the same typed Pydantic instance.
"""
if method is None:
method = "function_calling"
caps = get_capabilities(self.model_name)
if caps.preferred_structured_method == "none":
raise NotImplementedError(
f"{self.model_name} has no structured-output method available; "
f"agent factories will fall back to free-text generation."
)
method = method or caps.preferred_structured_method
# When the model rejects tool_choice, suppress langchain's hardcoded
# value. The schema is still bound as a tool — exactly what
# DeepSeek's official tool-calling examples do.
if method == "function_calling" and not caps.supports_tool_choice:
kwargs.setdefault("tool_choice", None)
return super().with_structured_output(schema, method=method, **kwargs)
def _input_to_messages(input_: Any) -> list:
"""Normalise a langchain LLM input to a list of message objects.
Accepts a list of messages, a ``ChatPromptValue`` (from a
ChatPromptTemplate), or anything else (treated as no messages).
Used by providers that need to walk the outgoing message history;
in particular DeepSeek thinking-mode propagation must work for
both bare-list invocations and ChatPromptTemplate-driven ones, so
treating only ``list`` here would silently skip half the call sites.
"""
if isinstance(input_, list):
return input_
if hasattr(input_, "to_messages"):
return input_.to_messages()
return []
class DeepSeekChatOpenAI(NormalizedChatOpenAI):
"""DeepSeek-specific overrides on top of the OpenAI-compatible client.
Thinking-mode round-trip is the only DeepSeek-specific behavior that
stays here. When DeepSeek's thinking models return a response with
``reasoning_content``, that field must be echoed back as part of the
assistant message on the next turn or the API fails with HTTP 400.
``_create_chat_result`` captures it on receive and
``_get_request_payload`` re-attaches it on send.
Tool-choice handling for V4 and reasoner — those models reject the
``tool_choice`` parameter — is handled by the capability dispatch in
``NormalizedChatOpenAI.with_structured_output``, not here.
"""
def _get_request_payload(self, input_, *, stop=None, **kwargs):
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
outgoing = payload.get("messages", [])
for message_dict, message in zip(outgoing, _input_to_messages(input_)):
if not isinstance(message, AIMessage):
continue
reasoning = message.additional_kwargs.get("reasoning_content")
if reasoning is not None:
message_dict["reasoning_content"] = reasoning
return payload
def _create_chat_result(self, response, generation_info=None):
chat_result = super()._create_chat_result(response, generation_info)
response_dict = (
response
if isinstance(response, dict)
else response.model_dump(
exclude={"choices": {"__all__": {"message": {"parsed"}}}}
)
)
for generation, choice in zip(
chat_result.generations, response_dict.get("choices", [])
):
reasoning = choice.get("message", {}).get("reasoning_content")
if reasoning is not None:
generation.message.additional_kwargs["reasoning_content"] = reasoning
return chat_result
class MinimaxChatOpenAI(NormalizedChatOpenAI):
"""MiniMax-specific overrides on top of the OpenAI-compatible client.
M2.x reasoning models embed ``<think>...</think>`` blocks directly in
``message.content`` by default, which would pollute saved reports.
Per platform.minimax.io/docs/api-reference/text-openai-api, setting
``reasoning_split=True`` in the request body redirects the thinking
block into ``reasoning_details`` so ``content`` stays clean.
Tool-choice handling for M2.x — those models accept only the string
enum ``{"none", "auto"}`` and reject langchain's function-spec dict —
is handled by the capability dispatch in
``NormalizedChatOpenAI.with_structured_output``, not here.
"""
def _get_request_payload(self, input_, *, stop=None, **kwargs):
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload.setdefault("reasoning_split", True)
return payload
# Kwargs forwarded from user config to ChatOpenAI
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "reasoning_effort",
"api_key", "callbacks", "http_client", "http_async_client",
)
# Provider base URLs and API key env vars
_PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None),
# Provider base URLs. API-key env vars live in api_key_env.PROVIDER_API_KEY_ENV
# (one canonical mapping consulted by both this client and the CLI's
# interactive key-prompt). Dual-region providers (qwen/glm/minimax) keep
# separate endpoints because international and China accounts cannot share
# credentials (#758).
_PROVIDER_BASE_URL = {
"xai": "https://api.x.ai/v1",
"deepseek": "https://api.deepseek.com",
"qwen": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
"qwen-cn": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"glm": "https://api.z.ai/api/paas/v4/",
"glm-cn": "https://open.bigmodel.cn/api/paas/v4/",
"minimax": "https://api.minimax.io/v1",
"minimax-cn": "https://api.minimaxi.com/v1",
"openrouter": "https://openrouter.ai/api/v1",
"ollama": "http://localhost:11434/v1",
}
def _resolve_provider_base_url(provider: str) -> Optional[str]:
"""Default base URL for ``provider``, with env-var overrides where defined.
Currently only Ollama supports an env-var override (``OLLAMA_BASE_URL``),
matching the convention in the broader Ollama tooling ecosystem so users
can point at a remote ollama-serve without editing code. The check is
call-time, not import-time, so tests that monkeypatch the env after
import behave correctly.
"""
if provider == "ollama":
env_url = os.environ.get("OLLAMA_BASE_URL")
if env_url:
return env_url
return _PROVIDER_BASE_URL.get(provider)
class OpenAIClient(BaseLLMClient):
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
@@ -75,14 +195,22 @@ class OpenAIClient(BaseLLMClient):
self.warn_if_unknown_model()
llm_kwargs = {"model": self.model}
# Provider-specific base URL and auth
if self.provider in _PROVIDER_CONFIG:
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = base_url
# Provider-specific base URL and auth. An explicit base_url on the
# client (e.g. a corporate proxy) takes precedence over the
# provider default so users can route through their own gateway.
if self.provider in _PROVIDER_BASE_URL:
llm_kwargs["base_url"] = self.base_url or _resolve_provider_base_url(self.provider)
api_key_env = get_api_key_env(self.provider)
if api_key_env:
api_key = os.environ.get(api_key_env)
if api_key:
llm_kwargs["api_key"] = api_key
else:
raise ValueError(
f"API key for provider '{self.provider}' is not set. "
f"Please set the {api_key_env} environment variable "
f"(e.g. add {api_key_env}=your_key to your .env file)."
)
else:
llm_kwargs["api_key"] = "ollama"
elif self.base_url:
@@ -98,7 +226,15 @@ class OpenAIClient(BaseLLMClient):
if self.provider == "openai":
llm_kwargs["use_responses_api"] = True
return NormalizedChatOpenAI(**llm_kwargs)
# Provider-specific quirks live in their own subclasses so the
# base NormalizedChatOpenAI stays free of provider branches.
if self.provider == "deepseek":
chat_cls = DeepSeekChatOpenAI
elif self.provider in ("minimax", "minimax-cn"):
chat_cls = MinimaxChatOpenAI
else:
chat_cls = NormalizedChatOpenAI
return chat_cls(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for the provider."""