mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-02 06:53:16 +03:00
Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e9e7b83c7 | ||
|
|
2c97bad45c | ||
|
|
7c37249f80 | ||
|
|
4016fd4efa | ||
|
|
bba147798f | ||
|
|
0fda24515f | ||
|
|
4cbd4b086f | ||
|
|
ebd2e12e67 | ||
|
|
f85f5d9f5d | ||
|
|
8e7654f0df | ||
|
|
872b063e69 | ||
|
|
6abc768c1d | ||
|
|
8536ccacdd | ||
|
|
fa4d01c23a | ||
|
|
b0f6058299 | ||
|
|
59d6b2152d | ||
|
|
10c136f49c | ||
|
|
4f965bf46a | ||
|
|
bdb9c29d44 | ||
|
|
bdc5fc62d3 | ||
|
|
78fb66aed1 | ||
|
|
7269f877c1 | ||
|
|
28d5cc661f | ||
|
|
7004dfe554 | ||
|
|
4641c03340 | ||
|
|
e75d17bc51 | ||
|
|
6cddd26d6e | ||
|
|
c61242a28c | ||
|
|
58e99421bd | ||
|
|
46e1b600b8 | ||
|
|
ae8c8aebe8 | ||
|
|
f3f58bdbdc | ||
|
|
e1113880a1 | ||
|
|
bd6a5b75b5 | ||
|
|
8793336dad | ||
|
|
047b38971c | ||
|
|
f5026009f9 |
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
.git
|
||||||
|
.venv
|
||||||
|
.env
|
||||||
|
.claude
|
||||||
|
.idea
|
||||||
|
.vscode
|
||||||
|
.DS_Store
|
||||||
|
__pycache__
|
||||||
|
*.egg-info
|
||||||
|
build
|
||||||
|
dist
|
||||||
|
results
|
||||||
|
eval_results
|
||||||
|
Dockerfile
|
||||||
|
docker-compose.yml
|
||||||
5
.env.enterprise.example
Normal file
5
.env.enterprise.example
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Azure OpenAI
|
||||||
|
AZURE_OPENAI_API_KEY=
|
||||||
|
AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
|
||||||
|
AZURE_OPENAI_DEPLOYMENT_NAME=
|
||||||
|
# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API
|
||||||
@@ -3,4 +3,7 @@ OPENAI_API_KEY=
|
|||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
XAI_API_KEY=
|
XAI_API_KEY=
|
||||||
|
DEEPSEEK_API_KEY=
|
||||||
|
DASHSCOPE_API_KEY=
|
||||||
|
ZHIPU_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
|||||||
266
CHANGELOG.md
Normal file
266
CHANGELOG.md
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to TradingAgents are documented here.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
|
and this project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
Breaking changes within the 0.x line are called out explicitly.
|
||||||
|
|
||||||
|
## [0.2.4] — 2026-04-25
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Structured-output decision agents.** Research Manager, Trader, and Portfolio
|
||||||
|
Manager now use `llm.with_structured_output(Schema)` on their primary call
|
||||||
|
and return typed Pydantic instances. Each provider's native structured-output
|
||||||
|
mode is used (`json_schema` for OpenAI / xAI, `response_schema` for Gemini,
|
||||||
|
tool-use for Anthropic, function-calling for OpenAI-compatible providers).
|
||||||
|
Render helpers preserve the existing markdown shape so memory log, CLI
|
||||||
|
display, and saved reports keep working unchanged. (#434)
|
||||||
|
- **LangGraph checkpoint resume** — opt-in via `--checkpoint`. State is saved
|
||||||
|
after each node so crashed or interrupted runs resume from the last
|
||||||
|
successful step. Per-ticker SQLite databases under
|
||||||
|
`~/.tradingagents/cache/checkpoints/`. `--clear-checkpoints` resets them. (#594)
|
||||||
|
- **Persistent decision log** replacing the per-agent BM25 memory. Decisions
|
||||||
|
are stored automatically at the end of `propagate()`; the next same-ticker
|
||||||
|
run resolves prior pending entries with realised return, alpha vs SPY, and
|
||||||
|
a one-paragraph reflection. Override path with `TRADINGAGENTS_MEMORY_LOG_PATH`.
|
||||||
|
Optional `memory_log_max_entries` config caps resolved entries; pending
|
||||||
|
entries are never pruned. (#578, #563, #564, #579)
|
||||||
|
- **DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), and Azure OpenAI**
|
||||||
|
providers, plus dynamic OpenRouter model selection.
|
||||||
|
- **Docker support** — multi-stage build with separate dev and runtime images.
|
||||||
|
- **`scripts/smoke_structured_output.py`** — diagnostic that exercises the
|
||||||
|
three structured-output agents against any provider so contributors can
|
||||||
|
verify their setup with one command.
|
||||||
|
- **5-tier rating scale** (Buy / Overweight / Hold / Underweight / Sell) used
|
||||||
|
consistently by Research Manager, Portfolio Manager, signal processor, and
|
||||||
|
the memory log; Trader keeps 3-tier (Buy / Hold / Sell) since transaction
|
||||||
|
direction is naturally ternary.
|
||||||
|
- **Pytest fixtures** — lazy LLM client imports plus placeholder API keys so
|
||||||
|
the test suite runs cleanly without credentials. (#588)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **`backend_url` default is now `None`** rather than the OpenAI URL. Each
|
||||||
|
provider client falls back to its native default. The previous default
|
||||||
|
leaked the OpenAI URL into non-OpenAI clients (e.g. Gemini), producing
|
||||||
|
malformed request URLs for Python users who switched providers without
|
||||||
|
overriding `backend_url`. The CLI flow is unaffected.
|
||||||
|
- All file I/O passes explicit `encoding="utf-8"` so Windows users no longer
|
||||||
|
hit `UnicodeEncodeError` with the cp1252 default. (#543, #550, #576)
|
||||||
|
- Cache and log directories moved to `~/.tradingagents/` to resolve Docker
|
||||||
|
permission issues. (#519)
|
||||||
|
- `SignalProcessor` reads the rating from the Portfolio Manager's rendered
|
||||||
|
markdown via a deterministic heuristic — no extra LLM call.
|
||||||
|
- OpenAI structured-output calls default to `method="function_calling"` to
|
||||||
|
avoid noisy `PydanticSerializationUnexpectedValue` warnings emitted by
|
||||||
|
langchain-openai's Responses-API parse path. Same typed result, no warnings.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Empty memory no longer triggers fabricated past-lessons in agent prompts;
|
||||||
|
the memory-log redesign makes this structurally impossible since only the
|
||||||
|
Portfolio Manager consults memory and only when entries exist. (#572)
|
||||||
|
- Tool-call logging processes every chunk message, not just the last one, and
|
||||||
|
memory score normalization handles empty score arrays. (#534, #531)
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- `FinancialSituationMemory` (the per-agent BM25 system) and the dead
|
||||||
|
`reflect_and_remember()` plumbing; subsumed by the persistent decision log.
|
||||||
|
- Hardcoded Google endpoint that caused 404 when `langchain-google-genai`
|
||||||
|
changed its API path. (#493, #496)
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
Thanks to everyone who shaped this release through code, design, and reports:
|
||||||
|
|
||||||
|
- [@claytonbrown](https://github.com/claytonbrown) — checkpoint resume (#594), test fixtures (#588), design feedback on cost tracking (#582) and structured validation (#583)
|
||||||
|
- [@Bcardo](https://github.com/Bcardo) — memory-log redesign (#579), empty-memory hallucination report (#572), encoding fix proposal (#570)
|
||||||
|
- [@voidborne-d](https://github.com/voidborne-d) — memory persistence design (#564), portfolio manager state fix (#503)
|
||||||
|
- [@mannubaveja007](https://github.com/mannubaveja007) — structured-output feature request (#434)
|
||||||
|
- [@kelder66](https://github.com/kelder66) — RAM-only memory issue (#563)
|
||||||
|
- [@Gujiassh](https://github.com/Gujiassh) — tool-call logging fix (#534), test stub PR (#533)
|
||||||
|
- [@iuyup](https://github.com/iuyup) — memory score normalization fix (#531)
|
||||||
|
- [@kaihg](https://github.com/kaihg) — Google base_url fix (#496)
|
||||||
|
- [@32ryh98yfe](https://github.com/32ryh98yfe) — Gemini 404 report (#493)
|
||||||
|
- [@uppb](https://github.com/uppb) — OpenRouter dynamic model selection (#482)
|
||||||
|
- [@guoz14](https://github.com/guoz14) — OpenRouter limited-model report (#337)
|
||||||
|
- [@samchenku](https://github.com/samchenku) — indicator name normalization (#490)
|
||||||
|
- [@JasonOA888](https://github.com/JasonOA888) — y_finance pandas import fix (#488)
|
||||||
|
- [@tiffanychum](https://github.com/tiffanychum) — stale import cleanup (#499)
|
||||||
|
- [@zaizou](https://github.com/zaizou) — Docker permission issue (#519)
|
||||||
|
- [@Stosman123](https://github.com/Stosman123), [@mauropuga](https://github.com/mauropuga), [@hotwind2015](https://github.com/hotwind2015) — Windows encoding bug reports (#543, #550, #576)
|
||||||
|
- [@nnishad](https://github.com/nnishad), [@atharvajoshi01](https://github.com/atharvajoshi01) — encoding fix proposals (#568, #549)
|
||||||
|
|
||||||
|
## [0.2.3] — 2026-03-29
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Multi-language output** for analyst reports and final decisions, with a
|
||||||
|
CLI selector. Internal agent debate stays in English for reasoning quality. (#472)
|
||||||
|
- **GPT-5.4 family models** in the default catalog, with deep/quick model split.
|
||||||
|
- **Unified model catalog** as a single source of truth for CLI options and
|
||||||
|
provider validation.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- `base_url` is forwarded to Google and Anthropic clients so corporate proxies
|
||||||
|
work consistently across providers. (#427)
|
||||||
|
- Standardised the Google `api_key` parameter to the unified `api_key` form.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Backtesting fetchers no longer leak look-ahead data when `curr_date` is in
|
||||||
|
the middle of a fetched window. (#475)
|
||||||
|
- Invalid indicator names from the LLM are caught at the tool boundary instead
|
||||||
|
of crashing the run. (#429)
|
||||||
|
- yfinance news fetchers respect the same exponential-backoff retry as price
|
||||||
|
fetchers. (#445)
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
- [@ahmedk20](https://github.com/ahmedk20) — multi-language output (#472)
|
||||||
|
- [@CadeYu](https://github.com/CadeYu) — model catalog typing (#464)
|
||||||
|
- [@javierdejesusda](https://github.com/javierdejesusda) — unified Google API key parameter (#453)
|
||||||
|
- [@voidborne-d](https://github.com/voidborne-d) — yfinance news retry (#445)
|
||||||
|
- [@kostakost2](https://github.com/kostakost2) — look-ahead bias report (#475)
|
||||||
|
- [@lu-zhengda](https://github.com/lu-zhengda) — proxy/base_url support request (#427)
|
||||||
|
- [@VamsiKrishna2021](https://github.com/VamsiKrishna2021) — invalid indicator crash report (#429)
|
||||||
|
|
||||||
|
## [0.2.2] — 2026-03-22
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Five-tier rating scale** (Buy / Overweight / Hold / Underweight / Sell)
|
||||||
|
introduced for the Portfolio Manager.
|
||||||
|
- **Anthropic effort level** support for Claude models.
|
||||||
|
- **OpenAI Responses API** path for native OpenAI models.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- `risk_manager` renamed to `portfolio_manager` to match the role description
|
||||||
|
shown in the CLI display.
|
||||||
|
- Exchange-qualified tickers (e.g. `7203.T`, `BRK.B`) preserved across all
|
||||||
|
agent prompts and tool calls.
|
||||||
|
- Process-level UTF-8 default attempted for cross-platform consistency
|
||||||
|
(note: this approach did not actually take effect; replaced in v0.2.4 with
|
||||||
|
explicit per-call `encoding="utf-8"` arguments).
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- yfinance rate-limit errors are retried with exponential backoff. (#426)
|
||||||
|
- HTTP client SSL customisation is supported for environments that need
|
||||||
|
custom certificate bundles. (#379)
|
||||||
|
- Report-section writes handle list-of-string content gracefully.
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
- [@CadeYu](https://github.com/CadeYu) — exchange-qualified ticker preservation (#413)
|
||||||
|
- [@yang1002378395-cmyk](https://github.com/yang1002378395-cmyk) — HTTP client SSL customisation (#379)
|
||||||
|
|
||||||
|
## [0.2.1] — 2026-03-15
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
- Patched `langchain-core` vulnerability (LangGrinch). (#335)
|
||||||
|
- Removed `chainlit` dependency affected by CVE-2026-22218.
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- `pyproject.toml` build-system configuration; the project now installs via
|
||||||
|
modern packaging tooling.
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- `setup.py` — dependencies consolidated to `pyproject.toml`.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Risk manager reads the correct fundamental report source. (#341)
|
||||||
|
- All `open()` calls receive an explicit UTF-8 encoding (initial pass).
|
||||||
|
- `get_indicators` tool handles comma-separated indicator names from the LLM. (#368)
|
||||||
|
- `Propagation` initialises every debate-state field so risk debaters never
|
||||||
|
see missing keys.
|
||||||
|
- Stock data parsing tolerates malformed CSVs and NaN values.
|
||||||
|
- Conditional debate logic respects the configured round count. (#361)
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
- [@RinZ27](https://github.com/RinZ27) — `langchain-core` security patch (#335)
|
||||||
|
- [@Ljx-007](https://github.com/Ljx-007) — risk manager fundamental-report fix (#341)
|
||||||
|
- [@makk9](https://github.com/makk9) — debate-rounds config issue (#361)
|
||||||
|
|
||||||
|
## [0.2.0] — 2026-02-04
|
||||||
|
|
||||||
|
This is the largest release since the initial public version. The framework
|
||||||
|
moved from single-provider to a multi-provider architecture and grew several
|
||||||
|
production-ready surfaces.
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Multi-provider LLM support** (OpenAI, Google, Anthropic, xAI, OpenRouter,
|
||||||
|
Ollama) via a factory pattern, with provider-specific thinking configurations.
|
||||||
|
- **Alpha Vantage** integration as a configurable primary data provider, with
|
||||||
|
yfinance as a community-stability fallback.
|
||||||
|
- **Footer statistics** in the CLI: real-time tracking of LLM calls, tool
|
||||||
|
calls, and token usage via LangChain callbacks.
|
||||||
|
- **Post-analysis report saving** — the framework writes per-section markdown
|
||||||
|
files (analyst reports, debate transcripts, final decision) when a run
|
||||||
|
completes.
|
||||||
|
- **Announcements panel** — fetches updates from `api.tauric.ai/v1/announcements`
|
||||||
|
for the CLI welcome screen.
|
||||||
|
- **Tool fallbacks** so a single vendor outage does not stop the pipeline.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Risky / Safe risk debaters renamed to **Aggressive / Conservative** for
|
||||||
|
consistency with the displayed agent labels.
|
||||||
|
- Default data vendor switched to balance reliability and quota across
|
||||||
|
community deployments.
|
||||||
|
- Ollama and OpenRouter model lists updated; default endpoints clarified.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Analyst status tracking and message deduplication in the live display.
|
||||||
|
- Infinite-loop guard in the agent loop; reflection and logging hardened.
|
||||||
|
- Various data-vendor implementation bugs and tool-signature mismatches.
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
This release is the first with substantial outside contributions; many community
|
||||||
|
PRs from late 2025 also landed here.
|
||||||
|
|
||||||
|
- [@luohy15](https://github.com/luohy15) — Alpha Vantage data-vendor integration (#235)
|
||||||
|
- [@EdwardoSunny](https://github.com/EdwardoSunny) — yfinance fetching optimisations (#245)
|
||||||
|
- [@Mirza-Samad-Ahmed-Baig](https://github.com/Mirza-Samad-Ahmed-Baig) — infinite-loop guard, reflection, and logging fixes (#89)
|
||||||
|
- [@ZeroAct](https://github.com/ZeroAct) — saved results path support (#29)
|
||||||
|
- [@Zhongyi-Lu](https://github.com/Zhongyi-Lu) — `.env` gitignore (#49)
|
||||||
|
- [@csoboy](https://github.com/csoboy) — local Ollama setup (#53)
|
||||||
|
- [@chauhang](https://github.com/chauhang) — initial Docker support attempt (#47, later reverted; the merged Docker support shipped in v0.2.4)
|
||||||
|
|
||||||
|
## [0.1.1] — 2025-06-07
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- Static site assets that had been bundled with v0.1.0; the public site now
|
||||||
|
lives separately.
|
||||||
|
|
||||||
|
## [0.1.0] — 2025-06-05
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Initial public release** of the TradingAgents multi-agent trading
|
||||||
|
framework: market / sentiment / news / fundamentals analysts; bull and bear
|
||||||
|
researchers; trader; aggressive, conservative, and neutral risk debaters;
|
||||||
|
portfolio manager. LangGraph orchestration, yfinance data, per-agent
|
||||||
|
BM25 memory, single-provider OpenAI integration, interactive CLI.
|
||||||
|
|
||||||
|
[0.2.4]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.3...v0.2.4
|
||||||
|
[0.2.3]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.2...v0.2.3
|
||||||
|
[0.2.2]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.1...v0.2.2
|
||||||
|
[0.2.1]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.0...v0.2.1
|
||||||
|
[0.2.0]: https://github.com/TauricResearch/TradingAgents/compare/v0.1.1...v0.2.0
|
||||||
|
[0.1.1]: https://github.com/TauricResearch/TradingAgents/compare/v0.1.0...v0.1.1
|
||||||
|
[0.1.0]: https://github.com/TauricResearch/TradingAgents/releases/tag/v0.1.0
|
||||||
27
Dockerfile
Normal file
27
Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
|
|
||||||
|
RUN python -m venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
COPY . .
|
||||||
|
RUN pip install --no-cache-dir .
|
||||||
|
|
||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
COPY --from=builder /opt/venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
RUN useradd --create-home appuser
|
||||||
|
USER appuser
|
||||||
|
WORKDIR /home/appuser/app
|
||||||
|
|
||||||
|
COPY --from=builder --chown=appuser:appuser /build .
|
||||||
|
|
||||||
|
ENTRYPOINT ["tradingagents"]
|
||||||
58
README.md
58
README.md
@@ -28,6 +28,8 @@
|
|||||||
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
- [2026-04] **TradingAgents v0.2.4** released with structured-output agents (Research Manager, Trader, Portfolio Manager), LangGraph checkpoint resume, persistent decision log, DeepSeek/Qwen/GLM/Azure provider support, Docker, and a Windows UTF-8 encoding fix. See [CHANGELOG.md](CHANGELOG.md) for the full list.
|
||||||
|
- [2026-03] **TradingAgents v0.2.3** released with multi-language support, GPT-5.4 family models, unified model catalog, backtesting date fidelity, and proxy support.
|
||||||
- [2026-03] **TradingAgents v0.2.2** released with GPT-5.4/Gemini 3.1/Claude 4.6 model coverage, five-tier rating scale, OpenAI Responses API, Anthropic effort control, and cross-platform stability.
|
- [2026-03] **TradingAgents v0.2.2** released with GPT-5.4/Gemini 3.1/Claude 4.6 model coverage, five-tier rating scale, OpenAI Responses API, Anthropic effort control, and cross-platform stability.
|
||||||
- [2026-02] **TradingAgents v0.2.0** released with multi-provider LLM support (GPT-5.x, Gemini 3.x, Claude 4.x, Grok 4.x) and improved system architecture.
|
- [2026-02] **TradingAgents v0.2.0** released with multi-provider LLM support (GPT-5.x, Gemini 3.x, Claude 4.x, Grok 4.x) and improved system architecture.
|
||||||
- [2026-01] **Trading-R1** [Technical Report](https://arxiv.org/abs/2509.11420) released, with [Terminal](https://github.com/TauricResearch/Trading-R1) expected to land soon.
|
- [2026-01] **Trading-R1** [Technical Report](https://arxiv.org/abs/2509.11420) released, with [Terminal](https://github.com/TauricResearch/Trading-R1) expected to land soon.
|
||||||
@@ -117,6 +119,19 @@ Install the package and its dependencies:
|
|||||||
pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
Alternatively, run with Docker:
|
||||||
|
```bash
|
||||||
|
cp .env.example .env # add your API keys
|
||||||
|
docker compose run --rm tradingagents
|
||||||
|
```
|
||||||
|
|
||||||
|
For local models with Ollama:
|
||||||
|
```bash
|
||||||
|
docker compose --profile ollama run --rm tradingagents-ollama
|
||||||
|
```
|
||||||
|
|
||||||
### Required APIs
|
### Required APIs
|
||||||
|
|
||||||
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
|
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
|
||||||
@@ -126,10 +141,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT)
|
|||||||
export GOOGLE_API_KEY=... # Google (Gemini)
|
export GOOGLE_API_KEY=... # Google (Gemini)
|
||||||
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
||||||
export XAI_API_KEY=... # xAI (Grok)
|
export XAI_API_KEY=... # xAI (Grok)
|
||||||
|
export DEEPSEEK_API_KEY=... # DeepSeek
|
||||||
|
export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope)
|
||||||
|
export ZHIPU_API_KEY=... # GLM (Zhipu)
|
||||||
export OPENROUTER_API_KEY=... # OpenRouter
|
export OPENROUTER_API_KEY=... # OpenRouter
|
||||||
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
||||||
|
|
||||||
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
||||||
|
|
||||||
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||||
@@ -164,7 +184,7 @@ An interface will appear showing results as they load, letting you track the age
|
|||||||
|
|
||||||
### Implementation Details
|
### Implementation Details
|
||||||
|
|
||||||
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama.
|
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise.
|
||||||
|
|
||||||
### Python Usage
|
### Python Usage
|
||||||
|
|
||||||
@@ -188,9 +208,9 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama
|
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, openrouter, ollama, azure
|
||||||
config["deep_think_llm"] = "gpt-5.2" # Model for complex reasoning
|
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
|
||||||
config["quick_think_llm"] = "gpt-5-mini" # Model for quick tasks
|
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
||||||
config["max_debate_rounds"] = 2
|
config["max_debate_rounds"] = 2
|
||||||
|
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
ta = TradingAgentsGraph(debug=True, config=config)
|
||||||
@@ -200,10 +220,40 @@ print(decision)
|
|||||||
|
|
||||||
See `tradingagents/default_config.py` for all configuration options.
|
See `tradingagents/default_config.py` for all configuration options.
|
||||||
|
|
||||||
|
## Persistence and Recovery
|
||||||
|
|
||||||
|
TradingAgents persists two kinds of state across runs.
|
||||||
|
|
||||||
|
### Decision log
|
||||||
|
|
||||||
|
The decision log is always on. Each completed run appends its decision to `~/.tradingagents/memory/trading_memory.md`. On the next run for the same ticker, TradingAgents fetches the realised return (raw and alpha vs SPY), generates a one-paragraph reflection, and injects the most recent same-ticker decisions plus recent cross-ticker lessons into the Portfolio Manager prompt, so each analysis carries forward what worked and what didn't.
|
||||||
|
|
||||||
|
Override the path with `TRADINGAGENTS_MEMORY_LOG_PATH`.
|
||||||
|
|
||||||
|
### Checkpoint resume
|
||||||
|
|
||||||
|
Checkpoint resume is opt-in via `--checkpoint`. When enabled, LangGraph saves state after each node so a crashed or interrupted run resumes from the last successful step instead of starting over. On a resume run you will see `Resuming from step N for <TICKER> on <date>` in the logs; on a new run you will see `Starting fresh`. Checkpoints are cleared automatically on successful completion.
|
||||||
|
|
||||||
|
Per-ticker SQLite databases live at `~/.tradingagents/cache/checkpoints/<TICKER>.db` (override the base with `TRADINGAGENTS_CACHE_DIR`). Use `--clear-checkpoints` to reset all of them before a run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tradingagents analyze --checkpoint # enable for this run
|
||||||
|
tradingagents analyze --clear-checkpoints # reset before running
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = DEFAULT_CONFIG.copy()
|
||||||
|
config["checkpoint_enabled"] = True
|
||||||
|
ta = TradingAgentsGraph(config=config)
|
||||||
|
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||||
|
```
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||||
|
|
||||||
|
Past contributions, including code, design feedback, and bug reports, are credited per release in [`CHANGELOG.md`](CHANGELOG.md).
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
Please reference our work if you find *TradingAgents* provides you with some help :)
|
Please reference our work if you find *TradingAgents* provides you with some help :)
|
||||||
|
|||||||
124
cli/main.py
124
cli/main.py
@@ -6,8 +6,9 @@ from functools import wraps
|
|||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Load environment variables from .env file
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
load_dotenv(".env.enterprise", override=False)
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.spinner import Spinner
|
from rich.spinner import Spinner
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
@@ -79,7 +80,7 @@ class MessageBuffer:
|
|||||||
self.current_agent = None
|
self.current_agent = None
|
||||||
self.report_sections = {}
|
self.report_sections = {}
|
||||||
self.selected_analysts = []
|
self.selected_analysts = []
|
||||||
self._last_message_id = None
|
self._processed_message_ids = set()
|
||||||
|
|
||||||
def init_for_analysis(self, selected_analysts):
|
def init_for_analysis(self, selected_analysts):
|
||||||
"""Initialize agent status and report sections based on selected analysts.
|
"""Initialize agent status and report sections based on selected analysts.
|
||||||
@@ -114,7 +115,7 @@ class MessageBuffer:
|
|||||||
self.current_agent = None
|
self.current_agent = None
|
||||||
self.messages.clear()
|
self.messages.clear()
|
||||||
self.tool_calls.clear()
|
self.tool_calls.clear()
|
||||||
self._last_message_id = None
|
self._processed_message_ids.clear()
|
||||||
|
|
||||||
def get_completed_reports_count(self):
|
def get_completed_reports_count(self):
|
||||||
"""Count reports that are finalized (their finalizing agent is completed).
|
"""Count reports that are finalized (their finalizing agent is completed).
|
||||||
@@ -462,7 +463,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
|
|||||||
def get_user_selections():
|
def get_user_selections():
|
||||||
"""Get all user selections before starting the analysis display."""
|
"""Get all user selections before starting the analysis display."""
|
||||||
# Display ASCII art welcome message
|
# Display ASCII art welcome message
|
||||||
with open(Path(__file__).parent / "static" / "welcome.txt", "r") as f:
|
with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f:
|
||||||
welcome_ascii = f.read()
|
welcome_ascii = f.read()
|
||||||
|
|
||||||
# Create welcome box content
|
# Create welcome box content
|
||||||
@@ -519,10 +520,19 @@ def get_user_selections():
|
|||||||
)
|
)
|
||||||
analysis_date = get_analysis_date()
|
analysis_date = get_analysis_date()
|
||||||
|
|
||||||
# Step 3: Select analysts
|
# Step 3: Output language
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
|
"Step 3: Output Language",
|
||||||
|
"Select the language for analyst reports and final decision"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output_language = ask_output_language()
|
||||||
|
|
||||||
|
# Step 4: Select analysts
|
||||||
|
console.print(
|
||||||
|
create_question_box(
|
||||||
|
"Step 4: Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
selected_analysts = select_analysts()
|
selected_analysts = select_analysts()
|
||||||
@@ -530,32 +540,32 @@ def get_user_selections():
|
|||||||
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 4: Research depth
|
# Step 5: Research depth
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 4: Research Depth", "Select your research depth level"
|
"Step 5: Research Depth", "Select your research depth level"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
selected_research_depth = select_research_depth()
|
selected_research_depth = select_research_depth()
|
||||||
|
|
||||||
# Step 5: OpenAI backend
|
# Step 6: LLM Provider
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 5: OpenAI backend", "Select which service to talk to"
|
"Step 6: LLM Provider", "Select your LLM provider"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
selected_llm_provider, backend_url = select_llm_provider()
|
selected_llm_provider, backend_url = select_llm_provider()
|
||||||
|
|
||||||
# Step 6: Thinking agents
|
# Step 7: Thinking agents
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
|
"Step 7: Thinking Agents", "Select your thinking agents for analysis"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
||||||
|
|
||||||
# Step 7: Provider-specific thinking configuration
|
# Step 8: Provider-specific thinking configuration
|
||||||
thinking_level = None
|
thinking_level = None
|
||||||
reasoning_effort = None
|
reasoning_effort = None
|
||||||
anthropic_effort = None
|
anthropic_effort = None
|
||||||
@@ -564,7 +574,7 @@ def get_user_selections():
|
|||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 7: Thinking Mode",
|
"Step 8: Thinking Mode",
|
||||||
"Configure Gemini thinking mode"
|
"Configure Gemini thinking mode"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -572,7 +582,7 @@ def get_user_selections():
|
|||||||
elif provider_lower == "openai":
|
elif provider_lower == "openai":
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 7: Reasoning Effort",
|
"Step 8: Reasoning Effort",
|
||||||
"Configure OpenAI reasoning effort level"
|
"Configure OpenAI reasoning effort level"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -580,7 +590,7 @@ def get_user_selections():
|
|||||||
elif provider_lower == "anthropic":
|
elif provider_lower == "anthropic":
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 7: Effort Level",
|
"Step 8: Effort Level",
|
||||||
"Configure Claude effort level"
|
"Configure Claude effort level"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -598,6 +608,7 @@ def get_user_selections():
|
|||||||
"google_thinking_level": thinking_level,
|
"google_thinking_level": thinking_level,
|
||||||
"openai_reasoning_effort": reasoning_effort,
|
"openai_reasoning_effort": reasoning_effort,
|
||||||
"anthropic_effort": anthropic_effort,
|
"anthropic_effort": anthropic_effort,
|
||||||
|
"output_language": output_language,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -635,19 +646,19 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||||||
analyst_parts = []
|
analyst_parts = []
|
||||||
if final_state.get("market_report"):
|
if final_state.get("market_report"):
|
||||||
analysts_dir.mkdir(exist_ok=True)
|
analysts_dir.mkdir(exist_ok=True)
|
||||||
(analysts_dir / "market.md").write_text(final_state["market_report"])
|
(analysts_dir / "market.md").write_text(final_state["market_report"], encoding="utf-8")
|
||||||
analyst_parts.append(("Market Analyst", final_state["market_report"]))
|
analyst_parts.append(("Market Analyst", final_state["market_report"]))
|
||||||
if final_state.get("sentiment_report"):
|
if final_state.get("sentiment_report"):
|
||||||
analysts_dir.mkdir(exist_ok=True)
|
analysts_dir.mkdir(exist_ok=True)
|
||||||
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"])
|
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"], encoding="utf-8")
|
||||||
analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
|
analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
|
||||||
if final_state.get("news_report"):
|
if final_state.get("news_report"):
|
||||||
analysts_dir.mkdir(exist_ok=True)
|
analysts_dir.mkdir(exist_ok=True)
|
||||||
(analysts_dir / "news.md").write_text(final_state["news_report"])
|
(analysts_dir / "news.md").write_text(final_state["news_report"], encoding="utf-8")
|
||||||
analyst_parts.append(("News Analyst", final_state["news_report"]))
|
analyst_parts.append(("News Analyst", final_state["news_report"]))
|
||||||
if final_state.get("fundamentals_report"):
|
if final_state.get("fundamentals_report"):
|
||||||
analysts_dir.mkdir(exist_ok=True)
|
analysts_dir.mkdir(exist_ok=True)
|
||||||
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"])
|
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"], encoding="utf-8")
|
||||||
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
|
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
|
||||||
if analyst_parts:
|
if analyst_parts:
|
||||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
|
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
|
||||||
@@ -660,15 +671,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||||||
research_parts = []
|
research_parts = []
|
||||||
if debate.get("bull_history"):
|
if debate.get("bull_history"):
|
||||||
research_dir.mkdir(exist_ok=True)
|
research_dir.mkdir(exist_ok=True)
|
||||||
(research_dir / "bull.md").write_text(debate["bull_history"])
|
(research_dir / "bull.md").write_text(debate["bull_history"], encoding="utf-8")
|
||||||
research_parts.append(("Bull Researcher", debate["bull_history"]))
|
research_parts.append(("Bull Researcher", debate["bull_history"]))
|
||||||
if debate.get("bear_history"):
|
if debate.get("bear_history"):
|
||||||
research_dir.mkdir(exist_ok=True)
|
research_dir.mkdir(exist_ok=True)
|
||||||
(research_dir / "bear.md").write_text(debate["bear_history"])
|
(research_dir / "bear.md").write_text(debate["bear_history"], encoding="utf-8")
|
||||||
research_parts.append(("Bear Researcher", debate["bear_history"]))
|
research_parts.append(("Bear Researcher", debate["bear_history"]))
|
||||||
if debate.get("judge_decision"):
|
if debate.get("judge_decision"):
|
||||||
research_dir.mkdir(exist_ok=True)
|
research_dir.mkdir(exist_ok=True)
|
||||||
(research_dir / "manager.md").write_text(debate["judge_decision"])
|
(research_dir / "manager.md").write_text(debate["judge_decision"], encoding="utf-8")
|
||||||
research_parts.append(("Research Manager", debate["judge_decision"]))
|
research_parts.append(("Research Manager", debate["judge_decision"]))
|
||||||
if research_parts:
|
if research_parts:
|
||||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
|
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
|
||||||
@@ -678,7 +689,7 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||||||
if final_state.get("trader_investment_plan"):
|
if final_state.get("trader_investment_plan"):
|
||||||
trading_dir = save_path / "3_trading"
|
trading_dir = save_path / "3_trading"
|
||||||
trading_dir.mkdir(exist_ok=True)
|
trading_dir.mkdir(exist_ok=True)
|
||||||
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"])
|
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"], encoding="utf-8")
|
||||||
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
|
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
|
||||||
|
|
||||||
# 4. Risk Management
|
# 4. Risk Management
|
||||||
@@ -688,15 +699,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||||||
risk_parts = []
|
risk_parts = []
|
||||||
if risk.get("aggressive_history"):
|
if risk.get("aggressive_history"):
|
||||||
risk_dir.mkdir(exist_ok=True)
|
risk_dir.mkdir(exist_ok=True)
|
||||||
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"])
|
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"], encoding="utf-8")
|
||||||
risk_parts.append(("Aggressive Analyst", risk["aggressive_history"]))
|
risk_parts.append(("Aggressive Analyst", risk["aggressive_history"]))
|
||||||
if risk.get("conservative_history"):
|
if risk.get("conservative_history"):
|
||||||
risk_dir.mkdir(exist_ok=True)
|
risk_dir.mkdir(exist_ok=True)
|
||||||
(risk_dir / "conservative.md").write_text(risk["conservative_history"])
|
(risk_dir / "conservative.md").write_text(risk["conservative_history"], encoding="utf-8")
|
||||||
risk_parts.append(("Conservative Analyst", risk["conservative_history"]))
|
risk_parts.append(("Conservative Analyst", risk["conservative_history"]))
|
||||||
if risk.get("neutral_history"):
|
if risk.get("neutral_history"):
|
||||||
risk_dir.mkdir(exist_ok=True)
|
risk_dir.mkdir(exist_ok=True)
|
||||||
(risk_dir / "neutral.md").write_text(risk["neutral_history"])
|
(risk_dir / "neutral.md").write_text(risk["neutral_history"], encoding="utf-8")
|
||||||
risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
|
risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
|
||||||
if risk_parts:
|
if risk_parts:
|
||||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
|
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
|
||||||
@@ -706,12 +717,12 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||||||
if risk.get("judge_decision"):
|
if risk.get("judge_decision"):
|
||||||
portfolio_dir = save_path / "5_portfolio"
|
portfolio_dir = save_path / "5_portfolio"
|
||||||
portfolio_dir.mkdir(exist_ok=True)
|
portfolio_dir.mkdir(exist_ok=True)
|
||||||
(portfolio_dir / "decision.md").write_text(risk["judge_decision"])
|
(portfolio_dir / "decision.md").write_text(risk["judge_decision"], encoding="utf-8")
|
||||||
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
|
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
|
||||||
|
|
||||||
# Write consolidated report
|
# Write consolidated report
|
||||||
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections))
|
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections), encoding="utf-8")
|
||||||
return save_path / "complete_report.md"
|
return save_path / "complete_report.md"
|
||||||
|
|
||||||
|
|
||||||
@@ -915,7 +926,7 @@ def format_tool_args(args, max_length=80) -> str:
|
|||||||
return result[:max_length - 3] + "..."
|
return result[:max_length - 3] + "..."
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def run_analysis():
|
def run_analysis(checkpoint: bool = False):
|
||||||
# First get all user selections
|
# First get all user selections
|
||||||
selections = get_user_selections()
|
selections = get_user_selections()
|
||||||
|
|
||||||
@@ -931,6 +942,8 @@ def run_analysis():
|
|||||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||||
config["anthropic_effort"] = selections.get("anthropic_effort")
|
config["anthropic_effort"] = selections.get("anthropic_effort")
|
||||||
|
config["output_language"] = selections.get("output_language", "English")
|
||||||
|
config["checkpoint_enabled"] = checkpoint
|
||||||
|
|
||||||
# Create stats callback handler for tracking LLM/tool calls
|
# Create stats callback handler for tracking LLM/tool calls
|
||||||
stats_handler = StatsCallbackHandler()
|
stats_handler = StatsCallbackHandler()
|
||||||
@@ -968,7 +981,7 @@ def run_analysis():
|
|||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
timestamp, message_type, content = obj.messages[-1]
|
timestamp, message_type, content = obj.messages[-1]
|
||||||
content = content.replace("\n", " ") # Replace newlines with spaces
|
content = content.replace("\n", " ") # Replace newlines with spaces
|
||||||
with open(log_file, "a") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@@ -979,7 +992,7 @@ def run_analysis():
|
|||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
timestamp, tool_name, args = obj.tool_calls[-1]
|
timestamp, tool_name, args = obj.tool_calls[-1]
|
||||||
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
||||||
with open(log_file, "a") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@@ -993,7 +1006,7 @@ def run_analysis():
|
|||||||
if content:
|
if content:
|
||||||
file_name = f"{section_name}.md"
|
file_name = f"{section_name}.md"
|
||||||
text = "\n".join(str(item) for item in content) if isinstance(content, list) else content
|
text = "\n".join(str(item) for item in content) if isinstance(content, list) else content
|
||||||
with open(report_dir / file_name, "w") as f:
|
with open(report_dir / file_name, "w", encoding="utf-8") as f:
|
||||||
f.write(text)
|
f.write(text)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@@ -1041,26 +1054,22 @@ def run_analysis():
|
|||||||
# Stream the analysis
|
# Stream the analysis
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||||
# Process messages if present (skip duplicates via message ID)
|
# Process all messages in chunk, deduplicating by message ID
|
||||||
if len(chunk["messages"]) > 0:
|
for message in chunk.get("messages", []):
|
||||||
last_message = chunk["messages"][-1]
|
msg_id = getattr(message, "id", None)
|
||||||
msg_id = getattr(last_message, "id", None)
|
if msg_id is not None:
|
||||||
|
if msg_id in message_buffer._processed_message_ids:
|
||||||
|
continue
|
||||||
|
message_buffer._processed_message_ids.add(msg_id)
|
||||||
|
|
||||||
if msg_id != message_buffer._last_message_id:
|
msg_type, content = classify_message_type(message)
|
||||||
message_buffer._last_message_id = msg_id
|
|
||||||
|
|
||||||
# Add message to buffer
|
|
||||||
msg_type, content = classify_message_type(last_message)
|
|
||||||
if content and content.strip():
|
if content and content.strip():
|
||||||
message_buffer.add_message(msg_type, content)
|
message_buffer.add_message(msg_type, content)
|
||||||
|
|
||||||
# Handle tool calls
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
for tool_call in last_message.tool_calls:
|
|
||||||
if isinstance(tool_call, dict):
|
if isinstance(tool_call, dict):
|
||||||
message_buffer.add_tool_call(
|
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||||
tool_call["name"], tool_call["args"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||||
|
|
||||||
@@ -1189,8 +1198,23 @@ def run_analysis():
|
|||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def analyze():
|
def analyze(
|
||||||
run_analysis()
|
checkpoint: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--checkpoint",
|
||||||
|
help="Enable checkpoint/resume: save state after each node so a crashed run can resume.",
|
||||||
|
),
|
||||||
|
clear_checkpoints: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--clear-checkpoints",
|
||||||
|
help="Delete all saved checkpoints before running (force fresh start).",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
if clear_checkpoints:
|
||||||
|
from tradingagents.graph.checkpointer import clear_all_checkpoints
|
||||||
|
n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"])
|
||||||
|
console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]")
|
||||||
|
run_analysis(checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
244
cli/utils.py
244
cli/utils.py
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Dict
|
|||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
|
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@@ -133,51 +134,70 @@ def select_research_depth() -> int:
|
|||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def select_shallow_thinking_agent(provider) -> str:
|
def _fetch_openrouter_models() -> List[Tuple[str, str]]:
|
||||||
"""Select shallow thinking llm engine using an interactive selection."""
|
"""Fetch available models from the OpenRouter API."""
|
||||||
|
import requests
|
||||||
|
try:
|
||||||
|
resp = requests.get("https://openrouter.ai/api/v1/models", timeout=10)
|
||||||
|
resp.raise_for_status()
|
||||||
|
models = resp.json().get("data", [])
|
||||||
|
return [(m.get("name") or m["id"], m["id"]) for m in models]
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"\n[yellow]Could not fetch OpenRouter models: {e}[/yellow]")
|
||||||
|
return []
|
||||||
|
|
||||||
# Define shallow thinking llm engine options with their corresponding model names
|
|
||||||
# Ordering: medium → light → heavy (balanced first for quick tasks)
|
def select_openrouter_model() -> str:
|
||||||
# Within same tier, newer models first
|
"""Select an OpenRouter model from the newest available, or enter a custom ID."""
|
||||||
SHALLOW_AGENT_OPTIONS = {
|
models = _fetch_openrouter_models()
|
||||||
"openai": [
|
|
||||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
choices = [questionary.Choice(name, value=mid) for name, mid in models[:5]]
|
||||||
("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"),
|
choices.append(questionary.Choice("Custom model ID", value="custom"))
|
||||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
|
||||||
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
|
|
||||||
],
|
|
||||||
"anthropic": [
|
|
||||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
|
||||||
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
|
|
||||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
|
||||||
],
|
|
||||||
"google": [
|
|
||||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
|
||||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
|
||||||
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
|
|
||||||
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
|
||||||
],
|
|
||||||
"xai": [
|
|
||||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
|
||||||
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
|
||||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
|
||||||
],
|
|
||||||
"openrouter": [
|
|
||||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
|
||||||
],
|
|
||||||
"ollama": [
|
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
"Select Your [Quick-Thinking LLM Engine]:",
|
"Select OpenRouter Model (latest available):",
|
||||||
|
choices=choices,
|
||||||
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:magenta noinherit"),
|
||||||
|
("highlighted", "fg:magenta noinherit"),
|
||||||
|
("pointer", "fg:magenta noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if choice is None or choice == "custom":
|
||||||
|
return questionary.text(
|
||||||
|
"Enter OpenRouter model ID (e.g. google/gemma-4-26b-a4b-it):",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
|
return choice
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_custom_model_id() -> str:
|
||||||
|
"""Prompt user to type a custom model ID."""
|
||||||
|
return questionary.text(
|
||||||
|
"Enter model ID:",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _select_model(provider: str, mode: str) -> str:
|
||||||
|
"""Select a model for the given provider and mode (quick/deep)."""
|
||||||
|
if provider.lower() == "openrouter":
|
||||||
|
return select_openrouter_model()
|
||||||
|
|
||||||
|
if provider.lower() == "azure":
|
||||||
|
return questionary.text(
|
||||||
|
f"Enter Azure deployment name ({mode}-thinking):",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
|
choice = questionary.select(
|
||||||
|
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
|
||||||
choices=[
|
choices=[
|
||||||
questionary.Choice(display, value=value)
|
questionary.Choice(display, value=value)
|
||||||
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
|
for display, value in get_model_options(provider, mode)
|
||||||
],
|
],
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
style=questionary.Style(
|
style=questionary.Style(
|
||||||
@@ -190,95 +210,45 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print(
|
console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
|
||||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
|
||||||
)
|
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
if choice == "custom":
|
||||||
|
return _prompt_custom_model_id()
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
|
def select_shallow_thinking_agent(provider) -> str:
|
||||||
|
"""Select shallow thinking llm engine using an interactive selection."""
|
||||||
|
return _select_model(provider, "quick")
|
||||||
|
|
||||||
|
|
||||||
def select_deep_thinking_agent(provider) -> str:
|
def select_deep_thinking_agent(provider) -> str:
|
||||||
"""Select deep thinking llm engine using an interactive selection."""
|
"""Select deep thinking llm engine using an interactive selection."""
|
||||||
|
return _select_model(provider, "deep")
|
||||||
|
|
||||||
# Define deep thinking llm engine options with their corresponding model names
|
def select_llm_provider() -> tuple[str, str | None]:
|
||||||
# Ordering: heavy → medium → light (most capable first for deep tasks)
|
"""Select the LLM provider and its API endpoint."""
|
||||||
# Within same tier, newer models first
|
# (display_name, provider_key, base_url)
|
||||||
DEEP_AGENT_OPTIONS = {
|
PROVIDERS = [
|
||||||
"openai": [
|
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
("Google", "google", None),
|
||||||
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
|
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
||||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
("xAI", "xai", "https://api.x.ai/v1"),
|
||||||
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
|
("DeepSeek", "deepseek", "https://api.deepseek.com"),
|
||||||
],
|
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||||
"anthropic": [
|
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||||
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
|
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||||
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
("Azure OpenAI", "azure", None),
|
||||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
("Ollama", "ollama", "http://localhost:11434/v1"),
|
||||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
|
||||||
],
|
|
||||||
"google": [
|
|
||||||
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
|
|
||||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
|
||||||
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
|
|
||||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
|
||||||
],
|
|
||||||
"xai": [
|
|
||||||
("Grok 4 - Flagship model", "grok-4-0709"),
|
|
||||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
|
||||||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
|
||||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
|
||||||
],
|
|
||||||
"openrouter": [
|
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
|
||||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
|
||||||
],
|
|
||||||
"ollama": [
|
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
choice = questionary.select(
|
|
||||||
"Select Your [Deep-Thinking LLM Engine]:",
|
|
||||||
choices=[
|
|
||||||
questionary.Choice(display, value=value)
|
|
||||||
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
|
|
||||||
],
|
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
|
||||||
style=questionary.Style(
|
|
||||||
[
|
|
||||||
("selected", "fg:magenta noinherit"),
|
|
||||||
("highlighted", "fg:magenta noinherit"),
|
|
||||||
("pointer", "fg:magenta noinherit"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
).ask()
|
|
||||||
|
|
||||||
if choice is None:
|
|
||||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
return choice
|
|
||||||
|
|
||||||
def select_llm_provider() -> tuple[str, str]:
|
|
||||||
"""Select the OpenAI api url using interactive selection."""
|
|
||||||
# Define OpenAI api options with their corresponding endpoints
|
|
||||||
BASE_URLS = [
|
|
||||||
("OpenAI", "https://api.openai.com/v1"),
|
|
||||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
|
||||||
("Anthropic", "https://api.anthropic.com/"),
|
|
||||||
("xAI", "https://api.x.ai/v1"),
|
|
||||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
|
||||||
("Ollama", "http://localhost:11434/v1"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
"Select your LLM Provider:",
|
"Select your LLM Provider:",
|
||||||
choices=[
|
choices=[
|
||||||
questionary.Choice(display, value=(display, value))
|
questionary.Choice(display, value=(provider_key, url))
|
||||||
for display, value in BASE_URLS
|
for display, provider_key, url in PROVIDERS
|
||||||
],
|
],
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
style=questionary.Style(
|
style=questionary.Style(
|
||||||
@@ -291,13 +261,11 @@ def select_llm_provider() -> tuple[str, str]:
|
|||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
display_name, url = choice
|
provider, url = choice
|
||||||
print(f"You selected: {display_name}\tURL: {url}")
|
return provider, url
|
||||||
|
|
||||||
return display_name, url
|
|
||||||
|
|
||||||
|
|
||||||
def ask_openai_reasoning_effort() -> str:
|
def ask_openai_reasoning_effort() -> str:
|
||||||
@@ -356,3 +324,37 @@ def ask_gemini_thinking_config() -> str | None:
|
|||||||
("pointer", "fg:green noinherit"),
|
("pointer", "fg:green noinherit"),
|
||||||
]),
|
]),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ask_output_language() -> str:
|
||||||
|
"""Ask for report output language."""
|
||||||
|
choice = questionary.select(
|
||||||
|
"Select Output Language:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice("English (default)", "English"),
|
||||||
|
questionary.Choice("Chinese (中文)", "Chinese"),
|
||||||
|
questionary.Choice("Japanese (日本語)", "Japanese"),
|
||||||
|
questionary.Choice("Korean (한국어)", "Korean"),
|
||||||
|
questionary.Choice("Hindi (हिन्दी)", "Hindi"),
|
||||||
|
questionary.Choice("Spanish (Español)", "Spanish"),
|
||||||
|
questionary.Choice("Portuguese (Português)", "Portuguese"),
|
||||||
|
questionary.Choice("French (Français)", "French"),
|
||||||
|
questionary.Choice("German (Deutsch)", "German"),
|
||||||
|
questionary.Choice("Arabic (العربية)", "Arabic"),
|
||||||
|
questionary.Choice("Russian (Русский)", "Russian"),
|
||||||
|
questionary.Choice("Custom language", "custom"),
|
||||||
|
],
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:yellow noinherit"),
|
||||||
|
("highlighted", "fg:yellow noinherit"),
|
||||||
|
("pointer", "fg:yellow noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if choice == "custom":
|
||||||
|
return questionary.text(
|
||||||
|
"Enter language name (e.g. Turkish, Vietnamese, Thai, Indonesian):",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a language name.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
|
return choice
|
||||||
|
|||||||
35
docker-compose.yml
Normal file
35
docker-compose.yml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
services:
|
||||||
|
tradingagents:
|
||||||
|
build: .
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- tradingagents_data:/home/appuser/.tradingagents
|
||||||
|
tty: true
|
||||||
|
stdin_open: true
|
||||||
|
|
||||||
|
ollama:
|
||||||
|
image: ollama/ollama:latest
|
||||||
|
volumes:
|
||||||
|
- ollama_data:/root/.ollama
|
||||||
|
profiles:
|
||||||
|
- ollama
|
||||||
|
|
||||||
|
tradingagents-ollama:
|
||||||
|
build: .
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
- LLM_PROVIDER=ollama
|
||||||
|
volumes:
|
||||||
|
- tradingagents_data:/home/appuser/.tradingagents
|
||||||
|
depends_on:
|
||||||
|
- ollama
|
||||||
|
tty: true
|
||||||
|
stdin_open: true
|
||||||
|
profiles:
|
||||||
|
- ollama
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
tradingagents_data:
|
||||||
|
ollama_data:
|
||||||
4
main.py
4
main.py
@@ -8,8 +8,8 @@ load_dotenv()
|
|||||||
|
|
||||||
# Create a custom config
|
# Create a custom config
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["deep_think_llm"] = "gpt-5-mini" # Use a different model
|
config["deep_think_llm"] = "gpt-5.4-mini" # Use a different model
|
||||||
config["quick_think_llm"] = "gpt-5-mini" # Use a different model
|
config["quick_think_llm"] = "gpt-5.4-mini" # Use a different model
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||||
|
|
||||||
# Configure data vendors (default uses yfinance, no extra API keys needed)
|
# Configure data vendors (default uses yfinance, no extra API keys needed)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "tradingagents"
|
name = "tradingagents"
|
||||||
version = "0.2.2"
|
version = "0.2.4"
|
||||||
description = "TradingAgents: Multi-Agents LLM Financial Trading Framework"
|
description = "TradingAgents: Multi-Agents LLM Financial Trading Framework"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
@@ -13,14 +13,14 @@ dependencies = [
|
|||||||
"backtrader>=1.9.78.123",
|
"backtrader>=1.9.78.123",
|
||||||
"langchain-anthropic>=0.3.15",
|
"langchain-anthropic>=0.3.15",
|
||||||
"langchain-experimental>=0.3.4",
|
"langchain-experimental>=0.3.4",
|
||||||
"langchain-google-genai>=2.1.5",
|
"langchain-google-genai>=4.0.0",
|
||||||
"langchain-openai>=0.3.23",
|
"langchain-openai>=0.3.23",
|
||||||
"langgraph>=0.4.8",
|
"langgraph>=0.4.8",
|
||||||
|
"langgraph-checkpoint-sqlite>=2.0.0",
|
||||||
"pandas>=2.3.0",
|
"pandas>=2.3.0",
|
||||||
"parsel>=1.10.0",
|
"parsel>=1.10.0",
|
||||||
"pytz>=2025.2",
|
"pytz>=2025.2",
|
||||||
"questionary>=2.1.0",
|
"questionary>=2.1.0",
|
||||||
"rank-bm25>=0.2.2",
|
|
||||||
"redis>=6.2.0",
|
"redis>=6.2.0",
|
||||||
"requests>=2.32.4",
|
"requests>=2.32.4",
|
||||||
"rich>=14.0.0",
|
"rich>=14.0.0",
|
||||||
@@ -40,3 +40,15 @@ include = ["tradingagents*", "cli*"]
|
|||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
cli = ["static/*"]
|
cli = ["static/*"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
addopts = "-ra --strict-markers"
|
||||||
|
markers = [
|
||||||
|
"unit: fast isolated unit tests",
|
||||||
|
"integration: tests requiring external services",
|
||||||
|
"smoke: quick sanity-check tests",
|
||||||
|
]
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::DeprecationWarning",
|
||||||
|
]
|
||||||
|
|||||||
176
scripts/smoke_structured_output.py
Normal file
176
scripts/smoke_structured_output.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""End-to-end smoke for structured-output agents against a real LLM provider.
|
||||||
|
|
||||||
|
Runs the three decision-making agents (Research Manager, Trader, Portfolio
|
||||||
|
Manager) directly with their structured-output bindings and prints the
|
||||||
|
typed Pydantic instance + the rendered markdown for each. Use this to
|
||||||
|
verify a provider's native structured-output mode (json_schema for
|
||||||
|
OpenAI / xAI / DeepSeek / Qwen / GLM, response_schema for Gemini, tool-use
|
||||||
|
for Anthropic) returns clean instances on the schemas we ship.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
OPENAI_API_KEY=... python scripts/smoke_structured_output.py openai
|
||||||
|
GOOGLE_API_KEY=... python scripts/smoke_structured_output.py google
|
||||||
|
ANTHROPIC_API_KEY=... python scripts/smoke_structured_output.py anthropic
|
||||||
|
DEEPSEEK_API_KEY=... python scripts/smoke_structured_output.py deepseek
|
||||||
|
|
||||||
|
The script does NOT call propagate(), to keep the surface tight and the
|
||||||
|
cost low — it exercises only the three structured-output calls we just
|
||||||
|
added, plus the heuristic SignalProcessor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||||
|
from tradingagents.agents.managers.research_manager import create_research_manager
|
||||||
|
from tradingagents.agents.trader.trader import create_trader
|
||||||
|
from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_DEFAULTS = {
|
||||||
|
"openai": ("gpt-5.4-mini", None),
|
||||||
|
"google": ("gemini-2.5-flash", None),
|
||||||
|
"anthropic": ("claude-sonnet-4-6", None),
|
||||||
|
"deepseek": ("deepseek-chat", None),
|
||||||
|
"qwen": ("qwen-plus", None),
|
||||||
|
"glm": ("glm-5", None),
|
||||||
|
"xai": ("grok-4", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Minimal but realistic state for the three agents.
|
||||||
|
DEBATE_HISTORY = """
|
||||||
|
Bull Analyst: NVDA's data-center revenue grew 60% YoY last quarter, driven by
|
||||||
|
Blackwell ramp; sovereign AI deals with multiple governments add a $40B+
|
||||||
|
multi-year tailwind. Margins remain above peer average.
|
||||||
|
|
||||||
|
Bear Analyst: Concentration risk is real — top three customers are >40% of
|
||||||
|
revenue. Any pause in hyperscaler capex would compress the multiple. China
|
||||||
|
export restrictions still cap a meaningful portion of demand.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rm_state():
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"investment_debate_state": {
|
||||||
|
"history": DEBATE_HISTORY,
|
||||||
|
"bull_history": "Bull Analyst: NVDA's data-center revenue grew 60% YoY...",
|
||||||
|
"bear_history": "Bear Analyst: Concentration risk is real...",
|
||||||
|
"current_response": "",
|
||||||
|
"judge_decision": "",
|
||||||
|
"count": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_trader_state(investment_plan: str):
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"investment_plan": investment_plan,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pm_state(investment_plan: str, trader_plan: str):
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"past_context": "",
|
||||||
|
"risk_debate_state": {
|
||||||
|
"history": "Aggressive: lean in. Conservative: trim. Neutral: balanced sizing.",
|
||||||
|
"aggressive_history": "Aggressive: ...",
|
||||||
|
"conservative_history": "Conservative: ...",
|
||||||
|
"neutral_history": "Neutral: ...",
|
||||||
|
"judge_decision": "",
|
||||||
|
"current_aggressive_response": "",
|
||||||
|
"current_conservative_response": "",
|
||||||
|
"current_neutral_response": "",
|
||||||
|
"count": 1,
|
||||||
|
},
|
||||||
|
"market_report": "Market report.",
|
||||||
|
"sentiment_report": "Sentiment report.",
|
||||||
|
"news_report": "News report.",
|
||||||
|
"fundamentals_report": "Fundamentals report.",
|
||||||
|
"investment_plan": investment_plan,
|
||||||
|
"trader_investment_plan": trader_plan,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _print_section(title: str, content: str) -> None:
|
||||||
|
bar = "=" * 70
|
||||||
|
print(f"\n{bar}\n{title}\n{bar}\n{content}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument("provider", choices=list(PROVIDER_DEFAULTS.keys()))
|
||||||
|
parser.add_argument("--deep-model", default=None, help="Override deep_think_llm")
|
||||||
|
parser.add_argument("--quick-model", default=None, help="Override quick_think_llm")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
default_model, _ = PROVIDER_DEFAULTS[args.provider]
|
||||||
|
deep_model = args.deep_model or default_model
|
||||||
|
quick_model = args.quick_model or default_model
|
||||||
|
|
||||||
|
print(f"Provider: {args.provider}")
|
||||||
|
print(f"Deep model: {deep_model}")
|
||||||
|
print(f"Quick model: {quick_model}")
|
||||||
|
|
||||||
|
# Build the LLM clients via the framework's factory.
|
||||||
|
deep_client = create_llm_client(provider=args.provider, model=deep_model)
|
||||||
|
quick_client = create_llm_client(provider=args.provider, model=quick_model)
|
||||||
|
deep_llm = deep_client.get_llm()
|
||||||
|
quick_llm = quick_client.get_llm()
|
||||||
|
|
||||||
|
# 1) Research Manager
|
||||||
|
rm = create_research_manager(deep_llm)
|
||||||
|
rm_result = rm(_make_rm_state())
|
||||||
|
investment_plan = rm_result["investment_plan"]
|
||||||
|
_print_section("[1] Research Manager — investment_plan", investment_plan)
|
||||||
|
|
||||||
|
# 2) Trader (consumes RM's plan)
|
||||||
|
trader = create_trader(quick_llm)
|
||||||
|
trader_result = trader(_make_trader_state(investment_plan))
|
||||||
|
trader_plan = trader_result["trader_investment_plan"]
|
||||||
|
_print_section("[2] Trader — trader_investment_plan", trader_plan)
|
||||||
|
|
||||||
|
# 3) Portfolio Manager (consumes both)
|
||||||
|
pm = create_portfolio_manager(deep_llm)
|
||||||
|
pm_result = pm(_make_pm_state(investment_plan, trader_plan))
|
||||||
|
final_decision = pm_result["final_trade_decision"]
|
||||||
|
_print_section("[3] Portfolio Manager — final_trade_decision", final_decision)
|
||||||
|
|
||||||
|
# 4) SignalProcessor extracts the rating with zero LLM calls.
|
||||||
|
sp = SignalProcessor()
|
||||||
|
rating = sp.process_signal(final_decision)
|
||||||
|
_print_section("[4] SignalProcessor → rating", rating)
|
||||||
|
|
||||||
|
# 5) Lightweight checks: each rendered output should carry the expected
|
||||||
|
# section headers so downstream consumers (memory log, CLI display,
|
||||||
|
# saved reports) keep working.
|
||||||
|
checks = [
|
||||||
|
("Research Manager", investment_plan, ["**Recommendation**:"]),
|
||||||
|
("Trader", trader_plan, ["**Action**:", "FINAL TRANSACTION PROPOSAL:"]),
|
||||||
|
("Portfolio Manager", final_decision, ["**Rating**:", "**Executive Summary**:", "**Investment Thesis**:"]),
|
||||||
|
]
|
||||||
|
print("\n" + "=" * 70 + "\nStructure checks\n" + "=" * 70)
|
||||||
|
failures = 0
|
||||||
|
for name, text, required in checks:
|
||||||
|
for marker in required:
|
||||||
|
ok = marker in text
|
||||||
|
print(f" {'PASS' if ok else 'FAIL'} {name}: contains {marker!r}")
|
||||||
|
failures += int(not ok)
|
||||||
|
|
||||||
|
print()
|
||||||
|
if failures:
|
||||||
|
print(f"Smoke FAILED: {failures} structure check(s) missing.")
|
||||||
|
return 1
|
||||||
|
print("Smoke PASSED: structured output → rendered markdown chain works for", args.provider)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
42
tests/conftest.py
Normal file
42
tests/conftest.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""Shared pytest fixtures that prevent CI hangs when API keys are absent."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
for marker in ("unit", "integration", "smoke"):
|
||||||
|
config.addinivalue_line("markers", f"{marker}: {marker}-level tests")
|
||||||
|
|
||||||
|
|
||||||
|
_API_KEY_ENV_VARS = (
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"GOOGLE_API_KEY",
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"DEEPSEEK_API_KEY",
|
||||||
|
"DASHSCOPE_API_KEY",
|
||||||
|
"ZHIPU_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"AZURE_OPENAI_API_KEY",
|
||||||
|
"ALPHA_VANTAGE_API_KEY",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _dummy_api_keys(monkeypatch):
|
||||||
|
for env_var in _API_KEY_ENV_VARS:
|
||||||
|
monkeypatch.setenv(env_var, os.environ.get(env_var, "placeholder"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_llm_client():
|
||||||
|
client = MagicMock()
|
||||||
|
client.get_llm.return_value = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"tradingagents.llm_clients.factory.create_llm_client",
|
||||||
|
return_value=client,
|
||||||
|
):
|
||||||
|
yield client
|
||||||
147
tests/test_checkpoint_resume.py
Normal file
147
tests/test_checkpoint_resume.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
from langgraph.graph import END, StateGraph
|
||||||
|
|
||||||
|
from tradingagents.graph.checkpointer import (
|
||||||
|
checkpoint_step,
|
||||||
|
clear_checkpoint,
|
||||||
|
get_checkpointer,
|
||||||
|
has_checkpoint,
|
||||||
|
thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mutable flag to simulate crash on first run
|
||||||
|
_should_crash = False
|
||||||
|
|
||||||
|
|
||||||
|
class _SimpleState(TypedDict):
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
def _node_a(state: _SimpleState) -> dict:
|
||||||
|
return {"count": state["count"] + 1}
|
||||||
|
|
||||||
|
|
||||||
|
def _node_b(state: _SimpleState) -> dict:
|
||||||
|
if _should_crash:
|
||||||
|
raise RuntimeError("simulated mid-analysis crash")
|
||||||
|
return {"count": state["count"] + 10}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_graph() -> StateGraph:
|
||||||
|
builder = StateGraph(_SimpleState)
|
||||||
|
builder.add_node("analyst", _node_a)
|
||||||
|
builder.add_node("trader", _node_b)
|
||||||
|
builder.set_entry_point("analyst")
|
||||||
|
builder.add_edge("analyst", "trader")
|
||||||
|
builder.add_edge("trader", END)
|
||||||
|
return builder
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointResume(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.mkdtemp()
|
||||||
|
self.ticker = "TEST"
|
||||||
|
self.date = "2026-04-20"
|
||||||
|
|
||||||
|
def test_crash_and_resume(self):
|
||||||
|
"""Crash at 'trader' node, then resume from checkpoint."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
tid = thread_id(self.ticker, self.date)
|
||||||
|
cfg = {"configurable": {"thread_id": tid}}
|
||||||
|
|
||||||
|
# Run 1: crash at trader node
|
||||||
|
_should_crash = True
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
# Checkpoint should exist at step 1 (analyst completed)
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
|
||||||
|
self.assertEqual(step, 1)
|
||||||
|
|
||||||
|
# Run 2: resume — trader succeeds this time
|
||||||
|
_should_crash = False
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke(None, config=cfg)
|
||||||
|
|
||||||
|
# analyst added 1, trader added 10 → 11
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
def test_clear_checkpoint_allows_fresh_start(self):
|
||||||
|
"""After clearing, the graph starts from scratch."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
tid = thread_id(self.ticker, self.date)
|
||||||
|
cfg = {"configurable": {"thread_id": tid}}
|
||||||
|
|
||||||
|
# Create a checkpoint by crashing
|
||||||
|
_should_crash = True
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# Clear it
|
||||||
|
clear_checkpoint(self.tmpdir, self.ticker, self.date)
|
||||||
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# Fresh run succeeds from scratch
|
||||||
|
_should_crash = False
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_date_starts_fresh(self):
|
||||||
|
"""A different date must NOT resume from an existing checkpoint."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
date2 = "2026-04-21"
|
||||||
|
|
||||||
|
# Run with date1 — crash to leave a checkpoint
|
||||||
|
_should_crash = True
|
||||||
|
tid1 = thread_id(self.ticker, self.date)
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
|
||||||
|
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# date2 should have no checkpoint
|
||||||
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
|
||||||
|
|
||||||
|
# Run with date2 — should start fresh and succeed
|
||||||
|
_should_crash = False
|
||||||
|
tid2 = thread_id(self.ticker, date2)
|
||||||
|
self.assertNotEqual(tid1, tid2)
|
||||||
|
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
|
||||||
|
|
||||||
|
# Fresh run: analyst +1, trader +10 = 11
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
# Original date checkpoint still exists (untouched)
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
169
tests/test_deepseek_reasoning.py
Normal file
169
tests/test_deepseek_reasoning.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Tests for DeepSeekChatOpenAI thinking-mode behaviour.
|
||||||
|
|
||||||
|
Two pieces verified:
|
||||||
|
|
||||||
|
1. ``reasoning_content`` is captured on receive into the AIMessage's
|
||||||
|
``additional_kwargs`` and re-attached on send so DeepSeek's API
|
||||||
|
sees the same value across turns.
|
||||||
|
2. ``with_structured_output`` raises NotImplementedError for
|
||||||
|
``deepseek-reasoner`` so the agent factories' free-text fallback
|
||||||
|
handles the request instead of failing at runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from langchain_core.prompt_values import ChatPromptValue
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.openai_client import (
|
||||||
|
DeepSeekChatOpenAI,
|
||||||
|
NormalizedChatOpenAI,
|
||||||
|
_input_to_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
||||||
|
# (Gemini bot review note: non-list inputs must also work)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestInputToMessages:
|
||||||
|
def test_list_input_returned_as_is(self):
|
||||||
|
msgs = [HumanMessage(content="hi")]
|
||||||
|
assert _input_to_messages(msgs) is msgs
|
||||||
|
|
||||||
|
def test_chat_prompt_value_unwrapped(self):
|
||||||
|
msgs = [HumanMessage(content="hi")]
|
||||||
|
prompt_value = ChatPromptValue(messages=msgs)
|
||||||
|
assert _input_to_messages(prompt_value) == msgs
|
||||||
|
|
||||||
|
def test_string_input_yields_empty_list(self):
|
||||||
|
# A bare string isn't a message-bearing input; the caller's normal
|
||||||
|
# langchain conversion happens upstream of _get_request_payload.
|
||||||
|
assert _input_to_messages("hello") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reasoning content propagation across turns
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestDeepSeekReasoningContent:
|
||||||
|
def _client(self):
|
||||||
|
os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder")
|
||||||
|
return DeepSeekChatOpenAI(
|
||||||
|
model="deepseek-v4-flash",
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_capture_on_receive(self):
|
||||||
|
"""When the response carries reasoning_content, it lands on the
|
||||||
|
AIMessage's additional_kwargs so the next turn can echo it back."""
|
||||||
|
client = self._client()
|
||||||
|
result = client._create_chat_result(
|
||||||
|
{
|
||||||
|
"model": "deepseek-v4-flash",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Plan: buy NVDA.",
|
||||||
|
"reasoning_content": "Step 1: trend is up. Step 2: ...",
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ai = result.generations[0].message
|
||||||
|
assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..."
|
||||||
|
|
||||||
|
def test_propagate_on_send(self):
|
||||||
|
"""When an outgoing AIMessage carries reasoning_content, the request
|
||||||
|
payload echoes it on the corresponding message dict."""
|
||||||
|
client = self._client()
|
||||||
|
prior = AIMessage(
|
||||||
|
content="Plan",
|
||||||
|
additional_kwargs={"reasoning_content": "weighed bull case"},
|
||||||
|
)
|
||||||
|
new_user = HumanMessage(content="Refine.")
|
||||||
|
payload = client._get_request_payload([prior, new_user])
|
||||||
|
# Find the assistant message in the payload
|
||||||
|
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
|
||||||
|
assert assistant_dicts, "assistant message missing from outgoing payload"
|
||||||
|
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
|
||||||
|
|
||||||
|
def test_propagate_through_chat_prompt_value(self):
|
||||||
|
"""Gemini bot review note: non-list inputs (ChatPromptValue) must
|
||||||
|
also propagate reasoning_content."""
|
||||||
|
client = self._client()
|
||||||
|
prior = AIMessage(
|
||||||
|
content="Plan",
|
||||||
|
additional_kwargs={"reasoning_content": "weighed bull case"},
|
||||||
|
)
|
||||||
|
prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")])
|
||||||
|
payload = client._get_request_payload(prompt_value)
|
||||||
|
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
|
||||||
|
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# deepseek-reasoner: structured output unavailable, falls through to free-text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestDeepSeekReasonerStructuredOutput:
|
||||||
|
def test_with_structured_output_raises_for_reasoner(self):
|
||||||
|
client = DeepSeekChatOpenAI(
|
||||||
|
model="deepseek-reasoner",
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class _Sample(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
client.with_structured_output(_Sample)
|
||||||
|
|
||||||
|
def test_with_structured_output_works_for_v4(self):
|
||||||
|
"""V4 models (non-reasoner) accept tool_choice; structured output works."""
|
||||||
|
client = DeepSeekChatOpenAI(
|
||||||
|
model="deepseek-v4-flash",
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class _Sample(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
# Should return a Runnable, not raise. (The actual API call would
|
||||||
|
# require a real key; we only assert binding succeeds.)
|
||||||
|
wrapped = client.with_structured_output(_Sample)
|
||||||
|
assert wrapped is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestBaseClassIsolation:
|
||||||
|
def test_normalized_does_not_propagate_reasoning_content(self):
|
||||||
|
"""The general-purpose NormalizedChatOpenAI must not carry
|
||||||
|
DeepSeek-specific behaviour. Only the subclass does."""
|
||||||
|
assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or (
|
||||||
|
NormalizedChatOpenAI._get_request_payload
|
||||||
|
is NormalizedChatOpenAI.__bases__[0]._get_request_payload
|
||||||
|
)
|
||||||
31
tests/test_google_api_key.py
Normal file
31
tests/test_google_api_key.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.google_client import GoogleClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestGoogleApiKeyStandardization(unittest.TestCase):
|
||||||
|
"""Verify GoogleClient accepts unified api_key parameter."""
|
||||||
|
|
||||||
|
@patch("tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI")
|
||||||
|
def test_api_key_handling(self, mock_chat):
|
||||||
|
test_cases = [
|
||||||
|
("unified api_key is mapped", {"api_key": "test-key-123"}, "test-key-123"),
|
||||||
|
("legacy google_api_key still works", {"google_api_key": "legacy-key-456"}, "legacy-key-456"),
|
||||||
|
("unified api_key takes precedence", {"api_key": "unified", "google_api_key": "legacy"}, "unified"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for msg, kwargs, expected_key in test_cases:
|
||||||
|
with self.subTest(msg=msg):
|
||||||
|
mock_chat.reset_mock()
|
||||||
|
client = GoogleClient("gemini-2.5-flash", **kwargs)
|
||||||
|
client.get_llm()
|
||||||
|
call_kwargs = mock_chat.call_args[1]
|
||||||
|
self.assertEqual(call_kwargs.get("google_api_key"), expected_key)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
773
tests/test_memory_log.py
Normal file
773
tests/test_memory_log.py
Normal file
@@ -0,0 +1,773 @@
|
|||||||
|
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||||
|
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
|
||||||
|
from tradingagents.graph.reflection import Reflector
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
from tradingagents.graph.propagation import Propagator
|
||||||
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||||
|
|
||||||
|
_SEP = TradingMemoryLog._SEPARATOR
|
||||||
|
|
||||||
|
DECISION_BUY = "Rating: Buy\nEnter at $189-192, 6% portfolio cap."
|
||||||
|
DECISION_OVERWEIGHT = (
|
||||||
|
"Rating: Overweight\n"
|
||||||
|
"Executive Summary: Moderate position, await confirmation.\n"
|
||||||
|
"Investment Thesis: Strong fundamentals but near-term headwinds."
|
||||||
|
)
|
||||||
|
DECISION_SELL = "Rating: Sell\nExit position immediately."
|
||||||
|
DECISION_NO_RATING = (
|
||||||
|
"Executive Summary: Complex situation with multiple competing factors.\n"
|
||||||
|
"Investment Thesis: No clear directional signal at this time."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def make_log(tmp_path, filename="trading_memory.md"):
|
||||||
|
config = {"memory_log_path": str(tmp_path / filename)}
|
||||||
|
return TradingMemoryLog(config)
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_completed(tmp_path, ticker, date, decision_text, reflection_text, filename="trading_memory.md"):
|
||||||
|
"""Write a completed entry directly to file, bypassing the API."""
|
||||||
|
entry = (
|
||||||
|
f"[{date} | {ticker} | Buy | +1.0% | +0.5% | 5d]\n\n"
|
||||||
|
f"DECISION:\n{decision_text}\n\n"
|
||||||
|
f"REFLECTION:\n{reflection_text}"
|
||||||
|
+ _SEP
|
||||||
|
)
|
||||||
|
with open(tmp_path / filename, "a", encoding="utf-8") as f:
|
||||||
|
f.write(entry)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_entry(log, ticker, date, decision, reflection="Good call."):
|
||||||
|
"""Store a decision then immediately resolve it via the API."""
|
||||||
|
log.store_decision(ticker, date, decision)
|
||||||
|
log.update_with_outcome(ticker, date, 0.05, 0.02, 5, reflection)
|
||||||
|
|
||||||
|
|
||||||
|
def _price_df(prices):
|
||||||
|
"""Minimal DataFrame matching yfinance .history() output shape."""
|
||||||
|
return pd.DataFrame({"Close": prices})
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pm_state(past_context=""):
|
||||||
|
"""Minimal AgentState dict for portfolio_manager_node."""
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"past_context": past_context,
|
||||||
|
"risk_debate_state": {
|
||||||
|
"history": "Risk debate history.",
|
||||||
|
"aggressive_history": "",
|
||||||
|
"conservative_history": "",
|
||||||
|
"neutral_history": "",
|
||||||
|
"judge_decision": "",
|
||||||
|
"current_aggressive_response": "",
|
||||||
|
"current_conservative_response": "",
|
||||||
|
"current_neutral_response": "",
|
||||||
|
"count": 1,
|
||||||
|
},
|
||||||
|
"market_report": "Market report.",
|
||||||
|
"sentiment_report": "Sentiment report.",
|
||||||
|
"news_report": "News report.",
|
||||||
|
"fundamentals_report": "Fundamentals report.",
|
||||||
|
"investment_plan": "Research plan.",
|
||||||
|
"trader_investment_plan": "Trader plan.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _structured_pm_llm(captured: dict, decision: PortfolioDecision | None = None):
|
||||||
|
"""Build a MagicMock LLM whose with_structured_output binding captures the
|
||||||
|
prompt and returns a real PortfolioDecision (so render_pm_decision works).
|
||||||
|
"""
|
||||||
|
if decision is None:
|
||||||
|
decision = PortfolioDecision(
|
||||||
|
rating=PortfolioRating.HOLD,
|
||||||
|
executive_summary="Hold the position; await catalyst.",
|
||||||
|
investment_thesis="Balanced view; neither side carried the debate.",
|
||||||
|
)
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = lambda prompt: (
|
||||||
|
captured.__setitem__("prompt", prompt) or decision
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core: storage and read path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTradingMemoryLogCore:
|
||||||
|
|
||||||
|
def test_store_creates_file(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
assert not (tmp_path / "trading_memory.md").exists()
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert (tmp_path / "trading_memory.md").exists()
|
||||||
|
|
||||||
|
def test_store_appends_not_overwrites(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 2
|
||||||
|
assert entries[0]["ticker"] == "NVDA"
|
||||||
|
assert entries[1]["ticker"] == "AAPL"
|
||||||
|
|
||||||
|
def test_store_decision_idempotent(self, tmp_path):
|
||||||
|
"""Calling store_decision twice with same (ticker, date) stores only one entry."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert len(log.load_entries()) == 1
|
||||||
|
|
||||||
|
def test_batch_update_resolves_multiple_entries(self, tmp_path):
|
||||||
|
"""batch_update_with_outcomes resolves multiple pending entries in one write."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||||
|
log.store_decision("NVDA", "2026-01-12", DECISION_SELL)
|
||||||
|
|
||||||
|
updates = [
|
||||||
|
{"ticker": "NVDA", "trade_date": "2026-01-05",
|
||||||
|
"raw_return": 0.05, "alpha_return": 0.02, "holding_days": 5,
|
||||||
|
"reflection": "First correct."},
|
||||||
|
{"ticker": "NVDA", "trade_date": "2026-01-12",
|
||||||
|
"raw_return": -0.03, "alpha_return": -0.01, "holding_days": 5,
|
||||||
|
"reflection": "Second correct."},
|
||||||
|
]
|
||||||
|
log.batch_update_with_outcomes(updates)
|
||||||
|
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 2
|
||||||
|
assert all(not e["pending"] for e in entries)
|
||||||
|
assert entries[0]["reflection"] == "First correct."
|
||||||
|
assert entries[1]["reflection"] == "Second correct."
|
||||||
|
|
||||||
|
def test_pending_tag_format(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||||
|
assert "[2026-01-10 | NVDA | Buy | pending]" in text
|
||||||
|
|
||||||
|
# Rating parsing
|
||||||
|
|
||||||
|
def test_rating_parsed_buy(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Buy"
|
||||||
|
|
||||||
|
def test_rating_parsed_overweight(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Overweight"
|
||||||
|
|
||||||
|
def test_rating_fallback_hold(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Hold"
|
||||||
|
|
||||||
|
def test_rating_priority_over_prose(self, tmp_path):
|
||||||
|
"""'Rating: X' label wins even when an opposing rating word appears earlier in prose."""
|
||||||
|
decision = (
|
||||||
|
"The sell thesis is weak. The hold case is marginal.\n\n"
|
||||||
|
"Rating: Buy\n\n"
|
||||||
|
"Executive Summary: Strong fundamentals support the position."
|
||||||
|
)
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Buy"
|
||||||
|
|
||||||
|
# Delimiter robustness
|
||||||
|
|
||||||
|
def test_decision_with_markdown_separator(self, tmp_path):
|
||||||
|
"""LLM decision containing '---' must not corrupt the entry."""
|
||||||
|
decision = "Rating: Buy\n\n---\n\nRisk: elevated volatility."
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert "Risk: elevated volatility" in entries[0]["decision"]
|
||||||
|
|
||||||
|
# load_entries
|
||||||
|
|
||||||
|
def test_load_entries_empty_file(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
assert log.load_entries() == []
|
||||||
|
|
||||||
|
def test_load_entries_single(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
e = entries[0]
|
||||||
|
assert e["date"] == "2026-01-10"
|
||||||
|
assert e["ticker"] == "NVDA"
|
||||||
|
assert e["rating"] == "Buy"
|
||||||
|
assert e["pending"] is True
|
||||||
|
assert e["raw"] is None
|
||||||
|
|
||||||
|
def test_load_entries_multiple(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||||
|
log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING)
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 3
|
||||||
|
assert [e["ticker"] for e in entries] == ["NVDA", "AAPL", "MSFT"]
|
||||||
|
|
||||||
|
def test_decision_content_preserved(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert log.load_entries()[0]["decision"] == DECISION_BUY.strip()
|
||||||
|
|
||||||
|
# get_pending_entries
|
||||||
|
|
||||||
|
def test_get_pending_returns_pending_only(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA.", "Correct.")
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
pending = log.get_pending_entries()
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0]["ticker"] == "NVDA"
|
||||||
|
assert pending[0]["date"] == "2026-01-10"
|
||||||
|
|
||||||
|
# get_past_context
|
||||||
|
|
||||||
|
def test_get_past_context_empty(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
assert log.get_past_context("NVDA") == ""
|
||||||
|
|
||||||
|
def test_get_past_context_pending_excluded(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert log.get_past_context("NVDA") == ""
|
||||||
|
|
||||||
|
def test_get_past_context_same_ticker(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA — AI capex thesis intact.", "Directionally correct.")
|
||||||
|
ctx = log.get_past_context("NVDA")
|
||||||
|
assert "Past analyses of NVDA" in ctx
|
||||||
|
assert "Buy NVDA" in ctx
|
||||||
|
|
||||||
|
def test_get_past_context_cross_ticker(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_seed_completed(tmp_path, "AAPL", "2026-01-05", "Buy AAPL — Services growth.", "Correct.")
|
||||||
|
ctx = log.get_past_context("NVDA")
|
||||||
|
assert "Recent cross-ticker lessons" in ctx
|
||||||
|
assert "Past analyses of NVDA" not in ctx
|
||||||
|
|
||||||
|
def test_n_same_limit_respected(self, tmp_path):
|
||||||
|
"""Only the n_same most recent same-ticker entries are included."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
for i in range(6):
|
||||||
|
_seed_completed(tmp_path, "NVDA", f"2026-01-{i+1:02d}", f"Buy entry {i}.", "Correct.")
|
||||||
|
ctx = log.get_past_context("NVDA", n_same=5)
|
||||||
|
assert "Buy entry 0" not in ctx
|
||||||
|
assert "Buy entry 5" in ctx
|
||||||
|
|
||||||
|
def test_n_cross_limit_respected(self, tmp_path):
|
||||||
|
"""Only the n_cross most recent cross-ticker entries are included."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
for i, ticker in enumerate(["AAPL", "MSFT", "GOOG", "META"]):
|
||||||
|
_seed_completed(tmp_path, ticker, f"2026-01-{i+1:02d}", f"Buy {ticker}.", "Correct.")
|
||||||
|
ctx = log.get_past_context("NVDA", n_cross=3)
|
||||||
|
assert "AAPL" not in ctx
|
||||||
|
assert "META" in ctx
|
||||||
|
|
||||||
|
# No-op when config is None
|
||||||
|
|
||||||
|
def test_no_log_path_is_noop(self):
|
||||||
|
log = TradingMemoryLog(config=None)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert log.load_entries() == []
|
||||||
|
assert log.get_past_context("NVDA") == ""
|
||||||
|
|
||||||
|
# Rotation: opt-in cap on resolved entries
|
||||||
|
|
||||||
|
def test_rotation_disabled_by_default(self, tmp_path):
|
||||||
|
"""Without max_entries, all resolved entries are kept."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
for i in range(7):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||||
|
assert len(log.load_entries()) == 7
|
||||||
|
|
||||||
|
def test_rotation_prunes_oldest_resolved(self, tmp_path):
|
||||||
|
"""When max_entries is set and exceeded, oldest resolved entries are pruned."""
|
||||||
|
log = TradingMemoryLog({
|
||||||
|
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||||
|
"memory_log_max_entries": 3,
|
||||||
|
})
|
||||||
|
# Resolve 5 entries; rotation should keep only the 3 most recent.
|
||||||
|
for i in range(5):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 3
|
||||||
|
# Confirm the OLDEST were dropped, not the newest.
|
||||||
|
dates = [e["date"] for e in entries]
|
||||||
|
assert dates == ["2026-01-03", "2026-01-04", "2026-01-05"]
|
||||||
|
|
||||||
|
def test_rotation_never_prunes_pending(self, tmp_path):
|
||||||
|
"""Pending entries (unresolved) are kept regardless of the cap."""
|
||||||
|
log = TradingMemoryLog({
|
||||||
|
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||||
|
"memory_log_max_entries": 2,
|
||||||
|
})
|
||||||
|
# 3 resolved + 2 pending. With cap=2, only 2 resolved survive; both pending stay.
|
||||||
|
for i in range(3):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Resolved {i}.")
|
||||||
|
log.store_decision("NVDA", "2026-02-01", DECISION_BUY)
|
||||||
|
log.store_decision("NVDA", "2026-02-02", DECISION_OVERWEIGHT)
|
||||||
|
# Trigger rotation by resolving one more entry — pending entries must stay.
|
||||||
|
_resolve_entry(log, "NVDA", "2026-01-04", DECISION_BUY, "Resolved 3.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
pending = [e for e in entries if e["pending"]]
|
||||||
|
resolved = [e for e in entries if not e["pending"]]
|
||||||
|
assert len(pending) == 2, "pending entries must never be pruned"
|
||||||
|
assert len(resolved) == 2, f"expected 2 resolved after rotation, got {len(resolved)}"
|
||||||
|
|
||||||
|
def test_rotation_under_cap_is_noop(self, tmp_path):
|
||||||
|
"""No rotation when resolved count <= max_entries."""
|
||||||
|
log = TradingMemoryLog({
|
||||||
|
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||||
|
"memory_log_max_entries": 10,
|
||||||
|
})
|
||||||
|
for i in range(3):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||||
|
assert len(log.load_entries()) == 3
|
||||||
|
|
||||||
|
# Rating parsing: markdown bold and numbered list formats
|
||||||
|
|
||||||
|
def test_rating_parsed_from_bold_markdown(self, tmp_path):
|
||||||
|
"""**Rating**: Buy — markdown bold around the label must not prevent parsing."""
|
||||||
|
decision = "**Rating**: Buy\nEnter at $190."
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Buy"
|
||||||
|
|
||||||
|
def test_rating_parsed_from_bold_value(self, tmp_path):
|
||||||
|
"""Rating: **Sell** — markdown bold around the value must not prevent parsing."""
|
||||||
|
decision = "Rating: **Sell**\nExit immediately."
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Sell"
|
||||||
|
|
||||||
|
def test_rating_label_wins_over_prose_with_markdown(self, tmp_path):
|
||||||
|
"""Rating: **Sell** must win even when prose contains a conflicting rating word."""
|
||||||
|
decision = (
|
||||||
|
"The buy thesis is weakened by guidance.\n"
|
||||||
|
"Rating: **Sell**\n"
|
||||||
|
"Exit before earnings."
|
||||||
|
)
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Sell"
|
||||||
|
|
||||||
|
def test_rating_parsed_from_numbered_list(self, tmp_path):
|
||||||
|
"""1. Rating: Buy — numbered list prefix must not prevent parsing."""
|
||||||
|
decision = "1. Rating: Buy\nEnter at $190."
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Buy"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Deferred reflection: update_with_outcome, Reflector, _fetch_returns
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDeferredReflection:
|
||||||
|
|
||||||
|
# update_with_outcome
|
||||||
|
|
||||||
|
def test_update_replaces_pending_tag(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||||
|
text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||||
|
assert "[2026-01-10 | NVDA | Buy | pending]" not in text
|
||||||
|
assert "+4.2%" in text
|
||||||
|
assert "+2.1%" in text
|
||||||
|
assert "5d" in text
|
||||||
|
|
||||||
|
def test_update_appends_reflection(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
e = entries[0]
|
||||||
|
assert e["pending"] is False
|
||||||
|
assert e["reflection"] == "Momentum confirmed."
|
||||||
|
assert e["decision"] == DECISION_BUY.strip()
|
||||||
|
|
||||||
|
def test_update_preserves_other_entries(self, tmp_path):
|
||||||
|
"""Only the matching entry is modified; all other entries remain unchanged."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.store_decision("AAPL", "2026-01-11", "Rating: Hold\nHold AAPL.")
|
||||||
|
log.store_decision("MSFT", "2026-01-12", DECISION_SELL)
|
||||||
|
log.update_with_outcome("AAPL", "2026-01-11", 0.01, -0.01, 5, "Neutral result.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 3
|
||||||
|
nvda, aapl, msft = entries
|
||||||
|
assert nvda["ticker"] == "NVDA" and nvda["pending"] is True
|
||||||
|
assert aapl["ticker"] == "AAPL" and aapl["pending"] is False
|
||||||
|
assert aapl["reflection"] == "Neutral result."
|
||||||
|
assert msft["ticker"] == "MSFT" and msft["pending"] is True
|
||||||
|
|
||||||
|
def test_update_atomic_write(self, tmp_path):
|
||||||
|
"""A pre-existing .tmp file is overwritten; the log is correctly updated."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
stale_tmp = tmp_path / "trading_memory.tmp"
|
||||||
|
stale_tmp.write_text("GARBAGE CONTENT — should be overwritten", encoding="utf-8")
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Correct.")
|
||||||
|
assert not stale_tmp.exists()
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["reflection"] == "Correct."
|
||||||
|
assert entries[0]["pending"] is False
|
||||||
|
|
||||||
|
def test_update_noop_when_no_log_path(self):
|
||||||
|
log = TradingMemoryLog(config=None)
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.05, 0.02, 5, "Reflection")
|
||||||
|
|
||||||
|
def test_formatting_roundtrip_after_update(self, tmp_path):
|
||||||
|
"""All fields intact and blank line between tag and DECISION preserved after update."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
e = entries[0]
|
||||||
|
assert e["pending"] is False
|
||||||
|
assert e["decision"] == DECISION_BUY.strip()
|
||||||
|
assert e["reflection"] == "Momentum confirmed."
|
||||||
|
assert e["raw"] == "+4.2%"
|
||||||
|
assert e["alpha"] == "+2.1%"
|
||||||
|
assert e["holding"] == "5d"
|
||||||
|
raw_text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||||
|
assert "[2026-01-10 | NVDA | Buy | +4.2% | +2.1% | 5d]\n\nDECISION:" in raw_text
|
||||||
|
|
||||||
|
# Reflector.reflect_on_final_decision
|
||||||
|
|
||||||
|
def test_reflect_on_final_decision_returns_llm_output(self):
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.invoke.return_value.content = "Directionally correct. Thesis confirmed."
|
||||||
|
reflector = Reflector(mock_llm)
|
||||||
|
result = reflector.reflect_on_final_decision(
|
||||||
|
final_decision=DECISION_BUY, raw_return=0.042, alpha_return=0.021
|
||||||
|
)
|
||||||
|
assert result == "Directionally correct. Thesis confirmed."
|
||||||
|
mock_llm.invoke.assert_called_once()
|
||||||
|
|
||||||
|
def test_reflect_on_final_decision_includes_returns_in_prompt(self):
|
||||||
|
"""Return figures are present in the human message sent to the LLM."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.invoke.return_value.content = "Incorrect call."
|
||||||
|
reflector = Reflector(mock_llm)
|
||||||
|
reflector.reflect_on_final_decision(
|
||||||
|
final_decision=DECISION_SELL, raw_return=-0.08, alpha_return=-0.05
|
||||||
|
)
|
||||||
|
messages = mock_llm.invoke.call_args[0][0]
|
||||||
|
human_content = next(content for role, content in messages if role == "human")
|
||||||
|
assert "-8.0%" in human_content
|
||||||
|
assert "-5.0%" in human_content
|
||||||
|
assert "Exit position immediately." in human_content
|
||||||
|
|
||||||
|
# TradingAgentsGraph._fetch_returns
|
||||||
|
|
||||||
|
def test_fetch_returns_valid_ticker(self):
|
||||||
|
stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0]
|
||||||
|
spy_prices = [400.0, 402.0, 404.0, 403.0, 405.0, 406.0]
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||||
|
def _make_ticker(sym):
|
||||||
|
m = MagicMock()
|
||||||
|
m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices)
|
||||||
|
return m
|
||||||
|
mock_ticker_cls.side_effect = _make_ticker
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05")
|
||||||
|
assert raw is not None and alpha is not None and days is not None
|
||||||
|
assert isinstance(raw, float) and isinstance(alpha, float) and isinstance(days, int)
|
||||||
|
assert days == 5
|
||||||
|
|
||||||
|
def test_fetch_returns_too_recent(self):
|
||||||
|
"""Only 1 data point available → returns (None, None, None), no crash."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||||
|
m = MagicMock()
|
||||||
|
m.history.return_value = _price_df([100.0])
|
||||||
|
mock_ticker_cls.return_value = m
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-04-19")
|
||||||
|
assert raw is None and alpha is None and days is None
|
||||||
|
|
||||||
|
def test_fetch_returns_delisted(self):
|
||||||
|
"""Empty DataFrame → returns (None, None, None), no crash."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||||
|
m = MagicMock()
|
||||||
|
m.history.return_value = pd.DataFrame({"Close": []})
|
||||||
|
mock_ticker_cls.return_value = m
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "XXXXXFAKE", "2026-01-10")
|
||||||
|
assert raw is None and alpha is None and days is None
|
||||||
|
|
||||||
|
def test_fetch_returns_spy_shorter_than_stock(self):
|
||||||
|
"""SPY having fewer rows than the stock must not raise IndexError."""
|
||||||
|
stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0]
|
||||||
|
spy_prices = [400.0, 402.0, 403.0]
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||||
|
def _make_ticker(sym):
|
||||||
|
m = MagicMock()
|
||||||
|
m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices)
|
||||||
|
return m
|
||||||
|
mock_ticker_cls.side_effect = _make_ticker
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05")
|
||||||
|
assert raw is not None and alpha is not None and days is not None
|
||||||
|
assert days == 2
|
||||||
|
|
||||||
|
# TradingAgentsGraph._resolve_pending_entries
|
||||||
|
|
||||||
|
def test_resolve_skips_other_tickers(self, tmp_path):
|
||||||
|
"""Pending AAPL entry is not resolved when the run is for NVDA."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("AAPL", "2026-01-10", DECISION_BUY)
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.memory_log = log
|
||||||
|
mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5))
|
||||||
|
TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA")
|
||||||
|
mock_graph._fetch_returns.assert_not_called()
|
||||||
|
assert len(log.get_pending_entries()) == 1
|
||||||
|
|
||||||
|
def test_resolve_marks_entry_completed(self, tmp_path):
|
||||||
|
"""After resolve, get_pending_entries() is empty and the entry has a REFLECTION."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||||
|
mock_reflector = MagicMock()
|
||||||
|
mock_reflector.reflect_on_final_decision.return_value = "Momentum confirmed."
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.memory_log = log
|
||||||
|
mock_graph.reflector = mock_reflector
|
||||||
|
mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5))
|
||||||
|
TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA")
|
||||||
|
assert log.get_pending_entries() == []
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["pending"] is False
|
||||||
|
assert entries[0]["reflection"] == "Momentum confirmed."
|
||||||
|
assert "+5.0%" in entries[0]["raw"]
|
||||||
|
assert "+2.0%" in entries[0]["alpha"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Portfolio Manager injection: past_context in state and prompt
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPortfolioManagerInjection:
|
||||||
|
|
||||||
|
# past_context in initial state
|
||||||
|
|
||||||
|
def test_past_context_in_initial_state(self):
|
||||||
|
propagator = Propagator()
|
||||||
|
state = propagator.create_initial_state("NVDA", "2026-01-10", past_context="some context")
|
||||||
|
assert "past_context" in state
|
||||||
|
assert state["past_context"] == "some context"
|
||||||
|
|
||||||
|
def test_past_context_defaults_to_empty(self):
|
||||||
|
propagator = Propagator()
|
||||||
|
state = propagator.create_initial_state("NVDA", "2026-01-10")
|
||||||
|
assert state["past_context"] == ""
|
||||||
|
|
||||||
|
# PM prompt
|
||||||
|
|
||||||
|
def test_pm_prompt_includes_past_context(self):
|
||||||
|
captured = {}
|
||||||
|
llm = _structured_pm_llm(captured)
|
||||||
|
pm_node = create_portfolio_manager(llm)
|
||||||
|
state = _make_pm_state(past_context="[2026-01-05 | NVDA | Buy | +5.0% | +2.0% | 5d]\nGreat call.")
|
||||||
|
pm_node(state)
|
||||||
|
assert "Lessons from prior decisions and outcomes" in captured["prompt"]
|
||||||
|
assert "Great call." in captured["prompt"]
|
||||||
|
|
||||||
|
def test_pm_no_past_context_no_section(self):
|
||||||
|
"""PM prompt omits the lessons section entirely when past_context is empty."""
|
||||||
|
captured = {}
|
||||||
|
llm = _structured_pm_llm(captured)
|
||||||
|
pm_node = create_portfolio_manager(llm)
|
||||||
|
state = _make_pm_state(past_context="")
|
||||||
|
pm_node(state)
|
||||||
|
assert "Lessons from prior decisions" not in captured["prompt"]
|
||||||
|
|
||||||
|
def test_pm_returns_rendered_markdown_with_rating(self):
|
||||||
|
"""The structured PortfolioDecision is rendered to markdown that
|
||||||
|
downstream consumers (memory log, signal processor, CLI display)
|
||||||
|
can parse without any extra LLM call."""
|
||||||
|
captured = {}
|
||||||
|
decision = PortfolioDecision(
|
||||||
|
rating=PortfolioRating.OVERWEIGHT,
|
||||||
|
executive_summary="Build position gradually over the next two weeks.",
|
||||||
|
investment_thesis="AI capex cycle remains intact; institutional flows constructive.",
|
||||||
|
price_target=215.0,
|
||||||
|
time_horizon="3-6 months",
|
||||||
|
)
|
||||||
|
llm = _structured_pm_llm(captured, decision)
|
||||||
|
pm_node = create_portfolio_manager(llm)
|
||||||
|
result = pm_node(_make_pm_state())
|
||||||
|
md = result["final_trade_decision"]
|
||||||
|
assert "**Rating**: Overweight" in md
|
||||||
|
assert "**Executive Summary**: Build position gradually" in md
|
||||||
|
assert "**Investment Thesis**: AI capex cycle" in md
|
||||||
|
assert "**Price Target**: 215.0" in md
|
||||||
|
assert "**Time Horizon**: 3-6 months" in md
|
||||||
|
|
||||||
|
def test_pm_falls_back_to_freetext_when_structured_unavailable(self):
|
||||||
|
"""If a provider does not support with_structured_output, the agent
|
||||||
|
falls back to a plain invoke and returns whatever prose the model
|
||||||
|
produced, so the pipeline never blocks."""
|
||||||
|
plain_response = "**Rating**: Sell\n\nExit ahead of guidance."
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||||
|
pm_node = create_portfolio_manager(llm)
|
||||||
|
result = pm_node(_make_pm_state())
|
||||||
|
assert result["final_trade_decision"] == plain_response
|
||||||
|
|
||||||
|
# get_past_context ordering and limits
|
||||||
|
|
||||||
|
def test_same_ticker_prioritised(self, tmp_path):
|
||||||
|
"""Same-ticker entries in same-ticker section; cross-ticker entries in cross-ticker section."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_resolve_entry(log, "NVDA", "2026-01-05", DECISION_BUY, "Momentum confirmed.")
|
||||||
|
_resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued.")
|
||||||
|
result = log.get_past_context("NVDA")
|
||||||
|
assert "Past analyses of NVDA" in result
|
||||||
|
assert "Recent cross-ticker lessons" in result
|
||||||
|
same_block, cross_block = result.split("Recent cross-ticker lessons")
|
||||||
|
assert "NVDA" in same_block
|
||||||
|
assert "AAPL" in cross_block
|
||||||
|
|
||||||
|
def test_cross_ticker_reflection_only(self, tmp_path):
|
||||||
|
"""Cross-ticker entries show only the REFLECTION text, not the full DECISION."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued correction.")
|
||||||
|
result = log.get_past_context("NVDA")
|
||||||
|
assert "Overvalued correction." in result
|
||||||
|
assert "Exit position immediately." not in result
|
||||||
|
|
||||||
|
def test_n_same_limit_respected(self, tmp_path):
|
||||||
|
"""More than 5 same-ticker completed entries → only 5 injected."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
for i in range(7):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||||
|
result = log.get_past_context("NVDA", n_same=5)
|
||||||
|
lessons_present = sum(1 for i in range(7) if f"Lesson {i}." in result)
|
||||||
|
assert lessons_present == 5
|
||||||
|
|
||||||
|
def test_n_cross_limit_respected(self, tmp_path):
|
||||||
|
"""More than 3 cross-ticker completed entries → only 3 injected."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
tickers = ["AAPL", "MSFT", "TSLA", "AMZN", "GOOG"]
|
||||||
|
for i, ticker in enumerate(tickers):
|
||||||
|
_resolve_entry(log, ticker, f"2026-01-{i+1:02d}", DECISION_BUY, f"{ticker} lesson.")
|
||||||
|
result = log.get_past_context("NVDA", n_cross=3)
|
||||||
|
cross_count = sum(result.count(f"{t} lesson.") for t in tickers)
|
||||||
|
assert cross_count == 3
|
||||||
|
|
||||||
|
# Full A→B→C integration cycle
|
||||||
|
|
||||||
|
def test_full_cycle_store_resolve_inject(self, tmp_path):
|
||||||
|
"""store pending → resolve with outcome → past_context non-empty for PM."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||||
|
assert len(log.get_pending_entries()) == 1
|
||||||
|
assert log.get_past_context("NVDA") == ""
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-05", 0.05, 0.02, 5, "Correct call.")
|
||||||
|
assert log.get_pending_entries() == []
|
||||||
|
past_ctx = log.get_past_context("NVDA")
|
||||||
|
assert past_ctx != ""
|
||||||
|
assert "NVDA" in past_ctx
|
||||||
|
assert "Correct call." in past_ctx
|
||||||
|
assert "DECISION:" in past_ctx
|
||||||
|
assert "REFLECTION:" in past_ctx
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Legacy removal: BM25 / FinancialSituationMemory fully gone
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestLegacyRemoval:
|
||||||
|
|
||||||
|
def test_financial_situation_memory_removed(self):
|
||||||
|
"""FinancialSituationMemory must not be importable from the memory module."""
|
||||||
|
import tradingagents.agents.utils.memory as m
|
||||||
|
assert not hasattr(m, "FinancialSituationMemory")
|
||||||
|
|
||||||
|
def test_bm25_not_imported(self):
|
||||||
|
"""rank_bm25 must not be present in the memory module namespace."""
|
||||||
|
import tradingagents.agents.utils.memory as m
|
||||||
|
assert not hasattr(m, "BM25Okapi")
|
||||||
|
|
||||||
|
def test_reflect_and_remember_removed(self):
|
||||||
|
"""TradingAgentsGraph must not expose reflect_and_remember."""
|
||||||
|
assert not hasattr(TradingAgentsGraph, "reflect_and_remember")
|
||||||
|
|
||||||
|
def test_portfolio_manager_no_memory_param(self):
|
||||||
|
"""create_portfolio_manager accepts only llm; passing memory= raises TypeError."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
create_portfolio_manager(mock_llm)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
create_portfolio_manager(mock_llm, memory=MagicMock())
|
||||||
|
|
||||||
|
def test_full_pipeline_no_regression(self, tmp_path):
|
||||||
|
"""propagate() completes and stores the decision after the redesign."""
|
||||||
|
import functools
|
||||||
|
|
||||||
|
fake_state = {
|
||||||
|
"final_trade_decision": "Rating: Buy\nBuy NVDA.",
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"trade_date": "2026-01-10",
|
||||||
|
"market_report": "",
|
||||||
|
"sentiment_report": "",
|
||||||
|
"news_report": "",
|
||||||
|
"fundamentals_report": "",
|
||||||
|
"investment_debate_state": {
|
||||||
|
"bull_history": "", "bear_history": "", "history": "",
|
||||||
|
"current_response": "", "judge_decision": "",
|
||||||
|
},
|
||||||
|
"investment_plan": "",
|
||||||
|
"trader_investment_plan": "",
|
||||||
|
"risk_debate_state": {
|
||||||
|
"aggressive_history": "", "conservative_history": "",
|
||||||
|
"neutral_history": "", "history": "", "judge_decision": "",
|
||||||
|
"current_aggressive_response": "", "current_conservative_response": "",
|
||||||
|
"current_neutral_response": "", "count": 1, "latest_speaker": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.memory_log = TradingMemoryLog({"memory_log_path": str(tmp_path / "mem.md")})
|
||||||
|
mock_graph.log_states_dict = {}
|
||||||
|
mock_graph.debug = False
|
||||||
|
mock_graph.config = {"results_dir": str(tmp_path)}
|
||||||
|
mock_graph.graph.invoke.return_value = fake_state
|
||||||
|
mock_graph.propagator.create_initial_state.return_value = fake_state
|
||||||
|
mock_graph.propagator.get_graph_args.return_value = {}
|
||||||
|
mock_graph.signal_processor.process_signal.return_value = "Buy"
|
||||||
|
# Bind the real _run_graph so propagate's call to self._run_graph executes
|
||||||
|
# the actual write path instead of the auto-MagicMock.
|
||||||
|
mock_graph._run_graph = functools.partial(
|
||||||
|
TradingAgentsGraph._run_graph, mock_graph
|
||||||
|
)
|
||||||
|
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
|
||||||
|
entries = mock_graph.memory_log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["ticker"] == "NVDA"
|
||||||
|
assert entries[0]["pending"] is True
|
||||||
55
tests/test_model_validation.py
Normal file
55
tests/test_model_validation.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.base_client import BaseLLMClient
|
||||||
|
from tradingagents.llm_clients.model_catalog import get_known_models
|
||||||
|
from tradingagents.llm_clients.validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
class DummyLLMClient(BaseLLMClient):
|
||||||
|
def __init__(self, provider: str, model: str):
|
||||||
|
self.provider = provider
|
||||||
|
super().__init__(model)
|
||||||
|
|
||||||
|
def get_llm(self):
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
return object()
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
return validate_model(self.provider, self.model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class ModelValidationTests(unittest.TestCase):
|
||||||
|
def test_cli_catalog_models_are_all_validator_approved(self):
|
||||||
|
for provider, models in get_known_models().items():
|
||||||
|
if provider in ("ollama", "openrouter"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
with self.subTest(provider=provider, model=model):
|
||||||
|
self.assertTrue(validate_model(provider, model))
|
||||||
|
|
||||||
|
def test_unknown_model_emits_warning_for_strict_provider(self):
|
||||||
|
client = DummyLLMClient("openai", "not-a-real-openai-model")
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as caught:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
client.get_llm()
|
||||||
|
|
||||||
|
self.assertEqual(len(caught), 1)
|
||||||
|
self.assertIn("not-a-real-openai-model", str(caught[0].message))
|
||||||
|
self.assertIn("openai", str(caught[0].message))
|
||||||
|
|
||||||
|
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
|
||||||
|
for provider in ("openrouter", "ollama"):
|
||||||
|
client = DummyLLMClient(provider, "custom-model-name")
|
||||||
|
|
||||||
|
with self.subTest(provider=provider):
|
||||||
|
with warnings.catch_warnings(record=True) as caught:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
client.get_llm()
|
||||||
|
|
||||||
|
self.assertEqual(caught, [])
|
||||||
52
tests/test_safe_ticker_component.py
Normal file
52
tests/test_safe_ticker_component.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Tests for the ticker path-component validator that blocks directory traversal."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSafeTickerComponent(unittest.TestCase):
|
||||||
|
def test_accepts_common_ticker_formats(self):
|
||||||
|
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
|
||||||
|
self.assertEqual(safe_ticker_component(ticker), ticker)
|
||||||
|
|
||||||
|
def test_rejects_path_separators(self):
|
||||||
|
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_null_byte_and_whitespace(self):
|
||||||
|
for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_empty_or_non_string(self):
|
||||||
|
for bad in ("", None, 123, b"AAPL"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_overlong_input(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component("A" * 33)
|
||||||
|
|
||||||
|
def test_rejects_dot_only_values(self):
|
||||||
|
# '.' and '..' pass the regex but traverse when used as a path
|
||||||
|
# component (e.g. ``Path(results_dir) / ticker / "logs"``).
|
||||||
|
for bad in (".", "..", "...", "...."):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_traversal_string_does_not_escape_join(self):
|
||||||
|
"""Sanity: sanitized values stay within base when joined."""
|
||||||
|
base = os.path.realpath("/tmp/cache")
|
||||||
|
ticker = safe_ticker_component("AAPL")
|
||||||
|
joined = os.path.realpath(os.path.join(base, f"{ticker}.csv"))
|
||||||
|
self.assertTrue(joined.startswith(base + os.sep))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
90
tests/test_signal_processing.py
Normal file
90
tests/test_signal_processing.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Tests for the shared rating heuristic and the SignalProcessor adapter.
|
||||||
|
|
||||||
|
The Portfolio Manager produces a typed PortfolioDecision via structured
|
||||||
|
output and renders it to markdown that always contains a ``**Rating**: X``
|
||||||
|
header. The deterministic heuristic in ``tradingagents.agents.utils.rating``
|
||||||
|
is therefore sufficient to extract the rating downstream — no second LLM
|
||||||
|
call is needed — and SignalProcessor is now a thin adapter that delegates
|
||||||
|
to it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
|
||||||
|
from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Heuristic parser
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestParseRating:
|
||||||
|
def test_explicit_label_buy(self):
|
||||||
|
assert parse_rating("Rating: Buy\nReasoning here.") == "Buy"
|
||||||
|
|
||||||
|
def test_explicit_label_overweight(self):
|
||||||
|
assert parse_rating("Rating: Overweight\nDetails.") == "Overweight"
|
||||||
|
|
||||||
|
def test_explicit_label_with_markdown_bold_value(self):
|
||||||
|
# Regression: Rating: **Sell** — markdown around the value.
|
||||||
|
assert parse_rating("Rating: **Sell**\nExit immediately.") == "Sell"
|
||||||
|
|
||||||
|
def test_explicit_label_with_markdown_bold_label(self):
|
||||||
|
assert parse_rating("**Rating**: Underweight\nTrim exposure.") == "Underweight"
|
||||||
|
|
||||||
|
def test_rendered_pm_markdown_shape(self):
|
||||||
|
# The exact shape produced by render_pm_decision must always parse.
|
||||||
|
text = (
|
||||||
|
"**Rating**: Buy\n\n"
|
||||||
|
"**Executive Summary**: Enter at $189-192, 6% portfolio cap.\n\n"
|
||||||
|
"**Investment Thesis**: AI capex cycle intact; institutional flows constructive."
|
||||||
|
)
|
||||||
|
assert parse_rating(text) == "Buy"
|
||||||
|
|
||||||
|
def test_explicit_label_wins_over_prose_with_markdown(self):
|
||||||
|
text = (
|
||||||
|
"The buy thesis is weakened by guidance.\n"
|
||||||
|
"Rating: **Sell**\n"
|
||||||
|
"Exit before earnings."
|
||||||
|
)
|
||||||
|
assert parse_rating(text) == "Sell"
|
||||||
|
|
||||||
|
def test_no_rating_returns_default(self):
|
||||||
|
assert parse_rating("No clear directional signal at this time.") == "Hold"
|
||||||
|
|
||||||
|
def test_no_rating_custom_default(self):
|
||||||
|
assert parse_rating("Plain prose.", default="Underweight") == "Underweight"
|
||||||
|
|
||||||
|
def test_all_five_tiers_recognised(self):
|
||||||
|
for r in RATINGS_5_TIER:
|
||||||
|
assert parse_rating(f"Rating: {r}") == r
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SignalProcessor: thin adapter over the heuristic
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSignalProcessor:
|
||||||
|
def test_returns_rating_from_pm_markdown(self):
|
||||||
|
sp = SignalProcessor()
|
||||||
|
md = "**Rating**: Overweight\n\n**Executive Summary**: Build gradually."
|
||||||
|
assert sp.process_signal(md) == "Overweight"
|
||||||
|
|
||||||
|
def test_makes_no_llm_calls(self):
|
||||||
|
"""SignalProcessor must not invoke the LLM it was constructed with —
|
||||||
|
the rating is parseable from the rendered PM markdown directly."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
sp = SignalProcessor(llm)
|
||||||
|
sp.process_signal("Rating: Buy\nDetails.")
|
||||||
|
llm.invoke.assert_not_called()
|
||||||
|
llm.with_structured_output.assert_not_called()
|
||||||
|
|
||||||
|
def test_default_when_no_rating_present(self):
|
||||||
|
sp = SignalProcessor()
|
||||||
|
assert sp.process_signal("Plain prose without a recommendation.") == "Hold"
|
||||||
232
tests/test_structured_agents.py
Normal file
232
tests/test_structured_agents.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
"""Tests for structured-output agents (Trader and Research Manager).
|
||||||
|
|
||||||
|
The Portfolio Manager has its own coverage in tests/test_memory_log.py
|
||||||
|
(which exercises the full memory-log → PM injection cycle). This file
|
||||||
|
covers the parallel schemas, render functions, and graceful-fallback
|
||||||
|
behavior we added for the Trader and Research Manager so all three
|
||||||
|
decision-making agents share the same shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.agents.managers.research_manager import create_research_manager
|
||||||
|
from tradingagents.agents.schemas import (
|
||||||
|
PortfolioRating,
|
||||||
|
ResearchPlan,
|
||||||
|
TraderAction,
|
||||||
|
TraderProposal,
|
||||||
|
render_research_plan,
|
||||||
|
render_trader_proposal,
|
||||||
|
)
|
||||||
|
from tradingagents.agents.trader.trader import create_trader
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Render functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRenderTraderProposal:
|
||||||
|
def test_minimal_required_fields(self):
|
||||||
|
p = TraderProposal(action=TraderAction.HOLD, reasoning="Balanced setup; no edge.")
|
||||||
|
md = render_trader_proposal(p)
|
||||||
|
assert "**Action**: Hold" in md
|
||||||
|
assert "**Reasoning**: Balanced setup; no edge." in md
|
||||||
|
# The trailing FINAL TRANSACTION PROPOSAL line is preserved for the
|
||||||
|
# analyst stop-signal text and any external code that greps for it.
|
||||||
|
assert "FINAL TRANSACTION PROPOSAL: **HOLD**" in md
|
||||||
|
|
||||||
|
def test_optional_fields_included_when_present(self):
|
||||||
|
p = TraderProposal(
|
||||||
|
action=TraderAction.BUY,
|
||||||
|
reasoning="Strong technicals + fundamentals.",
|
||||||
|
entry_price=189.5,
|
||||||
|
stop_loss=178.0,
|
||||||
|
position_sizing="6% of portfolio",
|
||||||
|
)
|
||||||
|
md = render_trader_proposal(p)
|
||||||
|
assert "**Action**: Buy" in md
|
||||||
|
assert "**Entry Price**: 189.5" in md
|
||||||
|
assert "**Stop Loss**: 178.0" in md
|
||||||
|
assert "**Position Sizing**: 6% of portfolio" in md
|
||||||
|
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in md
|
||||||
|
|
||||||
|
def test_optional_fields_omitted_when_absent(self):
|
||||||
|
p = TraderProposal(action=TraderAction.SELL, reasoning="Guidance cut.")
|
||||||
|
md = render_trader_proposal(p)
|
||||||
|
assert "Entry Price" not in md
|
||||||
|
assert "Stop Loss" not in md
|
||||||
|
assert "Position Sizing" not in md
|
||||||
|
assert "FINAL TRANSACTION PROPOSAL: **SELL**" in md
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRenderResearchPlan:
|
||||||
|
def test_required_fields(self):
|
||||||
|
p = ResearchPlan(
|
||||||
|
recommendation=PortfolioRating.OVERWEIGHT,
|
||||||
|
rationale="Bull case carried; tailwinds intact.",
|
||||||
|
strategic_actions="Build position over two weeks; cap at 5%.",
|
||||||
|
)
|
||||||
|
md = render_research_plan(p)
|
||||||
|
assert "**Recommendation**: Overweight" in md
|
||||||
|
assert "**Rationale**: Bull case carried" in md
|
||||||
|
assert "**Strategic Actions**: Build position" in md
|
||||||
|
|
||||||
|
def test_all_5_tier_ratings_render(self):
|
||||||
|
for rating in PortfolioRating:
|
||||||
|
p = ResearchPlan(
|
||||||
|
recommendation=rating,
|
||||||
|
rationale="r",
|
||||||
|
strategic_actions="s",
|
||||||
|
)
|
||||||
|
md = render_research_plan(p)
|
||||||
|
assert f"**Recommendation**: {rating.value}" in md
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trader agent: structured happy path + fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_trader_state():
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"investment_plan": "**Recommendation**: Buy\n**Rationale**: ...\n**Strategic Actions**: ...",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _structured_trader_llm(captured: dict, proposal: TraderProposal | None = None):
|
||||||
|
"""Build a MagicMock LLM whose with_structured_output binding captures the
|
||||||
|
prompt and returns a real TraderProposal so render_trader_proposal works.
|
||||||
|
"""
|
||||||
|
if proposal is None:
|
||||||
|
proposal = TraderProposal(
|
||||||
|
action=TraderAction.BUY,
|
||||||
|
reasoning="Strong setup.",
|
||||||
|
)
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = lambda prompt: (
|
||||||
|
captured.__setitem__("prompt", prompt) or proposal
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTraderAgent:
|
||||||
|
def test_structured_path_produces_rendered_markdown(self):
|
||||||
|
captured = {}
|
||||||
|
proposal = TraderProposal(
|
||||||
|
action=TraderAction.BUY,
|
||||||
|
reasoning="AI capex cycle intact; institutional flows constructive.",
|
||||||
|
entry_price=189.5,
|
||||||
|
stop_loss=178.0,
|
||||||
|
position_sizing="6% of portfolio",
|
||||||
|
)
|
||||||
|
llm = _structured_trader_llm(captured, proposal)
|
||||||
|
trader = create_trader(llm)
|
||||||
|
result = trader(_make_trader_state())
|
||||||
|
plan = result["trader_investment_plan"]
|
||||||
|
assert "**Action**: Buy" in plan
|
||||||
|
assert "**Entry Price**: 189.5" in plan
|
||||||
|
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in plan
|
||||||
|
# The same rendered markdown is also added to messages for downstream agents.
|
||||||
|
assert plan in result["messages"][0].content
|
||||||
|
|
||||||
|
def test_prompt_includes_investment_plan(self):
|
||||||
|
captured = {}
|
||||||
|
llm = _structured_trader_llm(captured)
|
||||||
|
trader = create_trader(llm)
|
||||||
|
trader(_make_trader_state())
|
||||||
|
# The investment plan is in the user message of the captured prompt.
|
||||||
|
prompt = captured["prompt"]
|
||||||
|
assert any("Proposed Investment Plan" in m["content"] for m in prompt)
|
||||||
|
|
||||||
|
def test_falls_back_to_freetext_when_structured_unavailable(self):
|
||||||
|
plain_response = (
|
||||||
|
"**Action**: Sell\n\nGuidance cut hits margins.\n\n"
|
||||||
|
"FINAL TRANSACTION PROPOSAL: **SELL**"
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||||
|
trader = create_trader(llm)
|
||||||
|
result = trader(_make_trader_state())
|
||||||
|
assert result["trader_investment_plan"] == plain_response
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Research Manager agent: structured happy path + fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rm_state():
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"investment_debate_state": {
|
||||||
|
"history": "Bull and bear arguments here.",
|
||||||
|
"bull_history": "Bull says...",
|
||||||
|
"bear_history": "Bear says...",
|
||||||
|
"current_response": "",
|
||||||
|
"judge_decision": "",
|
||||||
|
"count": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _structured_rm_llm(captured: dict, plan: ResearchPlan | None = None):
|
||||||
|
if plan is None:
|
||||||
|
plan = ResearchPlan(
|
||||||
|
recommendation=PortfolioRating.HOLD,
|
||||||
|
rationale="Balanced view across both sides.",
|
||||||
|
strategic_actions="Hold current position; reassess after earnings.",
|
||||||
|
)
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = lambda prompt: (
|
||||||
|
captured.__setitem__("prompt", prompt) or plan
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestResearchManagerAgent:
|
||||||
|
def test_structured_path_produces_rendered_markdown(self):
|
||||||
|
captured = {}
|
||||||
|
plan = ResearchPlan(
|
||||||
|
recommendation=PortfolioRating.OVERWEIGHT,
|
||||||
|
rationale="Bull case is stronger; AI tailwind intact.",
|
||||||
|
strategic_actions="Build position gradually over two weeks.",
|
||||||
|
)
|
||||||
|
llm = _structured_rm_llm(captured, plan)
|
||||||
|
rm = create_research_manager(llm)
|
||||||
|
result = rm(_make_rm_state())
|
||||||
|
ip = result["investment_plan"]
|
||||||
|
assert "**Recommendation**: Overweight" in ip
|
||||||
|
assert "**Rationale**: Bull case" in ip
|
||||||
|
assert "**Strategic Actions**: Build position" in ip
|
||||||
|
|
||||||
|
def test_prompt_uses_5_tier_rating_scale(self):
|
||||||
|
"""The RM prompt must list all five tiers so the schema enum matches user expectations."""
|
||||||
|
captured = {}
|
||||||
|
llm = _structured_rm_llm(captured)
|
||||||
|
rm = create_research_manager(llm)
|
||||||
|
rm(_make_rm_state())
|
||||||
|
prompt = captured["prompt"]
|
||||||
|
for tier in ("Buy", "Overweight", "Hold", "Underweight", "Sell"):
|
||||||
|
assert f"**{tier}**" in prompt, f"missing {tier} in prompt"
|
||||||
|
|
||||||
|
def test_falls_back_to_freetext_when_structured_unavailable(self):
|
||||||
|
plain_response = "**Recommendation**: Sell\n\n**Rationale**: ...\n\n**Strategic Actions**: ..."
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||||
|
rm = create_research_manager(llm)
|
||||||
|
result = rm(_make_rm_state())
|
||||||
|
assert result["investment_plan"] == plain_response
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from cli.utils import normalize_ticker_symbol
|
from cli.utils import normalize_ticker_symbol
|
||||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
class TickerSymbolHandlingTests(unittest.TestCase):
|
class TickerSymbolHandlingTests(unittest.TestCase):
|
||||||
def test_normalize_ticker_symbol_preserves_exchange_suffix(self):
|
def test_normalize_ticker_symbol_preserves_exchange_suffix(self):
|
||||||
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")
|
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
import os
|
|
||||||
os.environ.setdefault("PYTHONUTF8", "1")
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from .utils.agent_utils import create_msg_delete
|
from .utils.agent_utils import create_msg_delete
|
||||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||||
from .utils.memory import FinancialSituationMemory
|
|
||||||
|
|
||||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||||
from .analysts.market_analyst import create_market_analyst
|
from .analysts.market_analyst import create_market_analyst
|
||||||
@@ -20,7 +19,6 @@ from .managers.portfolio_manager import create_portfolio_manager
|
|||||||
from .trader.trader import create_trader
|
from .trader.trader import create_trader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FinancialSituationMemory",
|
|
||||||
"AgentState",
|
"AgentState",
|
||||||
"create_msg_delete",
|
"create_msg_delete",
|
||||||
"InvestDebateState",
|
"InvestDebateState",
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_balance_sheet,
|
get_balance_sheet,
|
||||||
@@ -8,6 +6,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||||||
get_fundamentals,
|
get_fundamentals,
|
||||||
get_income_statement,
|
get_income_statement,
|
||||||
get_insider_transactions,
|
get_insider_transactions,
|
||||||
|
get_language_instruction,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
@@ -27,7 +26,8 @@ def create_fundamentals_analyst(llm):
|
|||||||
system_message = (
|
system_message = (
|
||||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||||
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."
|
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."
|
||||||
+ " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements.",
|
+ " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements."
|
||||||
|
+ get_language_instruction(),
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_indicators,
|
get_indicators,
|
||||||
|
get_language_instruction,
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
@@ -47,6 +46,7 @@ Volume-Based Indicators:
|
|||||||
|
|
||||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
|
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
|
||||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
|
+ get_language_instruction()
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_global_news,
|
get_global_news,
|
||||||
|
get_language_instruction,
|
||||||
get_news,
|
get_news,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
@@ -22,6 +21,7 @@ def create_news_analyst(llm):
|
|||||||
system_message = (
|
system_message = (
|
||||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
|
+ get_language_instruction()
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
||||||
import json
|
|
||||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_news
|
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +15,7 @@ def create_social_media_analyst(llm):
|
|||||||
system_message = (
|
system_message = (
|
||||||
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
|
+ get_language_instruction()
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|||||||
@@ -1,25 +1,43 @@
|
|||||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
"""Portfolio Manager: synthesises the risk-analyst debate into the final decision.
|
||||||
|
|
||||||
|
Uses LangChain's ``with_structured_output`` so the LLM produces a typed
|
||||||
|
``PortfolioDecision`` directly, in a single call. The result is rendered
|
||||||
|
back to markdown for storage in ``final_trade_decision`` so memory log,
|
||||||
|
CLI display, and saved reports continue to consume the same shape they do
|
||||||
|
today. When a provider does not expose structured output, the agent falls
|
||||||
|
back gracefully to free-text generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from tradingagents.agents.schemas import PortfolioDecision, render_pm_decision
|
||||||
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
build_instrument_context,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
from tradingagents.agents.utils.structured import (
|
||||||
|
bind_structured,
|
||||||
|
invoke_structured_or_freetext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_portfolio_manager(llm, memory):
|
def create_portfolio_manager(llm):
|
||||||
|
structured_llm = bind_structured(llm, PortfolioDecision, "Portfolio Manager")
|
||||||
|
|
||||||
def portfolio_manager_node(state) -> dict:
|
def portfolio_manager_node(state) -> dict:
|
||||||
|
|
||||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||||
|
|
||||||
history = state["risk_debate_state"]["history"]
|
history = state["risk_debate_state"]["history"]
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
market_research_report = state["market_report"]
|
research_plan = state["investment_plan"]
|
||||||
news_report = state["news_report"]
|
trader_plan = state["trader_investment_plan"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
|
||||||
sentiment_report = state["sentiment_report"]
|
|
||||||
trader_plan = state["investment_plan"]
|
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
past_context = state.get("past_context", "")
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
lessons_line = (
|
||||||
|
f"- Lessons from prior decisions and outcomes:\n{past_context}\n"
|
||||||
past_memory_str = ""
|
if past_context
|
||||||
for i, rec in enumerate(past_memories, 1):
|
else ""
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
)
|
||||||
|
|
||||||
prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision.
|
prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision.
|
||||||
|
|
||||||
@@ -35,27 +53,26 @@ def create_portfolio_manager(llm, memory):
|
|||||||
- **Sell**: Exit position or avoid entry
|
- **Sell**: Exit position or avoid entry
|
||||||
|
|
||||||
**Context:**
|
**Context:**
|
||||||
- Trader's proposed plan: **{trader_plan}**
|
- Research Manager's investment plan: **{research_plan}**
|
||||||
- Lessons from past decisions: **{past_memory_str}**
|
- Trader's transaction proposal: **{trader_plan}**
|
||||||
|
{lessons_line}
|
||||||
**Required Output Structure:**
|
|
||||||
1. **Rating**: State one of Buy / Overweight / Hold / Underweight / Sell.
|
|
||||||
2. **Executive Summary**: A concise action plan covering entry strategy, position sizing, key risk levels, and time horizon.
|
|
||||||
3. **Investment Thesis**: Detailed reasoning anchored in the analysts' debate and past reflections.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Risk Analysts Debate History:**
|
**Risk Analysts Debate History:**
|
||||||
{history}
|
{history}
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
Be decisive and ground every conclusion in specific evidence from the analysts."""
|
Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}"""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
final_trade_decision = invoke_structured_or_freetext(
|
||||||
|
structured_llm,
|
||||||
|
llm,
|
||||||
|
prompt,
|
||||||
|
render_pm_decision,
|
||||||
|
"Portfolio Manager",
|
||||||
|
)
|
||||||
|
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
"judge_decision": response.content,
|
"judge_decision": final_trade_decision,
|
||||||
"history": risk_debate_state["history"],
|
"history": risk_debate_state["history"],
|
||||||
"aggressive_history": risk_debate_state["aggressive_history"],
|
"aggressive_history": risk_debate_state["aggressive_history"],
|
||||||
"conservative_history": risk_debate_state["conservative_history"],
|
"conservative_history": risk_debate_state["conservative_history"],
|
||||||
@@ -69,7 +86,7 @@ Be decisive and ground every conclusion in specific evidence from the analysts."
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"risk_debate_state": new_risk_debate_state,
|
"risk_debate_state": new_risk_debate_state,
|
||||||
"final_trade_decision": response.content,
|
"final_trade_decision": final_trade_decision,
|
||||||
}
|
}
|
||||||
|
|
||||||
return portfolio_manager_node
|
return portfolio_manager_node
|
||||||
|
|||||||
@@ -1,60 +1,64 @@
|
|||||||
import time
|
"""Research Manager: turns the bull/bear debate into a structured investment plan for the trader."""
|
||||||
import json
|
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from tradingagents.agents.schemas import ResearchPlan, render_research_plan
|
||||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||||
|
from tradingagents.agents.utils.structured import (
|
||||||
|
bind_structured,
|
||||||
|
invoke_structured_or_freetext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_research_manager(llm, memory):
|
def create_research_manager(llm):
|
||||||
|
structured_llm = bind_structured(llm, ResearchPlan, "Research Manager")
|
||||||
|
|
||||||
def research_manager_node(state) -> dict:
|
def research_manager_node(state) -> dict:
|
||||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||||
history = state["investment_debate_state"].get("history", "")
|
history = state["investment_debate_state"].get("history", "")
|
||||||
market_research_report = state["market_report"]
|
|
||||||
sentiment_report = state["sentiment_report"]
|
|
||||||
news_report = state["news_report"]
|
|
||||||
fundamentals_report = state["fundamentals_report"]
|
|
||||||
|
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
prompt = f"""As the Research Manager and debate facilitator, your role is to critically evaluate this round of debate and deliver a clear, actionable investment plan for the trader.
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
|
|
||||||
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
|
||||||
|
|
||||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
|
||||||
|
|
||||||
Additionally, develop a detailed investment plan for the trader. This should include:
|
|
||||||
|
|
||||||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
|
||||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
|
||||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
|
||||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
|
||||||
|
|
||||||
Here are your past reflections on mistakes:
|
|
||||||
\"{past_memory_str}\"
|
|
||||||
|
|
||||||
{instrument_context}
|
{instrument_context}
|
||||||
|
|
||||||
Here is the debate:
|
---
|
||||||
Debate History:
|
|
||||||
|
**Rating Scale** (use exactly one):
|
||||||
|
- **Buy**: Strong conviction in the bull thesis; recommend taking or growing the position
|
||||||
|
- **Overweight**: Constructive view; recommend gradually increasing exposure
|
||||||
|
- **Hold**: Balanced view; recommend maintaining the current position
|
||||||
|
- **Underweight**: Cautious view; recommend trimming exposure
|
||||||
|
- **Sell**: Strong conviction in the bear thesis; recommend exiting or avoiding the position
|
||||||
|
|
||||||
|
Commit to a clear stance whenever the debate's strongest arguments warrant one; reserve Hold for situations where the evidence on both sides is genuinely balanced.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Debate History:**
|
||||||
{history}"""
|
{history}"""
|
||||||
response = llm.invoke(prompt)
|
|
||||||
|
investment_plan = invoke_structured_or_freetext(
|
||||||
|
structured_llm,
|
||||||
|
llm,
|
||||||
|
prompt,
|
||||||
|
render_research_plan,
|
||||||
|
"Research Manager",
|
||||||
|
)
|
||||||
|
|
||||||
new_investment_debate_state = {
|
new_investment_debate_state = {
|
||||||
"judge_decision": response.content,
|
"judge_decision": investment_plan,
|
||||||
"history": investment_debate_state.get("history", ""),
|
"history": investment_debate_state.get("history", ""),
|
||||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||||
"current_response": response.content,
|
"current_response": investment_plan,
|
||||||
"count": investment_debate_state["count"],
|
"count": investment_debate_state["count"],
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"investment_debate_state": new_investment_debate_state,
|
"investment_debate_state": new_investment_debate_state,
|
||||||
"investment_plan": response.content,
|
"investment_plan": investment_plan,
|
||||||
}
|
}
|
||||||
|
|
||||||
return research_manager_node
|
return research_manager_node
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
from langchain_core.messages import AIMessage
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_bear_researcher(llm, memory):
|
def create_bear_researcher(llm):
|
||||||
def bear_node(state) -> dict:
|
def bear_node(state) -> dict:
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
history = investment_debate_state.get("history", "")
|
history = investment_debate_state.get("history", "")
|
||||||
@@ -15,13 +12,6 @@ def create_bear_researcher(llm, memory):
|
|||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
|
|
||||||
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||||
|
|
||||||
Key points to focus on:
|
Key points to focus on:
|
||||||
@@ -40,8 +30,7 @@ Latest world affairs news: {news_report}
|
|||||||
Company fundamentals report: {fundamentals_report}
|
Company fundamentals report: {fundamentals_report}
|
||||||
Conversation history of the debate: {history}
|
Conversation history of the debate: {history}
|
||||||
Last bull argument: {current_response}
|
Last bull argument: {current_response}
|
||||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock.
|
||||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
from langchain_core.messages import AIMessage
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_bull_researcher(llm, memory):
|
def create_bull_researcher(llm):
|
||||||
def bull_node(state) -> dict:
|
def bull_node(state) -> dict:
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
history = investment_debate_state.get("history", "")
|
history = investment_debate_state.get("history", "")
|
||||||
@@ -15,13 +12,6 @@ def create_bull_researcher(llm, memory):
|
|||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
|
|
||||||
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||||
|
|
||||||
Key points to focus on:
|
Key points to focus on:
|
||||||
@@ -38,8 +28,7 @@ Latest world affairs news: {news_report}
|
|||||||
Company fundamentals report: {fundamentals_report}
|
Company fundamentals report: {fundamentals_report}
|
||||||
Conversation history of the debate: {history}
|
Conversation history of the debate: {history}
|
||||||
Last bear argument: {current_response}
|
Last bear argument: {current_response}
|
||||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position.
|
||||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_aggressive_debator(llm):
|
def create_aggressive_debator(llm):
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
from langchain_core.messages import AIMessage
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_conservative_debator(llm):
|
def create_conservative_debator(llm):
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_neutral_debator(llm):
|
def create_neutral_debator(llm):
|
||||||
|
|||||||
228
tradingagents/agents/schemas.py
Normal file
228
tradingagents/agents/schemas.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""Pydantic schemas used by agents that produce structured output.
|
||||||
|
|
||||||
|
The framework's primary artifact is still prose: each agent's natural-language
|
||||||
|
reasoning is what users read in the saved markdown reports and what the
|
||||||
|
downstream agents read as context. Structured output is layered onto the
|
||||||
|
three decision-making agents (Research Manager, Trader, Portfolio Manager)
|
||||||
|
so that:
|
||||||
|
|
||||||
|
- Their outputs follow consistent section headers across runs and providers
|
||||||
|
- Each provider's native structured-output mode is used (json_schema for
|
||||||
|
OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic)
|
||||||
|
- Schema field descriptions become the model's output instructions, freeing
|
||||||
|
the prompt body to focus on context and the rating-scale guidance
|
||||||
|
- A render helper turns the parsed Pydantic instance back into the same
|
||||||
|
markdown shape the rest of the system already consumes, so display,
|
||||||
|
memory log, and saved reports keep working unchanged
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared rating types
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PortfolioRating(str, Enum):
|
||||||
|
"""5-tier rating used by the Research Manager and Portfolio Manager."""
|
||||||
|
|
||||||
|
BUY = "Buy"
|
||||||
|
OVERWEIGHT = "Overweight"
|
||||||
|
HOLD = "Hold"
|
||||||
|
UNDERWEIGHT = "Underweight"
|
||||||
|
SELL = "Sell"
|
||||||
|
|
||||||
|
|
||||||
|
class TraderAction(str, Enum):
|
||||||
|
"""3-tier transaction direction used by the Trader.
|
||||||
|
|
||||||
|
The Trader's job is to translate the Research Manager's investment plan
|
||||||
|
into a concrete transaction proposal: should the desk execute a Buy, a
|
||||||
|
Sell, or sit on Hold this round. Position sizing and the nuanced
|
||||||
|
Overweight / Underweight calls happen later at the Portfolio Manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BUY = "Buy"
|
||||||
|
HOLD = "Hold"
|
||||||
|
SELL = "Sell"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Research Manager
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ResearchPlan(BaseModel):
|
||||||
|
"""Structured investment plan produced by the Research Manager.
|
||||||
|
|
||||||
|
Hand-off to the Trader: the recommendation pins the directional view,
|
||||||
|
the rationale captures which side of the bull/bear debate carried the
|
||||||
|
argument, and the strategic actions translate that into concrete
|
||||||
|
instructions the trader can execute against.
|
||||||
|
"""
|
||||||
|
|
||||||
|
recommendation: PortfolioRating = Field(
|
||||||
|
description=(
|
||||||
|
"The investment recommendation. Exactly one of Buy / Overweight / "
|
||||||
|
"Hold / Underweight / Sell. Reserve Hold for situations where the "
|
||||||
|
"evidence on both sides is genuinely balanced; otherwise commit to "
|
||||||
|
"the side with the stronger arguments."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
rationale: str = Field(
|
||||||
|
description=(
|
||||||
|
"Conversational summary of the key points from both sides of the "
|
||||||
|
"debate, ending with which arguments led to the recommendation. "
|
||||||
|
"Speak naturally, as if to a teammate."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
strategic_actions: str = Field(
|
||||||
|
description=(
|
||||||
|
"Concrete steps for the trader to implement the recommendation, "
|
||||||
|
"including position sizing guidance consistent with the rating."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_research_plan(plan: ResearchPlan) -> str:
|
||||||
|
"""Render a ResearchPlan to markdown for storage and the trader's prompt context."""
|
||||||
|
return "\n".join([
|
||||||
|
f"**Recommendation**: {plan.recommendation.value}",
|
||||||
|
"",
|
||||||
|
f"**Rationale**: {plan.rationale}",
|
||||||
|
"",
|
||||||
|
f"**Strategic Actions**: {plan.strategic_actions}",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trader
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TraderProposal(BaseModel):
|
||||||
|
"""Structured transaction proposal produced by the Trader.
|
||||||
|
|
||||||
|
The trader reads the Research Manager's investment plan and the analyst
|
||||||
|
reports, then turns them into a concrete transaction: what action to
|
||||||
|
take, the reasoning that justifies it, and the practical levels for
|
||||||
|
entry, stop-loss, and sizing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
action: TraderAction = Field(
|
||||||
|
description="The transaction direction. Exactly one of Buy / Hold / Sell.",
|
||||||
|
)
|
||||||
|
reasoning: str = Field(
|
||||||
|
description=(
|
||||||
|
"The case for this action, anchored in the analysts' reports and "
|
||||||
|
"the research plan. Two to four sentences."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
entry_price: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional entry price target in the instrument's quote currency.",
|
||||||
|
)
|
||||||
|
stop_loss: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional stop-loss price in the instrument's quote currency.",
|
||||||
|
)
|
||||||
|
position_sizing: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional sizing guidance, e.g. '5% of portfolio'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_trader_proposal(proposal: TraderProposal) -> str:
|
||||||
|
"""Render a TraderProposal to markdown.
|
||||||
|
|
||||||
|
The trailing ``FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**`` line is
|
||||||
|
preserved for backward compatibility with the analyst stop-signal text
|
||||||
|
and any external code that greps for it.
|
||||||
|
"""
|
||||||
|
parts = [
|
||||||
|
f"**Action**: {proposal.action.value}",
|
||||||
|
"",
|
||||||
|
f"**Reasoning**: {proposal.reasoning}",
|
||||||
|
]
|
||||||
|
if proposal.entry_price is not None:
|
||||||
|
parts.extend(["", f"**Entry Price**: {proposal.entry_price}"])
|
||||||
|
if proposal.stop_loss is not None:
|
||||||
|
parts.extend(["", f"**Stop Loss**: {proposal.stop_loss}"])
|
||||||
|
if proposal.position_sizing:
|
||||||
|
parts.extend(["", f"**Position Sizing**: {proposal.position_sizing}"])
|
||||||
|
parts.extend([
|
||||||
|
"",
|
||||||
|
f"FINAL TRANSACTION PROPOSAL: **{proposal.action.value.upper()}**",
|
||||||
|
])
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Portfolio Manager
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PortfolioDecision(BaseModel):
|
||||||
|
"""Structured output produced by the Portfolio Manager.
|
||||||
|
|
||||||
|
The model fills every field as part of its primary LLM call; no separate
|
||||||
|
extraction pass is required. Field descriptions double as the model's
|
||||||
|
output instructions, so the prompt body only needs to convey context and
|
||||||
|
the rating-scale guidance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rating: PortfolioRating = Field(
|
||||||
|
description=(
|
||||||
|
"The final position rating. Exactly one of Buy / Overweight / Hold / "
|
||||||
|
"Underweight / Sell, picked based on the analysts' debate."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
executive_summary: str = Field(
|
||||||
|
description=(
|
||||||
|
"A concise action plan covering entry strategy, position sizing, "
|
||||||
|
"key risk levels, and time horizon. Two to four sentences."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
investment_thesis: str = Field(
|
||||||
|
description=(
|
||||||
|
"Detailed reasoning anchored in specific evidence from the analysts' "
|
||||||
|
"debate. If prior lessons are referenced in the prompt context, "
|
||||||
|
"incorporate them; otherwise rely solely on the current analysis."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
price_target: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional target price in the instrument's quote currency.",
|
||||||
|
)
|
||||||
|
time_horizon: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional recommended holding period, e.g. '3-6 months'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_pm_decision(decision: PortfolioDecision) -> str:
|
||||||
|
"""Render a PortfolioDecision back to the markdown shape the rest of the system expects.
|
||||||
|
|
||||||
|
Memory log, CLI display, and saved report files all read this markdown,
|
||||||
|
so the rendered output preserves the exact section headers (``**Rating**``,
|
||||||
|
``**Executive Summary**``, ``**Investment Thesis**``) that downstream
|
||||||
|
parsers and the report writers already handle.
|
||||||
|
"""
|
||||||
|
parts = [
|
||||||
|
f"**Rating**: {decision.rating.value}",
|
||||||
|
"",
|
||||||
|
f"**Executive Summary**: {decision.executive_summary}",
|
||||||
|
"",
|
||||||
|
f"**Investment Thesis**: {decision.investment_thesis}",
|
||||||
|
]
|
||||||
|
if decision.price_target is not None:
|
||||||
|
parts.extend(["", f"**Price Target**: {decision.price_target}"])
|
||||||
|
if decision.time_horizon:
|
||||||
|
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
|
||||||
|
return "\n".join(parts)
|
||||||
@@ -1,48 +1,60 @@
|
|||||||
|
"""Trader: turns the Research Manager's investment plan into a concrete transaction proposal."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from tradingagents.agents.schemas import TraderProposal, render_trader_proposal
|
||||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||||
|
from tradingagents.agents.utils.structured import (
|
||||||
|
bind_structured,
|
||||||
|
invoke_structured_or_freetext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_trader(llm, memory):
|
def create_trader(llm):
|
||||||
|
structured_llm = bind_structured(llm, TraderProposal, "Trader")
|
||||||
|
|
||||||
def trader_node(state, name):
|
def trader_node(state, name):
|
||||||
company_name = state["company_of_interest"]
|
company_name = state["company_of_interest"]
|
||||||
instrument_context = build_instrument_context(company_name)
|
instrument_context = build_instrument_context(company_name)
|
||||||
investment_plan = state["investment_plan"]
|
investment_plan = state["investment_plan"]
|
||||||
market_research_report = state["market_report"]
|
|
||||||
sentiment_report = state["sentiment_report"]
|
|
||||||
news_report = state["news_report"]
|
|
||||||
fundamentals_report = state["fundamentals_report"]
|
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
|
||||||
if past_memories:
|
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
else:
|
|
||||||
past_memory_str = "No past memories found."
|
|
||||||
|
|
||||||
context = {
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. {instrument_context} This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Apply lessons from past decisions to strengthen your analysis. Here are reflections from similar situations you traded in and the lessons learned: {past_memory_str}""",
|
"content": (
|
||||||
|
"You are a trading agent analyzing market data to make investment decisions. "
|
||||||
|
"Based on your analysis, provide a specific recommendation to buy, sell, or hold. "
|
||||||
|
"Anchor your reasoning in the analysts' reports and the research plan."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"Based on a comprehensive analysis by a team of analysts, here is an investment "
|
||||||
|
f"plan tailored for {company_name}. {instrument_context} This plan incorporates "
|
||||||
|
f"insights from current technical market trends, macroeconomic indicators, and "
|
||||||
|
f"social media sentiment. Use this plan as a foundation for evaluating your next "
|
||||||
|
f"trading decision.\n\nProposed Investment Plan: {investment_plan}\n\n"
|
||||||
|
f"Leverage these insights to make an informed and strategic decision."
|
||||||
|
),
|
||||||
},
|
},
|
||||||
context,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
result = llm.invoke(messages)
|
trader_plan = invoke_structured_or_freetext(
|
||||||
|
structured_llm,
|
||||||
|
llm,
|
||||||
|
messages,
|
||||||
|
render_trader_proposal,
|
||||||
|
"Trader",
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [AIMessage(content=trader_plan)],
|
||||||
"trader_investment_plan": result.content,
|
"trader_investment_plan": trader_plan,
|
||||||
"sender": name,
|
"sender": name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
from typing import Annotated, Sequence
|
from typing import Annotated
|
||||||
from datetime import date, timedelta, datetime
|
from typing_extensions import TypedDict
|
||||||
from typing_extensions import TypedDict, Optional
|
from langgraph.graph import MessagesState
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from tradingagents.agents import *
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
|
||||||
|
|
||||||
|
|
||||||
# Researcher team state
|
# Researcher team state
|
||||||
@@ -74,3 +70,4 @@ class AgentState(MessagesState):
|
|||||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
RiskDebateState, "Current state of the debate on evaluating risk"
|
||||||
]
|
]
|
||||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||||
|
past_context: Annotated[str, "Memory log context injected at run start (same-ticker decisions + cross-ticker lessons)"]
|
||||||
|
|||||||
@@ -20,6 +20,20 @@ from tradingagents.agents.utils.news_data_tools import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_language_instruction() -> str:
|
||||||
|
"""Return a prompt instruction for the configured output language.
|
||||||
|
|
||||||
|
Returns empty string when English (default), so no extra tokens are used.
|
||||||
|
Only applied to user-facing agents (analysts, portfolio manager).
|
||||||
|
Internal debate agents stay in English for reasoning quality.
|
||||||
|
"""
|
||||||
|
from tradingagents.dataflows.config import get_config
|
||||||
|
lang = get_config().get("output_language", "English")
|
||||||
|
if lang.strip().lower() == "english":
|
||||||
|
return ""
|
||||||
|
return f" Write your entire response in {lang}."
|
||||||
|
|
||||||
|
|
||||||
def build_instrument_context(ticker: str) -> str:
|
def build_instrument_context(ticker: str) -> str:
|
||||||
"""Describe the exact instrument so agents preserve exchange-qualified tickers."""
|
"""Describe the exact instrument so agents preserve exchange-qualified tickers."""
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -1,144 +1,300 @@
|
|||||||
"""Financial situation memory using BM25 for lexical similarity matching.
|
"""Append-only markdown decision log for TradingAgents."""
|
||||||
|
|
||||||
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
|
from typing import List, Optional
|
||||||
no token limits, works offline with any LLM provider.
|
from pathlib import Path
|
||||||
"""
|
|
||||||
|
|
||||||
from rank_bm25 import BM25Okapi
|
|
||||||
from typing import List, Tuple
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.rating import parse_rating
|
||||||
|
|
||||||
class FinancialSituationMemory:
|
|
||||||
"""Memory system for storing and retrieving financial situations using BM25."""
|
|
||||||
|
|
||||||
def __init__(self, name: str, config: dict = None):
|
class TradingMemoryLog:
|
||||||
"""Initialize the memory system.
|
"""Append-only markdown log of trading decisions and reflections."""
|
||||||
|
|
||||||
Args:
|
# HTML comment: cannot appear in LLM prose output, safe as a hard delimiter
|
||||||
name: Name identifier for this memory instance
|
_SEPARATOR = "\n\n<!-- ENTRY_END -->\n\n"
|
||||||
config: Configuration dict (kept for API compatibility, not used for BM25)
|
# Precompiled patterns — avoids re-compilation on every load_entries() call
|
||||||
"""
|
_DECISION_RE = re.compile(r"DECISION:\n(.*?)(?=\nREFLECTION:|\Z)", re.DOTALL)
|
||||||
self.name = name
|
_REFLECTION_RE = re.compile(r"REFLECTION:\n(.*?)$", re.DOTALL)
|
||||||
self.documents: List[str] = []
|
|
||||||
self.recommendations: List[str] = []
|
|
||||||
self.bm25 = None
|
|
||||||
|
|
||||||
def _tokenize(self, text: str) -> List[str]:
|
def __init__(self, config: dict = None):
|
||||||
"""Tokenize text for BM25 indexing.
|
cfg = config or {}
|
||||||
|
self._log_path = None
|
||||||
|
path = cfg.get("memory_log_path")
|
||||||
|
if path:
|
||||||
|
self._log_path = Path(path).expanduser()
|
||||||
|
self._log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Optional cap on resolved entries. None disables rotation.
|
||||||
|
self._max_entries = cfg.get("memory_log_max_entries")
|
||||||
|
|
||||||
Simple whitespace + punctuation tokenization with lowercasing.
|
# --- Write path (Phase A) ---
|
||||||
"""
|
|
||||||
# Lowercase and split on non-alphanumeric characters
|
|
||||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def _rebuild_index(self):
|
def store_decision(
|
||||||
"""Rebuild the BM25 index after adding documents."""
|
self,
|
||||||
if self.documents:
|
ticker: str,
|
||||||
tokenized_docs = [self._tokenize(doc) for doc in self.documents]
|
trade_date: str,
|
||||||
self.bm25 = BM25Okapi(tokenized_docs)
|
final_trade_decision: str,
|
||||||
else:
|
) -> None:
|
||||||
self.bm25 = None
|
"""Append pending entry at end of propagate(). No LLM call."""
|
||||||
|
if not self._log_path:
|
||||||
|
return
|
||||||
|
# Idempotency guard: fast raw-text scan instead of full parse
|
||||||
|
if self._log_path.exists():
|
||||||
|
raw = self._log_path.read_text(encoding="utf-8")
|
||||||
|
for line in raw.splitlines():
|
||||||
|
if line.startswith(f"[{trade_date} | {ticker} |") and line.endswith("| pending]"):
|
||||||
|
return
|
||||||
|
rating = parse_rating(final_trade_decision)
|
||||||
|
tag = f"[{trade_date} | {ticker} | {rating} | pending]"
|
||||||
|
entry = f"{tag}\n\nDECISION:\n{final_trade_decision}{self._SEPARATOR}"
|
||||||
|
with open(self._log_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(entry)
|
||||||
|
|
||||||
def add_situations(self, situations_and_advice: List[Tuple[str, str]]):
|
# --- Read path (Phase A) ---
|
||||||
"""Add financial situations and their corresponding advice.
|
|
||||||
|
|
||||||
Args:
|
def load_entries(self) -> List[dict]:
|
||||||
situations_and_advice: List of tuples (situation, recommendation)
|
"""Parse all entries from log. Returns list of dicts."""
|
||||||
"""
|
if not self._log_path or not self._log_path.exists():
|
||||||
for situation, recommendation in situations_and_advice:
|
|
||||||
self.documents.append(situation)
|
|
||||||
self.recommendations.append(recommendation)
|
|
||||||
|
|
||||||
# Rebuild BM25 index with new documents
|
|
||||||
self._rebuild_index()
|
|
||||||
|
|
||||||
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
|
|
||||||
"""Find matching recommendations using BM25 similarity.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_situation: The current financial situation to match against
|
|
||||||
n_matches: Number of top matches to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with matched_situation, recommendation, and similarity_score
|
|
||||||
"""
|
|
||||||
if not self.documents or self.bm25 is None:
|
|
||||||
return []
|
return []
|
||||||
|
text = self._log_path.read_text(encoding="utf-8")
|
||||||
|
raw_entries = [e.strip() for e in text.split(self._SEPARATOR) if e.strip()]
|
||||||
|
entries = []
|
||||||
|
for raw in raw_entries:
|
||||||
|
parsed = self._parse_entry(raw)
|
||||||
|
if parsed:
|
||||||
|
entries.append(parsed)
|
||||||
|
return entries
|
||||||
|
|
||||||
# Tokenize query
|
def get_pending_entries(self) -> List[dict]:
|
||||||
query_tokens = self._tokenize(current_situation)
|
"""Return entries with outcome:pending (for Phase B)."""
|
||||||
|
return [e for e in self.load_entries() if e.get("pending")]
|
||||||
|
|
||||||
# Get BM25 scores for all documents
|
def get_past_context(self, ticker: str, n_same: int = 5, n_cross: int = 3) -> str:
|
||||||
scores = self.bm25.get_scores(query_tokens)
|
"""Return formatted past context string for agent prompt injection."""
|
||||||
|
entries = [e for e in self.load_entries() if not e.get("pending")]
|
||||||
|
if not entries:
|
||||||
|
return ""
|
||||||
|
|
||||||
# Get top-n indices sorted by score (descending)
|
same, cross = [], []
|
||||||
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_matches]
|
for e in reversed(entries):
|
||||||
|
if len(same) >= n_same and len(cross) >= n_cross:
|
||||||
|
break
|
||||||
|
if e["ticker"] == ticker and len(same) < n_same:
|
||||||
|
same.append(e)
|
||||||
|
elif e["ticker"] != ticker and len(cross) < n_cross:
|
||||||
|
cross.append(e)
|
||||||
|
|
||||||
# Build results
|
if not same and not cross:
|
||||||
results = []
|
return ""
|
||||||
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
|
|
||||||
|
|
||||||
for idx in top_indices:
|
parts = []
|
||||||
# Normalize score to 0-1 range for consistency
|
if same:
|
||||||
normalized_score = scores[idx] / max_score if max_score > 0 else 0
|
parts.append(f"Past analyses of {ticker} (most recent first):")
|
||||||
results.append({
|
parts.extend(self._format_full(e) for e in same)
|
||||||
"matched_situation": self.documents[idx],
|
if cross:
|
||||||
"recommendation": self.recommendations[idx],
|
parts.append("Recent cross-ticker lessons:")
|
||||||
"similarity_score": normalized_score,
|
parts.extend(self._format_reflection_only(e) for e in cross)
|
||||||
})
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
return results
|
# --- Update path (Phase B) ---
|
||||||
|
|
||||||
def clear(self):
|
def update_with_outcome(
|
||||||
"""Clear all stored memories."""
|
self,
|
||||||
self.documents = []
|
ticker: str,
|
||||||
self.recommendations = []
|
trade_date: str,
|
||||||
self.bm25 = None
|
raw_return: float,
|
||||||
|
alpha_return: float,
|
||||||
|
holding_days: int,
|
||||||
|
reflection: str,
|
||||||
|
) -> None:
|
||||||
|
"""Replace pending tag and append REFLECTION section using atomic write.
|
||||||
|
|
||||||
|
Finds the first pending entry matching (trade_date, ticker), updates
|
||||||
if __name__ == "__main__":
|
its tag with return figures, and appends a REFLECTION section. Uses
|
||||||
# Example usage
|
a temp-file + os.replace() so a crash mid-write never corrupts the log.
|
||||||
matcher = FinancialSituationMemory("test_memory")
|
|
||||||
|
|
||||||
# Example data
|
|
||||||
example_data = [
|
|
||||||
(
|
|
||||||
"High inflation rate with rising interest rates and declining consumer spending",
|
|
||||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
|
||||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
|
||||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Market showing signs of sector rotation with rising yields",
|
|
||||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add the example situations and recommendations
|
|
||||||
matcher.add_situations(example_data)
|
|
||||||
|
|
||||||
# Example query
|
|
||||||
current_situation = """
|
|
||||||
Market showing increased volatility in tech sector, with institutional investors
|
|
||||||
reducing positions and rising interest rates affecting growth stock valuations
|
|
||||||
"""
|
"""
|
||||||
|
if not self._log_path or not self._log_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
text = self._log_path.read_text(encoding="utf-8")
|
||||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
blocks = text.split(self._SEPARATOR)
|
||||||
|
|
||||||
for i, rec in enumerate(recommendations, 1):
|
pending_prefix = f"[{trade_date} | {ticker} |"
|
||||||
print(f"\nMatch {i}:")
|
raw_pct = f"{raw_return:+.1%}"
|
||||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
alpha_pct = f"{alpha_return:+.1%}"
|
||||||
print(f"Matched Situation: {rec['matched_situation']}")
|
|
||||||
print(f"Recommendation: {rec['recommendation']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
updated = False
|
||||||
print(f"Error during recommendation: {str(e)}")
|
new_blocks = []
|
||||||
|
for block in blocks:
|
||||||
|
stripped = block.strip()
|
||||||
|
if not stripped:
|
||||||
|
new_blocks.append(block)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines = stripped.splitlines()
|
||||||
|
tag_line = lines[0].strip()
|
||||||
|
|
||||||
|
if (
|
||||||
|
not updated
|
||||||
|
and tag_line.startswith(pending_prefix)
|
||||||
|
and tag_line.endswith("| pending]")
|
||||||
|
):
|
||||||
|
# Parse rating from the existing pending tag
|
||||||
|
fields = [f.strip() for f in tag_line[1:-1].split("|")]
|
||||||
|
rating = fields[2]
|
||||||
|
new_tag = (
|
||||||
|
f"[{trade_date} | {ticker} | {rating}"
|
||||||
|
f" | {raw_pct} | {alpha_pct} | {holding_days}d]"
|
||||||
|
)
|
||||||
|
rest = "\n".join(lines[1:])
|
||||||
|
new_blocks.append(
|
||||||
|
f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{reflection}"
|
||||||
|
)
|
||||||
|
updated = True
|
||||||
|
else:
|
||||||
|
new_blocks.append(block)
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_blocks = self._apply_rotation(new_blocks)
|
||||||
|
new_text = self._SEPARATOR.join(new_blocks)
|
||||||
|
tmp_path = self._log_path.with_suffix(".tmp")
|
||||||
|
tmp_path.write_text(new_text, encoding="utf-8")
|
||||||
|
tmp_path.replace(self._log_path)
|
||||||
|
|
||||||
|
def batch_update_with_outcomes(self, updates: List[dict]) -> None:
|
||||||
|
"""Apply multiple outcome updates in a single read + atomic write.
|
||||||
|
|
||||||
|
Each element of updates must have keys: ticker, trade_date,
|
||||||
|
raw_return, alpha_return, holding_days, reflection.
|
||||||
|
"""
|
||||||
|
if not self._log_path or not self._log_path.exists() or not updates:
|
||||||
|
return
|
||||||
|
|
||||||
|
text = self._log_path.read_text(encoding="utf-8")
|
||||||
|
blocks = text.split(self._SEPARATOR)
|
||||||
|
|
||||||
|
# Build lookup keyed by (trade_date, ticker) for O(1) dispatch
|
||||||
|
update_map = {(u["trade_date"], u["ticker"]): u for u in updates}
|
||||||
|
|
||||||
|
new_blocks = []
|
||||||
|
for block in blocks:
|
||||||
|
stripped = block.strip()
|
||||||
|
if not stripped:
|
||||||
|
new_blocks.append(block)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines = stripped.splitlines()
|
||||||
|
tag_line = lines[0].strip()
|
||||||
|
|
||||||
|
matched = False
|
||||||
|
for (trade_date, ticker), upd in list(update_map.items()):
|
||||||
|
pending_prefix = f"[{trade_date} | {ticker} |"
|
||||||
|
if tag_line.startswith(pending_prefix) and tag_line.endswith("| pending]"):
|
||||||
|
fields = [f.strip() for f in tag_line[1:-1].split("|")]
|
||||||
|
rating = fields[2]
|
||||||
|
raw_pct = f"{upd['raw_return']:+.1%}"
|
||||||
|
alpha_pct = f"{upd['alpha_return']:+.1%}"
|
||||||
|
new_tag = (
|
||||||
|
f"[{trade_date} | {ticker} | {rating}"
|
||||||
|
f" | {raw_pct} | {alpha_pct} | {upd['holding_days']}d]"
|
||||||
|
)
|
||||||
|
rest = "\n".join(lines[1:])
|
||||||
|
new_blocks.append(
|
||||||
|
f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{upd['reflection']}"
|
||||||
|
)
|
||||||
|
del update_map[(trade_date, ticker)]
|
||||||
|
matched = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not matched:
|
||||||
|
new_blocks.append(block)
|
||||||
|
|
||||||
|
new_blocks = self._apply_rotation(new_blocks)
|
||||||
|
new_text = self._SEPARATOR.join(new_blocks)
|
||||||
|
tmp_path = self._log_path.with_suffix(".tmp")
|
||||||
|
tmp_path.write_text(new_text, encoding="utf-8")
|
||||||
|
tmp_path.replace(self._log_path)
|
||||||
|
|
||||||
|
# --- Helpers ---
|
||||||
|
|
||||||
|
def _apply_rotation(self, blocks: List[str]) -> List[str]:
|
||||||
|
"""Drop oldest resolved blocks when their count exceeds max_entries.
|
||||||
|
|
||||||
|
Pending blocks are always kept (they represent unprocessed work).
|
||||||
|
Returns ``blocks`` unchanged when rotation is disabled or under cap.
|
||||||
|
"""
|
||||||
|
if not self._max_entries or self._max_entries <= 0:
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
# Tag each block with (kept, is_resolved) by parsing tag-line markers.
|
||||||
|
decisions = []
|
||||||
|
for block in blocks:
|
||||||
|
stripped = block.strip()
|
||||||
|
if not stripped:
|
||||||
|
decisions.append((block, False))
|
||||||
|
continue
|
||||||
|
tag_line = stripped.splitlines()[0].strip()
|
||||||
|
is_resolved = (
|
||||||
|
tag_line.startswith("[")
|
||||||
|
and tag_line.endswith("]")
|
||||||
|
and not tag_line.endswith("| pending]")
|
||||||
|
)
|
||||||
|
decisions.append((block, is_resolved))
|
||||||
|
|
||||||
|
resolved_count = sum(1 for _, r in decisions if r)
|
||||||
|
if resolved_count <= self._max_entries:
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
to_drop = resolved_count - self._max_entries
|
||||||
|
kept: List[str] = []
|
||||||
|
for block, is_resolved in decisions:
|
||||||
|
if is_resolved and to_drop > 0:
|
||||||
|
to_drop -= 1
|
||||||
|
continue
|
||||||
|
kept.append(block)
|
||||||
|
return kept
|
||||||
|
|
||||||
|
def _parse_entry(self, raw: str) -> Optional[dict]:
|
||||||
|
lines = raw.strip().splitlines()
|
||||||
|
if not lines:
|
||||||
|
return None
|
||||||
|
tag_line = lines[0].strip()
|
||||||
|
if not (tag_line.startswith("[") and tag_line.endswith("]")):
|
||||||
|
return None
|
||||||
|
fields = [f.strip() for f in tag_line[1:-1].split("|")]
|
||||||
|
if len(fields) < 4:
|
||||||
|
return None
|
||||||
|
entry = {
|
||||||
|
"date": fields[0],
|
||||||
|
"ticker": fields[1],
|
||||||
|
"rating": fields[2],
|
||||||
|
"pending": fields[3] == "pending",
|
||||||
|
"raw": fields[3] if fields[3] != "pending" else None,
|
||||||
|
"alpha": fields[4] if len(fields) > 4 else None,
|
||||||
|
"holding": fields[5] if len(fields) > 5 else None,
|
||||||
|
}
|
||||||
|
body = "\n".join(lines[1:]).strip()
|
||||||
|
decision_match = self._DECISION_RE.search(body)
|
||||||
|
reflection_match = self._REFLECTION_RE.search(body)
|
||||||
|
entry["decision"] = decision_match.group(1).strip() if decision_match else ""
|
||||||
|
entry["reflection"] = reflection_match.group(1).strip() if reflection_match else ""
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def _format_full(self, e: dict) -> str:
|
||||||
|
raw = e["raw"] or "n/a"
|
||||||
|
alpha = e["alpha"] or "n/a"
|
||||||
|
holding = e["holding"] or "n/a"
|
||||||
|
tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {raw} | {alpha} | {holding}]"
|
||||||
|
parts = [tag, f"DECISION:\n{e['decision']}"]
|
||||||
|
if e["reflection"]:
|
||||||
|
parts.append(f"REFLECTION:\n{e['reflection']}")
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
def _format_reflection_only(self, e: dict) -> str:
|
||||||
|
tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {e['raw'] or 'n/a'}]"
|
||||||
|
if e["reflection"]:
|
||||||
|
return f"{tag}\n{e['reflection']}"
|
||||||
|
text = e["decision"][:300]
|
||||||
|
suffix = "..." if len(e["decision"]) > 300 else ""
|
||||||
|
return f"{tag}\n{text}{suffix}"
|
||||||
|
|||||||
50
tradingagents/agents/utils/rating.py
Normal file
50
tradingagents/agents/utils/rating.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Shared 5-tier rating vocabulary and a deterministic heuristic parser.
|
||||||
|
|
||||||
|
The same five-tier scale (Buy, Overweight, Hold, Underweight, Sell) is used by:
|
||||||
|
- The Research Manager (investment plan recommendation)
|
||||||
|
- The Portfolio Manager (final position decision)
|
||||||
|
- The signal processor (rating extracted for downstream consumers)
|
||||||
|
- The memory log (rating tag stored alongside each decision entry)
|
||||||
|
|
||||||
|
Centralising it here avoids drift between those call sites.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# Canonical, ordered 5-tier scale (most bullish to most bearish).
|
||||||
|
RATINGS_5_TIER: Tuple[str, ...] = (
|
||||||
|
"Buy", "Overweight", "Hold", "Underweight", "Sell",
|
||||||
|
)
|
||||||
|
|
||||||
|
_RATING_SET = {r.lower() for r in RATINGS_5_TIER}
|
||||||
|
|
||||||
|
# Matches "Rating: X" / "rating - X" / "Rating: **X**" — tolerates markdown
|
||||||
|
# bold wrappers and either a colon or hyphen separator.
|
||||||
|
_RATING_LABEL_RE = re.compile(r"rating.*?[:\-][\s*]*(\w+)", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_rating(text: str, default: str = "Hold") -> str:
|
||||||
|
"""Heuristically extract a 5-tier rating from prose text.
|
||||||
|
|
||||||
|
Two-pass strategy:
|
||||||
|
1. Look for an explicit "Rating: X" label (tolerant of markdown bold).
|
||||||
|
2. Fall back to the first 5-tier rating word found anywhere in the text.
|
||||||
|
|
||||||
|
Returns a Title-cased rating string, or ``default`` if no rating word appears.
|
||||||
|
"""
|
||||||
|
for line in text.splitlines():
|
||||||
|
m = _RATING_LABEL_RE.search(line)
|
||||||
|
if m and m.group(1).lower() in _RATING_SET:
|
||||||
|
return m.group(1).capitalize()
|
||||||
|
|
||||||
|
for line in text.splitlines():
|
||||||
|
for word in line.lower().split():
|
||||||
|
clean = word.strip("*:.,")
|
||||||
|
if clean in _RATING_SET:
|
||||||
|
return clean.capitalize()
|
||||||
|
|
||||||
|
return default
|
||||||
73
tradingagents/agents/utils/structured.py
Normal file
73
tradingagents/agents/utils/structured.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Shared helpers for invoking an agent with structured output and a graceful fallback.
|
||||||
|
|
||||||
|
The Portfolio Manager, Trader, and Research Manager all follow the same
|
||||||
|
canonical pattern:
|
||||||
|
|
||||||
|
1. At agent creation, wrap the LLM with ``with_structured_output(Schema)``
|
||||||
|
so the model returns a typed Pydantic instance. If the provider does
|
||||||
|
not support structured output (rare; mostly older Ollama models), the
|
||||||
|
wrap is skipped and the agent uses free-text generation instead.
|
||||||
|
2. At invocation, run the structured call and render the result back to
|
||||||
|
markdown. If the structured call itself fails for any reason
|
||||||
|
(malformed JSON from a weak model, transient provider issue), fall
|
||||||
|
back to a plain ``llm.invoke`` so the pipeline never blocks.
|
||||||
|
|
||||||
|
Centralising the pattern here keeps the agent factories small and ensures
|
||||||
|
all three agents log the same warnings when fallback fires.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Optional, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]:
|
||||||
|
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
|
||||||
|
|
||||||
|
Logs a warning when the binding fails so the user understands the agent
|
||||||
|
will use free-text generation for every call instead of one-shot fallback.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return llm.with_structured_output(schema)
|
||||||
|
except (NotImplementedError, AttributeError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"%s: provider does not support with_structured_output (%s); "
|
||||||
|
"falling back to free-text generation",
|
||||||
|
agent_name, exc,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_structured_or_freetext(
|
||||||
|
structured_llm: Optional[Any],
|
||||||
|
plain_llm: Any,
|
||||||
|
prompt: Any,
|
||||||
|
render: Callable[[T], str],
|
||||||
|
agent_name: str,
|
||||||
|
) -> str:
|
||||||
|
"""Run the structured call and render to markdown; fall back to free-text on any failure.
|
||||||
|
|
||||||
|
``prompt`` is whatever the underlying LLM accepts (a string for chat
|
||||||
|
invocations, a list of message dicts for chat models that take that
|
||||||
|
shape). The same value is forwarded to the free-text path so the
|
||||||
|
fallback sees the same input the structured call did.
|
||||||
|
"""
|
||||||
|
if structured_llm is not None:
|
||||||
|
try:
|
||||||
|
result = structured_llm.invoke(prompt)
|
||||||
|
return render(result)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"%s: structured-output invocation failed (%s); retrying once as free text",
|
||||||
|
agent_name, exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = plain_llm.invoke(prompt)
|
||||||
|
return response.content
|
||||||
@@ -22,10 +22,11 @@ def get_indicators(
|
|||||||
"""
|
"""
|
||||||
# LLMs sometimes pass multiple indicators as a comma-separated string;
|
# LLMs sometimes pass multiple indicators as a comma-separated string;
|
||||||
# split and process each individually.
|
# split and process each individually.
|
||||||
indicators = [i.strip() for i in indicator.split(",") if i.strip()]
|
indicators = [i.strip().lower() for i in indicator.split(",") if i.strip()]
|
||||||
if len(indicators) > 1:
|
|
||||||
results = []
|
results = []
|
||||||
for ind in indicators:
|
for ind in indicators:
|
||||||
|
try:
|
||||||
results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days))
|
results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days))
|
||||||
|
except ValueError as e:
|
||||||
|
results.append(str(e))
|
||||||
return "\n\n".join(results)
|
return "\n\n".join(results)
|
||||||
return route_to_vendor("get_indicators", symbol, indicator.strip(), curr_date, look_back_days)
|
|
||||||
@@ -1,6 +1,23 @@
|
|||||||
from .alpha_vantage_common import _make_api_request
|
from .alpha_vantage_common import _make_api_request
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_reports_by_date(result, curr_date: str):
|
||||||
|
"""Filter annualReports/quarterlyReports to exclude entries after curr_date.
|
||||||
|
|
||||||
|
Prevents look-ahead bias by removing fiscal periods that end after
|
||||||
|
the simulation's current date.
|
||||||
|
"""
|
||||||
|
if not curr_date or not isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
for key in ("annualReports", "quarterlyReports"):
|
||||||
|
if key in result:
|
||||||
|
result[key] = [
|
||||||
|
r for r in result[key]
|
||||||
|
if r.get("fiscalDateEnding", "") <= curr_date
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
|
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
|
||||||
@@ -19,59 +36,20 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
|||||||
return _make_api_request("OVERVIEW", params)
|
return _make_api_request("OVERVIEW", params)
|
||||||
|
|
||||||
|
|
||||||
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||||
"""
|
"""Retrieve balance sheet data for a given ticker symbol using Alpha Vantage."""
|
||||||
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage.
|
result = _make_api_request("BALANCE_SHEET", {"symbol": ticker})
|
||||||
|
return _filter_reports_by_date(result, curr_date)
|
||||||
Args:
|
|
||||||
ticker (str): Ticker symbol of the company
|
|
||||||
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
|
|
||||||
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Balance sheet data with normalized fields
|
|
||||||
"""
|
|
||||||
params = {
|
|
||||||
"symbol": ticker,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _make_api_request("BALANCE_SHEET", params)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||||
"""
|
"""Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage."""
|
||||||
Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage.
|
result = _make_api_request("CASH_FLOW", {"symbol": ticker})
|
||||||
|
return _filter_reports_by_date(result, curr_date)
|
||||||
Args:
|
|
||||||
ticker (str): Ticker symbol of the company
|
|
||||||
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
|
|
||||||
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Cash flow statement data with normalized fields
|
|
||||||
"""
|
|
||||||
params = {
|
|
||||||
"symbol": ticker,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _make_api_request("CASH_FLOW", params)
|
|
||||||
|
|
||||||
|
|
||||||
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||||
"""
|
"""Retrieve income statement data for a given ticker symbol using Alpha Vantage."""
|
||||||
Retrieve income statement data for a given ticker symbol using Alpha Vantage.
|
result = _make_api_request("INCOME_STATEMENT", {"symbol": ticker})
|
||||||
|
return _filter_reports_by_date(result, curr_date)
|
||||||
Args:
|
|
||||||
ticker (str): Ticker symbol of the company
|
|
||||||
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
|
|
||||||
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Income statement data with normalized fields
|
|
||||||
"""
|
|
||||||
params = {
|
|
||||||
"symbol": ticker,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _make_api_request("INCOME_STATEMENT", params)
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from stockstats import wrap
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
import os
|
import os
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
|
from .utils import safe_ticker_component
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -44,6 +45,68 @@ def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
||||||
|
"""Fetch OHLCV data with caching, filtered to prevent look-ahead bias.
|
||||||
|
|
||||||
|
Downloads 15 years of data up to today and caches per symbol. On
|
||||||
|
subsequent calls the cache is reused. Rows after curr_date are
|
||||||
|
filtered out so backtests never see future prices.
|
||||||
|
"""
|
||||||
|
# Reject ticker values that would escape the cache directory when
|
||||||
|
# interpolated into the cache filename (e.g. ``../../tmp/x``).
|
||||||
|
safe_symbol = safe_ticker_component(symbol)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
curr_date_dt = pd.to_datetime(curr_date)
|
||||||
|
|
||||||
|
# Cache uses a fixed window (15y to today) so one file per symbol
|
||||||
|
today_date = pd.Timestamp.today()
|
||||||
|
start_date = today_date - pd.DateOffset(years=5)
|
||||||
|
start_str = start_date.strftime("%Y-%m-%d")
|
||||||
|
end_str = today_date.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||||
|
data_file = os.path.join(
|
||||||
|
config["data_cache_dir"],
|
||||||
|
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.exists(data_file):
|
||||||
|
data = pd.read_csv(data_file, on_bad_lines="skip", encoding="utf-8")
|
||||||
|
else:
|
||||||
|
data = yf_retry(lambda: yf.download(
|
||||||
|
symbol,
|
||||||
|
start=start_str,
|
||||||
|
end=end_str,
|
||||||
|
multi_level_index=False,
|
||||||
|
progress=False,
|
||||||
|
auto_adjust=True,
|
||||||
|
))
|
||||||
|
data = data.reset_index()
|
||||||
|
data.to_csv(data_file, index=False, encoding="utf-8")
|
||||||
|
|
||||||
|
data = _clean_dataframe(data)
|
||||||
|
|
||||||
|
# Filter to curr_date to prevent look-ahead bias in backtesting
|
||||||
|
data = data[data["Date"] <= curr_date_dt]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def filter_financials_by_date(data: pd.DataFrame, curr_date: str) -> pd.DataFrame:
|
||||||
|
"""Drop financial statement columns (fiscal period timestamps) after curr_date.
|
||||||
|
|
||||||
|
yfinance financial statements use fiscal period end dates as columns.
|
||||||
|
Columns after curr_date represent future data and are removed to
|
||||||
|
prevent look-ahead bias.
|
||||||
|
"""
|
||||||
|
if not curr_date or data.empty:
|
||||||
|
return data
|
||||||
|
cutoff = pd.Timestamp(curr_date)
|
||||||
|
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
|
||||||
|
return data.loc[:, mask]
|
||||||
|
|
||||||
|
|
||||||
class StockstatsUtils:
|
class StockstatsUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_stock_stats(
|
def get_stock_stats(
|
||||||
@@ -55,42 +118,10 @@ class StockstatsUtils:
|
|||||||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
config = get_config()
|
data = load_ohlcv(symbol, curr_date)
|
||||||
|
|
||||||
today_date = pd.Timestamp.today()
|
|
||||||
curr_date_dt = pd.to_datetime(curr_date)
|
|
||||||
|
|
||||||
end_date = today_date
|
|
||||||
start_date = today_date - pd.DateOffset(years=15)
|
|
||||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
|
||||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
# Ensure cache directory exists
|
|
||||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
|
||||||
|
|
||||||
data_file = os.path.join(
|
|
||||||
config["data_cache_dir"],
|
|
||||||
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.exists(data_file):
|
|
||||||
data = pd.read_csv(data_file, on_bad_lines="skip")
|
|
||||||
else:
|
|
||||||
data = yf_retry(lambda: yf.download(
|
|
||||||
symbol,
|
|
||||||
start=start_date_str,
|
|
||||||
end=end_date_str,
|
|
||||||
multi_level_index=False,
|
|
||||||
progress=False,
|
|
||||||
auto_adjust=True,
|
|
||||||
))
|
|
||||||
data = data.reset_index()
|
|
||||||
data.to_csv(data_file, index=False)
|
|
||||||
|
|
||||||
data = _clean_dataframe(data)
|
|
||||||
df = wrap(data)
|
df = wrap(data)
|
||||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||||
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
|
curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
df[indicator] # trigger stockstats to calculate the indicator
|
df[indicator] # trigger stockstats to calculate the indicator
|
||||||
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
|
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datetime import date, timedelta, datetime
|
from datetime import date, timedelta, datetime
|
||||||
@@ -6,9 +7,43 @@ from typing import Annotated
|
|||||||
|
|
||||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||||
|
|
||||||
|
# Tickers can contain letters, digits, dot, dash, underscore, and caret
|
||||||
|
# (for index symbols like ^GSPC). Anything else is rejected so the value
|
||||||
|
# never escapes a containing directory when interpolated into a path.
|
||||||
|
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
|
||||||
|
"""Validate ``value`` is safe to interpolate into a filesystem path.
|
||||||
|
|
||||||
|
Tickers come from user CLI input or from LLM tool calls, both of which
|
||||||
|
can be influenced by attacker-controlled content (e.g. prompt injection
|
||||||
|
embedded in fetched news). Without validation, a value like
|
||||||
|
``"../../../etc/foo"`` flows into ``os.path.join`` / ``Path /`` and
|
||||||
|
escapes the configured cache, checkpoint, or results directory.
|
||||||
|
|
||||||
|
Returns ``value`` unchanged when it matches the allowed pattern; raises
|
||||||
|
``ValueError`` otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str) or not value:
|
||||||
|
raise ValueError(f"ticker must be a non-empty string, got {value!r}")
|
||||||
|
if len(value) > max_len:
|
||||||
|
raise ValueError(f"ticker exceeds {max_len} chars: {value!r}")
|
||||||
|
if not _TICKER_PATH_RE.fullmatch(value):
|
||||||
|
raise ValueError(
|
||||||
|
f"ticker contains characters not allowed in a filesystem path: {value!r}"
|
||||||
|
)
|
||||||
|
# The regex above allows '.', so values like '.', '..', '...' would pass,
|
||||||
|
# and as a path component they traverse the parent directory. Reject any
|
||||||
|
# value that's only dots.
|
||||||
|
if set(value) == {"."}:
|
||||||
|
raise ValueError(f"ticker cannot consist solely of dots: {value!r}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
||||||
if save_path:
|
if save_path:
|
||||||
data.to_csv(save_path)
|
data.to_csv(save_path, encoding="utf-8")
|
||||||
print(f"{tag} saved to {save_path}")
|
print(f"{tag} saved to {save_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
import os
|
import os
|
||||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
|
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date
|
||||||
|
|
||||||
def get_YFin_data_online(
|
def get_YFin_data_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
@@ -194,58 +195,9 @@ def _get_stock_stats_bulk(
|
|||||||
Fetches data once and calculates indicator for all available dates.
|
Fetches data once and calculates indicator for all available dates.
|
||||||
Returns dict mapping date strings to indicator values.
|
Returns dict mapping date strings to indicator values.
|
||||||
"""
|
"""
|
||||||
from .config import get_config
|
|
||||||
import pandas as pd
|
|
||||||
from stockstats import wrap
|
from stockstats import wrap
|
||||||
import os
|
|
||||||
|
|
||||||
config = get_config()
|
data = load_ohlcv(symbol, curr_date)
|
||||||
online = config["data_vendors"]["technical_indicators"] != "local"
|
|
||||||
|
|
||||||
if not online:
|
|
||||||
# Local data path
|
|
||||||
try:
|
|
||||||
data = pd.read_csv(
|
|
||||||
os.path.join(
|
|
||||||
config.get("data_cache_dir", "data"),
|
|
||||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
|
||||||
),
|
|
||||||
on_bad_lines="skip",
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
|
|
||||||
else:
|
|
||||||
# Online data fetching with caching
|
|
||||||
today_date = pd.Timestamp.today()
|
|
||||||
curr_date_dt = pd.to_datetime(curr_date)
|
|
||||||
|
|
||||||
end_date = today_date
|
|
||||||
start_date = today_date - pd.DateOffset(years=15)
|
|
||||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
|
||||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
|
||||||
|
|
||||||
data_file = os.path.join(
|
|
||||||
config["data_cache_dir"],
|
|
||||||
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.exists(data_file):
|
|
||||||
data = pd.read_csv(data_file, on_bad_lines="skip")
|
|
||||||
else:
|
|
||||||
data = yf_retry(lambda: yf.download(
|
|
||||||
symbol,
|
|
||||||
start=start_date_str,
|
|
||||||
end=end_date_str,
|
|
||||||
multi_level_index=False,
|
|
||||||
progress=False,
|
|
||||||
auto_adjust=True,
|
|
||||||
))
|
|
||||||
data = data.reset_index()
|
|
||||||
data.to_csv(data_file, index=False)
|
|
||||||
|
|
||||||
data = _clean_dataframe(data)
|
|
||||||
df = wrap(data)
|
df = wrap(data)
|
||||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
@@ -353,7 +305,7 @@ def get_fundamentals(
|
|||||||
def get_balance_sheet(
|
def get_balance_sheet(
|
||||||
ticker: Annotated[str, "ticker symbol of the company"],
|
ticker: Annotated[str, "ticker symbol of the company"],
|
||||||
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||||
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
):
|
):
|
||||||
"""Get balance sheet data from yfinance."""
|
"""Get balance sheet data from yfinance."""
|
||||||
try:
|
try:
|
||||||
@@ -364,6 +316,8 @@ def get_balance_sheet(
|
|||||||
else:
|
else:
|
||||||
data = yf_retry(lambda: ticker_obj.balance_sheet)
|
data = yf_retry(lambda: ticker_obj.balance_sheet)
|
||||||
|
|
||||||
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No balance sheet data found for symbol '{ticker}'"
|
return f"No balance sheet data found for symbol '{ticker}'"
|
||||||
|
|
||||||
@@ -383,7 +337,7 @@ def get_balance_sheet(
|
|||||||
def get_cashflow(
|
def get_cashflow(
|
||||||
ticker: Annotated[str, "ticker symbol of the company"],
|
ticker: Annotated[str, "ticker symbol of the company"],
|
||||||
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||||
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
):
|
):
|
||||||
"""Get cash flow data from yfinance."""
|
"""Get cash flow data from yfinance."""
|
||||||
try:
|
try:
|
||||||
@@ -394,6 +348,8 @@ def get_cashflow(
|
|||||||
else:
|
else:
|
||||||
data = yf_retry(lambda: ticker_obj.cashflow)
|
data = yf_retry(lambda: ticker_obj.cashflow)
|
||||||
|
|
||||||
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No cash flow data found for symbol '{ticker}'"
|
return f"No cash flow data found for symbol '{ticker}'"
|
||||||
|
|
||||||
@@ -413,7 +369,7 @@ def get_cashflow(
|
|||||||
def get_income_statement(
|
def get_income_statement(
|
||||||
ticker: Annotated[str, "ticker symbol of the company"],
|
ticker: Annotated[str, "ticker symbol of the company"],
|
||||||
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||||
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
):
|
):
|
||||||
"""Get income statement data from yfinance."""
|
"""Get income statement data from yfinance."""
|
||||||
try:
|
try:
|
||||||
@@ -424,6 +380,8 @@ def get_income_statement(
|
|||||||
else:
|
else:
|
||||||
data = yf_retry(lambda: ticker_obj.income_stmt)
|
data = yf_retry(lambda: ticker_obj.income_stmt)
|
||||||
|
|
||||||
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No income statement data found for symbol '{ticker}'"
|
return f"No income statement data found for symbol '{ticker}'"
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import yfinance as yf
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
|
from .stockstats_utils import yf_retry
|
||||||
|
|
||||||
|
|
||||||
def _extract_article_data(article: dict) -> dict:
|
def _extract_article_data(article: dict) -> dict:
|
||||||
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
|
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
|
||||||
@@ -64,7 +66,7 @@ def get_news_yfinance(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
stock = yf.Ticker(ticker)
|
stock = yf.Ticker(ticker)
|
||||||
news = stock.get_news(count=20)
|
news = yf_retry(lambda: stock.get_news(count=20))
|
||||||
|
|
||||||
if not news:
|
if not news:
|
||||||
return f"No news found for {ticker}"
|
return f"No news found for {ticker}"
|
||||||
@@ -131,11 +133,11 @@ def get_global_news_yfinance(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for query in search_queries:
|
for query in search_queries:
|
||||||
search = yf.Search(
|
search = yf_retry(lambda q=query: yf.Search(
|
||||||
query=query,
|
query=q,
|
||||||
news_count=limit,
|
news_count=limit,
|
||||||
enable_fuzzy_query=True,
|
enable_fuzzy_query=True,
|
||||||
)
|
))
|
||||||
|
|
||||||
if search.news:
|
if search.news:
|
||||||
for article in search.news:
|
for article in search.news:
|
||||||
@@ -167,6 +169,11 @@ def get_global_news_yfinance(
|
|||||||
# Handle both flat and nested structures
|
# Handle both flat and nested structures
|
||||||
if "content" in article:
|
if "content" in article:
|
||||||
data = _extract_article_data(article)
|
data = _extract_article_data(article)
|
||||||
|
# Skip articles published after curr_date (look-ahead guard)
|
||||||
|
if data.get("pub_date"):
|
||||||
|
pub_naive = data["pub_date"].replace(tzinfo=None) if hasattr(data["pub_date"], "replace") else data["pub_date"]
|
||||||
|
if pub_naive > curr_dt + relativedelta(days=1):
|
||||||
|
continue
|
||||||
title = data["title"]
|
title = data["title"]
|
||||||
publisher = data["publisher"]
|
publisher = data["publisher"]
|
||||||
link = data["link"]
|
link = data["link"]
|
||||||
|
|||||||
@@ -1,21 +1,36 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
|
||||||
|
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
||||||
"data_cache_dir": os.path.join(
|
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"memory_log_path": os.getenv("TRADINGAGENTS_MEMORY_LOG_PATH", os.path.join(_TRADINGAGENTS_HOME, "memory", "trading_memory.md")),
|
||||||
"dataflows/data_cache",
|
# Optional cap on the number of resolved memory log entries. When set,
|
||||||
),
|
# the oldest resolved entries are pruned once this limit is exceeded.
|
||||||
|
# Pending entries are never pruned. None disables rotation entirely.
|
||||||
|
"memory_log_max_entries": None,
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"llm_provider": "openai",
|
"llm_provider": "openai",
|
||||||
"deep_think_llm": "gpt-5.2",
|
"deep_think_llm": "gpt-5.4",
|
||||||
"quick_think_llm": "gpt-5-mini",
|
"quick_think_llm": "gpt-5.4-mini",
|
||||||
"backend_url": "https://api.openai.com/v1",
|
# When None, each provider's client falls back to its own default endpoint
|
||||||
|
# (api.openai.com for OpenAI, generativelanguage.googleapis.com for Gemini, ...).
|
||||||
|
# The CLI overrides this per provider when the user picks one. Keeping a
|
||||||
|
# provider-specific URL here would leak (e.g. OpenAI's /v1 was previously
|
||||||
|
# being forwarded to Gemini, producing malformed request URLs).
|
||||||
|
"backend_url": None,
|
||||||
# Provider-specific thinking configuration
|
# Provider-specific thinking configuration
|
||||||
"google_thinking_level": None, # "high", "minimal", etc.
|
"google_thinking_level": None, # "high", "minimal", etc.
|
||||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||||
"anthropic_effort": None, # "high", "medium", "low"
|
"anthropic_effort": None, # "high", "medium", "low"
|
||||||
|
# Checkpoint/resume: when True, LangGraph saves state after each node
|
||||||
|
# so a crashed run can resume from the last successful step.
|
||||||
|
"checkpoint_enabled": False,
|
||||||
|
# Output language for analyst reports and final decision
|
||||||
|
# Internal agent debate stays in English for reasoning quality
|
||||||
|
"output_language": "English",
|
||||||
# Debate and discussion settings
|
# Debate and discussion settings
|
||||||
"max_debate_rounds": 1,
|
"max_debate_rounds": 1,
|
||||||
"max_risk_discuss_rounds": 1,
|
"max_risk_discuss_rounds": 1,
|
||||||
|
|||||||
90
tradingagents/graph/checkpointer.py
Normal file
90
tradingagents/graph/checkpointer.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""LangGraph checkpoint support for resumable analysis runs.
|
||||||
|
|
||||||
|
Per-ticker SQLite databases so concurrent tickers don't contend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
|
|
||||||
|
|
||||||
|
def _db_path(data_dir: str | Path, ticker: str) -> Path:
|
||||||
|
"""Return the SQLite checkpoint DB path for a ticker."""
|
||||||
|
# Reject ticker values that would escape the checkpoints directory.
|
||||||
|
safe = safe_ticker_component(ticker).upper()
|
||||||
|
p = Path(data_dir) / "checkpoints"
|
||||||
|
p.mkdir(parents=True, exist_ok=True)
|
||||||
|
return p / f"{safe}.db"
|
||||||
|
|
||||||
|
|
||||||
|
def thread_id(ticker: str, date: str) -> str:
|
||||||
|
"""Deterministic thread ID for a ticker+date pair."""
|
||||||
|
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
|
||||||
|
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
conn = sqlite3.connect(str(db), check_same_thread=False)
|
||||||
|
try:
|
||||||
|
saver = SqliteSaver(conn)
|
||||||
|
saver.setup()
|
||||||
|
yield saver
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
|
||||||
|
"""Check whether a resumable checkpoint exists for ticker+date."""
|
||||||
|
return checkpoint_step(data_dir, ticker, date) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
|
||||||
|
"""Return the step number of the latest checkpoint, or None if none exists."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
if not db.exists():
|
||||||
|
return None
|
||||||
|
tid = thread_id(ticker, date)
|
||||||
|
with get_checkpointer(data_dir, ticker) as saver:
|
||||||
|
config = {"configurable": {"thread_id": tid}}
|
||||||
|
cp = saver.get_tuple(config)
|
||||||
|
if cp is None:
|
||||||
|
return None
|
||||||
|
return cp.metadata.get("step")
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all_checkpoints(data_dir: str | Path) -> int:
|
||||||
|
"""Remove all checkpoint DBs. Returns number of files deleted."""
|
||||||
|
cp_dir = Path(data_dir) / "checkpoints"
|
||||||
|
if not cp_dir.exists():
|
||||||
|
return 0
|
||||||
|
dbs = list(cp_dir.glob("*.db"))
|
||||||
|
for db in dbs:
|
||||||
|
db.unlink()
|
||||||
|
return len(dbs)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
|
||||||
|
"""Remove checkpoint for a specific ticker+date by deleting the thread's rows."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
if not db.exists():
|
||||||
|
return
|
||||||
|
tid = thread_id(ticker, date)
|
||||||
|
conn = sqlite3.connect(str(db))
|
||||||
|
try:
|
||||||
|
for table in ("writes", "checkpoints"):
|
||||||
|
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
@@ -16,13 +16,14 @@ class Propagator:
|
|||||||
self.max_recur_limit = max_recur_limit
|
self.max_recur_limit = max_recur_limit
|
||||||
|
|
||||||
def create_initial_state(
|
def create_initial_state(
|
||||||
self, company_name: str, trade_date: str
|
self, company_name: str, trade_date: str, past_context: str = ""
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create the initial state for the agent graph."""
|
"""Create the initial state for the agent graph."""
|
||||||
return {
|
return {
|
||||||
"messages": [("human", company_name)],
|
"messages": [("human", company_name)],
|
||||||
"company_of_interest": company_name,
|
"company_of_interest": company_name,
|
||||||
"trade_date": str(trade_date),
|
"trade_date": str(trade_date),
|
||||||
|
"past_context": past_context,
|
||||||
"investment_debate_state": InvestDebateState(
|
"investment_debate_state": InvestDebateState(
|
||||||
{
|
{
|
||||||
"bull_history": "",
|
"bull_history": "",
|
||||||
|
|||||||
@@ -1,121 +1,53 @@
|
|||||||
# TradingAgents/graph/reflection.py
|
# TradingAgents/graph/reflection.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
|
|
||||||
class Reflector:
|
class Reflector:
|
||||||
"""Handles reflection on decisions and updating memory."""
|
"""Handles reflection on trading decisions."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: Any):
|
||||||
"""Initialize the reflector with an LLM."""
|
"""Initialize the reflector with an LLM."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
self.reflection_system_prompt = self._get_reflection_prompt()
|
self.log_reflection_prompt = self._get_log_reflection_prompt()
|
||||||
|
|
||||||
def _get_reflection_prompt(self) -> str:
|
def _get_log_reflection_prompt(self) -> str:
|
||||||
"""Get the system prompt for reflection."""
|
"""Concise prompt for reflect_on_final_decision (Phase B log entries).
|
||||||
return """
|
|
||||||
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
|
|
||||||
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
|
|
||||||
|
|
||||||
1. Reasoning:
|
Produces 2-4 sentences of plain prose — compact enough to be re-injected
|
||||||
- For each trading decision, determine whether it was correct or incorrect. A correct decision results in an increase in returns, while an incorrect decision does the opposite.
|
into future agent prompts without bloating the context window.
|
||||||
- Analyze the contributing factors to each success or mistake. Consider:
|
|
||||||
- Market intelligence.
|
|
||||||
- Technical indicators.
|
|
||||||
- Technical signals.
|
|
||||||
- Price movement analysis.
|
|
||||||
- Overall market data analysis
|
|
||||||
- News analysis.
|
|
||||||
- Social media and sentiment analysis.
|
|
||||||
- Fundamental data analysis.
|
|
||||||
- Weight the importance of each factor in the decision-making process.
|
|
||||||
|
|
||||||
2. Improvement:
|
|
||||||
- For any incorrect decisions, propose revisions to maximize returns.
|
|
||||||
- Provide a detailed list of corrective actions or improvements, including specific recommendations (e.g., changing a decision from HOLD to BUY on a particular date).
|
|
||||||
|
|
||||||
3. Summary:
|
|
||||||
- Summarize the lessons learned from the successes and mistakes.
|
|
||||||
- Highlight how these lessons can be adapted for future trading scenarios and draw connections between similar situations to apply the knowledge gained.
|
|
||||||
|
|
||||||
4. Query:
|
|
||||||
- Extract key insights from the summary into a concise sentence of no more than 1000 tokens.
|
|
||||||
- Ensure the condensed sentence captures the essence of the lessons and reasoning for easy reference.
|
|
||||||
|
|
||||||
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
|
|
||||||
"""
|
"""
|
||||||
|
return (
|
||||||
|
"You are a trading analyst reviewing your own past decision now that the outcome is known.\n"
|
||||||
|
"Write exactly 2-4 sentences of plain prose (no bullets, no headers, no markdown).\n\n"
|
||||||
|
"Cover in order:\n"
|
||||||
|
"1. Was the directional call correct? (cite the alpha figure)\n"
|
||||||
|
"2. Which part of the investment thesis held or failed?\n"
|
||||||
|
"3. One concrete lesson to apply to the next similar analysis.\n\n"
|
||||||
|
"Be specific and terse. Your output will be stored verbatim in a decision log "
|
||||||
|
"and re-read by future analysts, so every word must earn its place."
|
||||||
|
)
|
||||||
|
|
||||||
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
|
def reflect_on_final_decision(
|
||||||
"""Extract the current market situation from the state."""
|
self,
|
||||||
curr_market_report = current_state["market_report"]
|
final_decision: str,
|
||||||
curr_sentiment_report = current_state["sentiment_report"]
|
raw_return: float,
|
||||||
curr_news_report = current_state["news_report"]
|
alpha_return: float,
|
||||||
curr_fundamentals_report = current_state["fundamentals_report"]
|
|
||||||
|
|
||||||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
|
||||||
|
|
||||||
def _reflect_on_component(
|
|
||||||
self, component_type: str, report: str, situation: str, returns_losses
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate reflection for a component."""
|
"""Single reflection call on the final trade decision with outcome context.
|
||||||
|
|
||||||
|
Used by Phase B deferred reflection. The final_trade_decision already
|
||||||
|
synthesises all analyst insights, so no separate market context is needed.
|
||||||
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
("system", self.reflection_system_prompt),
|
("system", self.log_reflection_prompt),
|
||||||
(
|
(
|
||||||
"human",
|
"human",
|
||||||
f"Returns: {returns_losses}\n\nAnalysis/Decision: {report}\n\nObjective Market Reports for Reference: {situation}",
|
(
|
||||||
|
f"Raw return: {raw_return:+.1%}\n"
|
||||||
|
f"Alpha vs SPY: {alpha_return:+.1%}\n\n"
|
||||||
|
f"Final Decision:\n{final_decision}"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
return self.quick_thinking_llm.invoke(messages).content
|
||||||
result = self.quick_thinking_llm.invoke(messages).content
|
|
||||||
return result
|
|
||||||
|
|
||||||
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
|
|
||||||
"""Reflect on bull researcher's analysis and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"BULL", bull_debate_history, situation, returns_losses
|
|
||||||
)
|
|
||||||
bull_memory.add_situations([(situation, result)])
|
|
||||||
|
|
||||||
def reflect_bear_researcher(self, current_state, returns_losses, bear_memory):
|
|
||||||
"""Reflect on bear researcher's analysis and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"BEAR", bear_debate_history, situation, returns_losses
|
|
||||||
)
|
|
||||||
bear_memory.add_situations([(situation, result)])
|
|
||||||
|
|
||||||
def reflect_trader(self, current_state, returns_losses, trader_memory):
|
|
||||||
"""Reflect on trader's decision and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
trader_decision = current_state["trader_investment_plan"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"TRADER", trader_decision, situation, returns_losses
|
|
||||||
)
|
|
||||||
trader_memory.add_situations([(situation, result)])
|
|
||||||
|
|
||||||
def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory):
|
|
||||||
"""Reflect on investment judge's decision and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"INVEST JUDGE", judge_decision, situation, returns_losses
|
|
||||||
)
|
|
||||||
invest_judge_memory.add_situations([(situation, result)])
|
|
||||||
|
|
||||||
def reflect_portfolio_manager(self, current_state, returns_losses, portfolio_manager_memory):
|
|
||||||
"""Reflect on portfolio manager's decision and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"PORTFOLIO MANAGER", judge_decision, situation, returns_losses
|
|
||||||
)
|
|
||||||
portfolio_manager_memory.add_situations([(situation, result)])
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
# TradingAgents/graph/setup.py
|
# TradingAgents/graph/setup.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Any, Dict
|
||||||
from langchain_openai import ChatOpenAI
|
from langgraph.graph import END, START, StateGraph
|
||||||
from langgraph.graph import END, StateGraph, START
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
@@ -16,25 +15,15 @@ class GraphSetup:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quick_thinking_llm: ChatOpenAI,
|
quick_thinking_llm: Any,
|
||||||
deep_thinking_llm: ChatOpenAI,
|
deep_thinking_llm: Any,
|
||||||
tool_nodes: Dict[str, ToolNode],
|
tool_nodes: Dict[str, ToolNode],
|
||||||
bull_memory,
|
|
||||||
bear_memory,
|
|
||||||
trader_memory,
|
|
||||||
invest_judge_memory,
|
|
||||||
portfolio_manager_memory,
|
|
||||||
conditional_logic: ConditionalLogic,
|
conditional_logic: ConditionalLogic,
|
||||||
):
|
):
|
||||||
"""Initialize with required components."""
|
"""Initialize with required components."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
self.deep_thinking_llm = deep_thinking_llm
|
self.deep_thinking_llm = deep_thinking_llm
|
||||||
self.tool_nodes = tool_nodes
|
self.tool_nodes = tool_nodes
|
||||||
self.bull_memory = bull_memory
|
|
||||||
self.bear_memory = bear_memory
|
|
||||||
self.trader_memory = trader_memory
|
|
||||||
self.invest_judge_memory = invest_judge_memory
|
|
||||||
self.portfolio_manager_memory = portfolio_manager_memory
|
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
@@ -86,24 +75,16 @@ class GraphSetup:
|
|||||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||||
|
|
||||||
# Create researcher and manager nodes
|
# Create researcher and manager nodes
|
||||||
bull_researcher_node = create_bull_researcher(
|
bull_researcher_node = create_bull_researcher(self.quick_thinking_llm)
|
||||||
self.quick_thinking_llm, self.bull_memory
|
bear_researcher_node = create_bear_researcher(self.quick_thinking_llm)
|
||||||
)
|
research_manager_node = create_research_manager(self.deep_thinking_llm)
|
||||||
bear_researcher_node = create_bear_researcher(
|
trader_node = create_trader(self.quick_thinking_llm)
|
||||||
self.quick_thinking_llm, self.bear_memory
|
|
||||||
)
|
|
||||||
research_manager_node = create_research_manager(
|
|
||||||
self.deep_thinking_llm, self.invest_judge_memory
|
|
||||||
)
|
|
||||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
|
||||||
|
|
||||||
# Create risk analysis nodes
|
# Create risk analysis nodes
|
||||||
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm)
|
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm)
|
||||||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||||
conservative_analyst = create_conservative_debator(self.quick_thinking_llm)
|
conservative_analyst = create_conservative_debator(self.quick_thinking_llm)
|
||||||
portfolio_manager_node = create_portfolio_manager(
|
portfolio_manager_node = create_portfolio_manager(self.deep_thinking_llm)
|
||||||
self.deep_thinking_llm, self.portfolio_manager_memory
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create workflow
|
# Create workflow
|
||||||
workflow = StateGraph(AgentState)
|
workflow = StateGraph(AgentState)
|
||||||
@@ -198,5 +179,4 @@ class GraphSetup:
|
|||||||
|
|
||||||
workflow.add_edge("Portfolio Manager", END)
|
workflow.add_edge("Portfolio Manager", END)
|
||||||
|
|
||||||
# Compile and return
|
return workflow
|
||||||
return workflow.compile()
|
|
||||||
|
|||||||
@@ -1,33 +1,31 @@
|
|||||||
# TradingAgents/graph/signal_processing.py
|
"""Extract the 5-tier portfolio rating from the Portfolio Manager's decision.
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
The Portfolio Manager produces a typed ``PortfolioDecision`` via structured
|
||||||
|
output and renders it to markdown that always carries a ``**Rating**: X``
|
||||||
|
header (see :func:`tradingagents.agents.schemas.render_pm_decision`). The
|
||||||
|
deterministic heuristic in :mod:`tradingagents.agents.utils.rating` is more
|
||||||
|
than sufficient to extract that rating; no extra LLM call is needed.
|
||||||
|
|
||||||
|
This module exists for backwards compatibility with callers that expect a
|
||||||
|
``SignalProcessor.process_signal(text)`` interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.rating import parse_rating
|
||||||
|
|
||||||
|
|
||||||
class SignalProcessor:
|
class SignalProcessor:
|
||||||
"""Processes trading signals to extract actionable decisions."""
|
"""Read the 5-tier rating out of a Portfolio Manager decision."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: Any = None):
|
||||||
"""Initialize with an LLM for processing."""
|
# The LLM argument is accepted for backwards compatibility but no
|
||||||
|
# longer used: the PM's structured output guarantees the rating is
|
||||||
|
# parseable from the rendered markdown without a second LLM call.
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
|
|
||||||
def process_signal(self, full_signal: str) -> str:
|
def process_signal(self, full_signal: str) -> str:
|
||||||
"""
|
"""Return one of Buy / Overweight / Hold / Underweight / Sell."""
|
||||||
Process a full trading signal to extract the core decision.
|
return parse_rating(full_signal)
|
||||||
|
|
||||||
Args:
|
|
||||||
full_signal: Complete trading signal text
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Extracted rating (BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, or SELL)
|
|
||||||
"""
|
|
||||||
messages = [
|
|
||||||
(
|
|
||||||
"system",
|
|
||||||
"You are an efficient assistant that extracts the trading decision from analyst reports. "
|
|
||||||
"Extract the rating as exactly one of: BUY, OVERWEIGHT, HOLD, UNDERWEIGHT, SELL. "
|
|
||||||
"Output only the single rating word, nothing else.",
|
|
||||||
),
|
|
||||||
("human", full_signal),
|
|
||||||
]
|
|
||||||
|
|
||||||
return self.quick_thinking_llm.invoke(messages).content
|
|
||||||
|
|||||||
@@ -1,18 +1,24 @@
|
|||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from datetime import date
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
from typing import Dict, Any, Tuple, List, Optional
|
||||||
|
|
||||||
|
import yfinance as yf
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.llm_clients import create_llm_client
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
from tradingagents.agents.utils.agent_states import (
|
from tradingagents.agents.utils.agent_states import (
|
||||||
AgentState,
|
AgentState,
|
||||||
InvestDebateState,
|
InvestDebateState,
|
||||||
@@ -33,6 +39,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||||||
get_global_news
|
get_global_news
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
from .setup import GraphSetup
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
@@ -66,10 +73,8 @@ class TradingAgentsGraph:
|
|||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
||||||
# Create necessary directories
|
# Create necessary directories
|
||||||
os.makedirs(
|
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
|
||||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
os.makedirs(self.config["results_dir"], exist_ok=True)
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize LLMs with provider-specific thinking configuration
|
# Initialize LLMs with provider-specific thinking configuration
|
||||||
llm_kwargs = self._get_provider_kwargs()
|
llm_kwargs = self._get_provider_kwargs()
|
||||||
@@ -94,12 +99,7 @@ class TradingAgentsGraph:
|
|||||||
self.deep_thinking_llm = deep_client.get_llm()
|
self.deep_thinking_llm = deep_client.get_llm()
|
||||||
self.quick_thinking_llm = quick_client.get_llm()
|
self.quick_thinking_llm = quick_client.get_llm()
|
||||||
|
|
||||||
# Initialize memories
|
self.memory_log = TradingMemoryLog(self.config)
|
||||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
|
||||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
|
||||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
|
||||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
|
||||||
self.portfolio_manager_memory = FinancialSituationMemory("portfolio_manager_memory", self.config)
|
|
||||||
|
|
||||||
# Create tool nodes
|
# Create tool nodes
|
||||||
self.tool_nodes = self._create_tool_nodes()
|
self.tool_nodes = self._create_tool_nodes()
|
||||||
@@ -113,11 +113,6 @@ class TradingAgentsGraph:
|
|||||||
self.quick_thinking_llm,
|
self.quick_thinking_llm,
|
||||||
self.deep_thinking_llm,
|
self.deep_thinking_llm,
|
||||||
self.tool_nodes,
|
self.tool_nodes,
|
||||||
self.bull_memory,
|
|
||||||
self.bear_memory,
|
|
||||||
self.trader_memory,
|
|
||||||
self.invest_judge_memory,
|
|
||||||
self.portfolio_manager_memory,
|
|
||||||
self.conditional_logic,
|
self.conditional_logic,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,8 +125,10 @@ class TradingAgentsGraph:
|
|||||||
self.ticker = None
|
self.ticker = None
|
||||||
self.log_states_dict = {} # date to full state dict
|
self.log_states_dict = {} # date to full state dict
|
||||||
|
|
||||||
# Set up the graph
|
# Set up the graph: keep the workflow for recompilation with a checkpointer.
|
||||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
self.workflow = self.graph_setup.setup_graph(selected_analysts)
|
||||||
|
self.graph = self.workflow.compile()
|
||||||
|
self._checkpointer_ctx = None
|
||||||
|
|
||||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||||
"""Get provider-specific kwargs for LLM client creation."""
|
"""Get provider-specific kwargs for LLM client creation."""
|
||||||
@@ -191,19 +188,133 @@ class TradingAgentsGraph:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def propagate(self, company_name, trade_date):
|
def _fetch_returns(
|
||||||
"""Run the trading agents graph for a company on a specific date."""
|
self, ticker: str, trade_date: str, holding_days: int = 5
|
||||||
|
) -> Tuple[Optional[float], Optional[float], Optional[int]]:
|
||||||
|
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
|
||||||
|
|
||||||
|
Returns (raw_return, alpha_return, actual_holding_days) or
|
||||||
|
(None, None, None) if price data is unavailable (too recent, delisted,
|
||||||
|
or network error).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start = datetime.strptime(trade_date, "%Y-%m-%d")
|
||||||
|
end = start + timedelta(days=holding_days + 7) # buffer for weekends/holidays
|
||||||
|
end_str = end.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
stock = yf.Ticker(ticker).history(start=trade_date, end=end_str)
|
||||||
|
spy = yf.Ticker("SPY").history(start=trade_date, end=end_str)
|
||||||
|
|
||||||
|
if len(stock) < 2 or len(spy) < 2:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
actual_days = min(holding_days, len(stock) - 1, len(spy) - 1)
|
||||||
|
raw = float(
|
||||||
|
(stock["Close"].iloc[actual_days] - stock["Close"].iloc[0])
|
||||||
|
/ stock["Close"].iloc[0]
|
||||||
|
)
|
||||||
|
spy_ret = float(
|
||||||
|
(spy["Close"].iloc[actual_days] - spy["Close"].iloc[0])
|
||||||
|
/ spy["Close"].iloc[0]
|
||||||
|
)
|
||||||
|
alpha = raw - spy_ret
|
||||||
|
return raw, alpha, actual_days
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Could not resolve outcome for %s on %s (will retry next run): %s",
|
||||||
|
ticker, trade_date, e,
|
||||||
|
)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
def _resolve_pending_entries(self, ticker: str) -> None:
|
||||||
|
"""Resolve pending log entries for ticker at the start of a new run.
|
||||||
|
|
||||||
|
Fetches returns for each same-ticker pending entry, generates reflections,
|
||||||
|
then writes all updates in a single atomic batch write to avoid redundant I/O.
|
||||||
|
Skips entries whose price data is not yet available (too recent or delisted).
|
||||||
|
|
||||||
|
Trade-off: only same-ticker entries are resolved per run. Entries for
|
||||||
|
other tickers accumulate until that ticker is run again.
|
||||||
|
"""
|
||||||
|
pending = [e for e in self.memory_log.get_pending_entries() if e["ticker"] == ticker]
|
||||||
|
if not pending:
|
||||||
|
return
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
for entry in pending:
|
||||||
|
raw, alpha, days = self._fetch_returns(ticker, entry["date"])
|
||||||
|
if raw is None:
|
||||||
|
continue # price not available yet — try again next run
|
||||||
|
reflection = self.reflector.reflect_on_final_decision(
|
||||||
|
final_decision=entry.get("decision", ""),
|
||||||
|
raw_return=raw,
|
||||||
|
alpha_return=alpha,
|
||||||
|
)
|
||||||
|
updates.append({
|
||||||
|
"ticker": ticker,
|
||||||
|
"trade_date": entry["date"],
|
||||||
|
"raw_return": raw,
|
||||||
|
"alpha_return": alpha,
|
||||||
|
"holding_days": days,
|
||||||
|
"reflection": reflection,
|
||||||
|
})
|
||||||
|
|
||||||
|
if updates:
|
||||||
|
self.memory_log.batch_update_with_outcomes(updates)
|
||||||
|
|
||||||
|
def propagate(self, company_name, trade_date):
|
||||||
|
"""Run the trading agents graph for a company on a specific date.
|
||||||
|
|
||||||
|
When ``checkpoint_enabled`` is set in config, the graph is recompiled
|
||||||
|
with a per-ticker SqliteSaver so a crashed run can resume from the last
|
||||||
|
successful node on a subsequent invocation with the same ticker+date.
|
||||||
|
"""
|
||||||
self.ticker = company_name
|
self.ticker = company_name
|
||||||
|
|
||||||
# Initialize state
|
# Resolve any pending memory-log entries for this ticker before the pipeline runs.
|
||||||
|
self._resolve_pending_entries(company_name)
|
||||||
|
|
||||||
|
# Recompile with a checkpointer if the user opted in.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
self._checkpointer_ctx = get_checkpointer(
|
||||||
|
self.config["data_cache_dir"], company_name
|
||||||
|
)
|
||||||
|
saver = self._checkpointer_ctx.__enter__()
|
||||||
|
self.graph = self.workflow.compile(checkpointer=saver)
|
||||||
|
|
||||||
|
step = checkpoint_step(
|
||||||
|
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||||
|
)
|
||||||
|
if step is not None:
|
||||||
|
logger.info(
|
||||||
|
"Resuming from step %d for %s on %s", step, company_name, trade_date
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Starting fresh for %s on %s", company_name, trade_date)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._run_graph(company_name, trade_date)
|
||||||
|
finally:
|
||||||
|
if self._checkpointer_ctx is not None:
|
||||||
|
self._checkpointer_ctx.__exit__(None, None, None)
|
||||||
|
self._checkpointer_ctx = None
|
||||||
|
self.graph = self.workflow.compile()
|
||||||
|
|
||||||
|
def _run_graph(self, company_name, trade_date):
|
||||||
|
"""Execute the graph and write the resulting state to disk and memory log."""
|
||||||
|
# Initialize state — inject memory log context for PM.
|
||||||
|
past_context = self.memory_log.get_past_context(company_name)
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date
|
company_name, trade_date, past_context=past_context
|
||||||
)
|
)
|
||||||
args = self.propagator.get_graph_args()
|
args = self.propagator.get_graph_args()
|
||||||
|
|
||||||
|
# Inject thread_id so same ticker+date resumes, different date starts fresh.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
tid = thread_id(company_name, str(trade_date))
|
||||||
|
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
# Debug mode with tracing
|
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in self.graph.stream(init_agent_state, **args):
|
for chunk in self.graph.stream(init_agent_state, **args):
|
||||||
if len(chunk["messages"]) == 0:
|
if len(chunk["messages"]) == 0:
|
||||||
@@ -211,19 +322,29 @@ class TradingAgentsGraph:
|
|||||||
else:
|
else:
|
||||||
chunk["messages"][-1].pretty_print()
|
chunk["messages"][-1].pretty_print()
|
||||||
trace.append(chunk)
|
trace.append(chunk)
|
||||||
|
|
||||||
final_state = trace[-1]
|
final_state = trace[-1]
|
||||||
else:
|
else:
|
||||||
# Standard mode without tracing
|
|
||||||
final_state = self.graph.invoke(init_agent_state, **args)
|
final_state = self.graph.invoke(init_agent_state, **args)
|
||||||
|
|
||||||
# Store current state for reflection
|
# Store current state for reflection.
|
||||||
self.curr_state = final_state
|
self.curr_state = final_state
|
||||||
|
|
||||||
# Log state
|
# Log state to disk.
|
||||||
self._log_state(trade_date, final_state)
|
self._log_state(trade_date, final_state)
|
||||||
|
|
||||||
# Return decision and processed signal
|
# Store decision for deferred reflection on the next same-ticker run.
|
||||||
|
self.memory_log.store_decision(
|
||||||
|
ticker=company_name,
|
||||||
|
trade_date=trade_date,
|
||||||
|
final_trade_decision=final_state["final_trade_decision"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear checkpoint on successful completion to avoid stale state.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
clear_checkpoint(
|
||||||
|
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||||
|
)
|
||||||
|
|
||||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||||
|
|
||||||
def _log_state(self, trade_date, final_state):
|
def _log_state(self, trade_date, final_state):
|
||||||
@@ -258,34 +379,15 @@ class TradingAgentsGraph:
|
|||||||
"final_trade_decision": final_state["final_trade_decision"],
|
"final_trade_decision": final_state["final_trade_decision"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save to file
|
# Save to file. Reject ticker values that would escape the
|
||||||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
# results directory when joined as a path component.
|
||||||
|
safe_ticker = safe_ticker_component(self.ticker)
|
||||||
|
directory = Path(self.config["results_dir"]) / safe_ticker / "TradingAgentsStrategy_logs"
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(
|
log_path = directory / f"full_states_log_{trade_date}.json"
|
||||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
with open(log_path, "w", encoding="utf-8") as f:
|
||||||
"w",
|
json.dump(self.log_states_dict[str(trade_date)], f, indent=4)
|
||||||
encoding="utf-8",
|
|
||||||
) as f:
|
|
||||||
json.dump(self.log_states_dict, f, indent=4)
|
|
||||||
|
|
||||||
def reflect_and_remember(self, returns_losses):
|
|
||||||
"""Reflect on decisions and update memory based on returns."""
|
|
||||||
self.reflector.reflect_bull_researcher(
|
|
||||||
self.curr_state, returns_losses, self.bull_memory
|
|
||||||
)
|
|
||||||
self.reflector.reflect_bear_researcher(
|
|
||||||
self.curr_state, returns_losses, self.bear_memory
|
|
||||||
)
|
|
||||||
self.reflector.reflect_trader(
|
|
||||||
self.curr_state, returns_losses, self.trader_memory
|
|
||||||
)
|
|
||||||
self.reflector.reflect_invest_judge(
|
|
||||||
self.curr_state, returns_losses, self.invest_judge_memory
|
|
||||||
)
|
|
||||||
self.reflector.reflect_portfolio_manager(
|
|
||||||
self.curr_state, returns_losses, self.portfolio_manager_memory
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_signal(self, full_signal):
|
def process_signal(self, full_signal):
|
||||||
"""Process a signal to extract the core decision."""
|
"""Process a signal to extract the core decision."""
|
||||||
|
|||||||
@@ -5,20 +5,11 @@
|
|||||||
### 1. `validate_model()` is never called
|
### 1. `validate_model()` is never called
|
||||||
- Add validation call in `get_llm()` with warning (not error) for unknown models
|
- Add validation call in `get_llm()` with warning (not error) for unknown models
|
||||||
|
|
||||||
### 2. Inconsistent parameter handling
|
### 2. ~~Inconsistent parameter handling~~ (Fixed)
|
||||||
| Client | API Key Param | Special Params |
|
- GoogleClient now accepts unified `api_key` and maps it to `google_api_key`
|
||||||
|--------|---------------|----------------|
|
|
||||||
| OpenAI | `api_key` | `reasoning_effort` |
|
|
||||||
| Anthropic | `api_key` | `thinking_config` → `thinking` |
|
|
||||||
| Google | `google_api_key` | `thinking_budget` |
|
|
||||||
|
|
||||||
**Fix:** Standardize with unified `api_key` that maps to provider-specific keys
|
### 3. ~~`base_url` accepted but ignored~~ (Fixed)
|
||||||
|
- All clients now pass `base_url` to their respective LLM constructors
|
||||||
|
|
||||||
### 3. `base_url` accepted but ignored
|
### 4. ~~Update validators.py with models from CLI~~ (Fixed)
|
||||||
- `AnthropicClient`: accepts `base_url` but never uses it
|
- Synced in v0.2.2
|
||||||
- `GoogleClient`: accepts `base_url` but never uses it (correct - Google doesn't support it)
|
|
||||||
|
|
||||||
**Fix:** Remove unused `base_url` from clients that don't support it
|
|
||||||
|
|
||||||
### 4. Update validators.py with models from CLI
|
|
||||||
- Sync `VALID_MODELS` dict with CLI model options after Feature 2 is complete
|
|
||||||
|
|||||||
@@ -31,8 +31,12 @@ class AnthropicClient(BaseLLMClient):
|
|||||||
|
|
||||||
def get_llm(self) -> Any:
|
def get_llm(self) -> Any:
|
||||||
"""Return configured ChatAnthropic instance."""
|
"""Return configured ChatAnthropic instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
llm_kwargs = {"model": self.model}
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
|
if self.base_url:
|
||||||
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
for key in _PASSTHROUGH_KWARGS:
|
for key in _PASSTHROUGH_KWARGS:
|
||||||
if key in self.kwargs:
|
if key in self.kwargs:
|
||||||
llm_kwargs[key] = self.kwargs[key]
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|||||||
52
tradingagents/llm_clients/azure_client.py
Normal file
52
tradingagents/llm_clients/azure_client.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
from .validators import validate_model
|
||||||
|
|
||||||
|
_PASSTHROUGH_KWARGS = (
|
||||||
|
"timeout", "max_retries", "api_key", "reasoning_effort",
|
||||||
|
"callbacks", "http_client", "http_async_client",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
|
||||||
|
"""AzureChatOpenAI with normalized content output."""
|
||||||
|
|
||||||
|
def invoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIClient(BaseLLMClient):
|
||||||
|
"""Client for Azure OpenAI deployments.
|
||||||
|
|
||||||
|
Requires environment variables:
|
||||||
|
AZURE_OPENAI_API_KEY: API key
|
||||||
|
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
|
||||||
|
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
|
||||||
|
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured AzureChatOpenAI instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
|
||||||
|
llm_kwargs = {
|
||||||
|
"model": self.model,
|
||||||
|
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in _PASSTHROUGH_KWARGS:
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
return NormalizedAzureChatOpenAI(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Azure accepts any deployed model name."""
|
||||||
|
return True
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
def normalize_content(response):
|
def normalize_content(response):
|
||||||
@@ -29,6 +30,27 @@ class BaseLLMClient(ABC):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def get_provider_name(self) -> str:
|
||||||
|
"""Return the provider name used in warning messages."""
|
||||||
|
provider = getattr(self, "provider", None)
|
||||||
|
if provider:
|
||||||
|
return str(provider)
|
||||||
|
return self.__class__.__name__.removesuffix("Client").lower()
|
||||||
|
|
||||||
|
def warn_if_unknown_model(self) -> None:
|
||||||
|
"""Warn when the model is outside the known list for the provider."""
|
||||||
|
if self.validate_model():
|
||||||
|
return
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
(
|
||||||
|
f"Model '{self.model}' is not in the known model list for "
|
||||||
|
f"provider '{self.get_provider_name()}'. Continuing anyway."
|
||||||
|
),
|
||||||
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_llm(self) -> Any:
|
def get_llm(self) -> Any:
|
||||||
"""Return the configured LLM instance."""
|
"""Return the configured LLM instance."""
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base_client import BaseLLMClient
|
from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
|
||||||
from .anthropic_client import AnthropicClient
|
# Providers that use the OpenAI-compatible chat completions API
|
||||||
from .google_client import GoogleClient
|
_OPENAI_COMPATIBLE = (
|
||||||
|
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
@@ -14,17 +16,15 @@ def create_llm_client(
|
|||||||
) -> BaseLLMClient:
|
) -> BaseLLMClient:
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
|
Provider modules are imported lazily so that simply importing this
|
||||||
|
factory (e.g. during test collection) does not pull in heavy LLM SDKs
|
||||||
|
or fail when their API keys are absent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
provider: LLM provider name
|
||||||
model: Model name/identifier
|
model: Model name/identifier
|
||||||
base_url: Optional base URL for API endpoint
|
base_url: Optional base URL for API endpoint
|
||||||
**kwargs: Additional provider-specific arguments
|
**kwargs: Additional provider-specific arguments
|
||||||
- http_client: Custom httpx.Client for SSL proxy or certificate customization
|
|
||||||
- http_async_client: Custom httpx.AsyncClient for async operations
|
|
||||||
- timeout: Request timeout in seconds
|
|
||||||
- max_retries: Maximum retry attempts
|
|
||||||
- api_key: API key for the provider
|
|
||||||
- callbacks: LangChain callbacks
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured BaseLLMClient instance
|
Configured BaseLLMClient instance
|
||||||
@@ -34,16 +34,20 @@ def create_llm_client(
|
|||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
if provider_lower in _OPENAI_COMPATIBLE:
|
||||||
|
from .openai_client import OpenAIClient
|
||||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "xai":
|
|
||||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
|
||||||
|
|
||||||
if provider_lower == "anthropic":
|
if provider_lower == "anthropic":
|
||||||
|
from .anthropic_client import AnthropicClient
|
||||||
return AnthropicClient(model, base_url, **kwargs)
|
return AnthropicClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
|
from .google_client import GoogleClient
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "azure":
|
||||||
|
from .azure_client import AzureOpenAIClient
|
||||||
|
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|||||||
@@ -25,12 +25,21 @@ class GoogleClient(BaseLLMClient):
|
|||||||
|
|
||||||
def get_llm(self) -> Any:
|
def get_llm(self) -> Any:
|
||||||
"""Return configured ChatGoogleGenerativeAI instance."""
|
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
llm_kwargs = {"model": self.model}
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
for key in ("timeout", "max_retries", "google_api_key", "callbacks", "http_client", "http_async_client"):
|
if self.base_url:
|
||||||
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
|
for key in ("timeout", "max_retries", "callbacks", "http_client", "http_async_client"):
|
||||||
if key in self.kwargs:
|
if key in self.kwargs:
|
||||||
llm_kwargs[key] = self.kwargs[key]
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
# Unified api_key maps to provider-specific google_api_key
|
||||||
|
google_api_key = self.kwargs.get("api_key") or self.kwargs.get("google_api_key")
|
||||||
|
if google_api_key:
|
||||||
|
llm_kwargs["google_api_key"] = google_api_key
|
||||||
|
|
||||||
# Map thinking_level to appropriate API param based on model
|
# Map thinking_level to appropriate API param based on model
|
||||||
# Gemini 3 Pro: low, high
|
# Gemini 3 Pro: low, high
|
||||||
# Gemini 3 Flash: minimal, low, medium, high
|
# Gemini 3 Flash: minimal, low, medium, high
|
||||||
|
|||||||
136
tradingagents/llm_clients/model_catalog.py
Normal file
136
tradingagents/llm_clients/model_catalog.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""Shared model catalog for CLI selections and validation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
ModelOption = Tuple[str, str]
|
||||||
|
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_OPTIONS: ProviderModeOptions = {
|
||||||
|
"openai": {
|
||||||
|
"quick": [
|
||||||
|
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
|
||||||
|
("GPT-5.4 Nano - Cheapest, high-volume tasks", "gpt-5.4-nano"),
|
||||||
|
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||||
|
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||||
|
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
|
||||||
|
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
|
||||||
|
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"anthropic": {
|
||||||
|
"quick": [
|
||||||
|
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||||
|
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
|
||||||
|
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
|
||||||
|
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
||||||
|
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||||
|
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"google": {
|
||||||
|
"quick": [
|
||||||
|
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||||
|
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||||
|
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
|
||||||
|
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
|
||||||
|
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||||
|
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
|
||||||
|
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"xai": {
|
||||||
|
"quick": [
|
||||||
|
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||||
|
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||||
|
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Grok 4 - Flagship model", "grok-4-0709"),
|
||||||
|
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||||
|
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||||
|
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"quick": [
|
||||||
|
("DeepSeek V4 Flash - Latest V4 fast model", "deepseek-v4-flash"),
|
||||||
|
("DeepSeek V3.2", "deepseek-chat"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("DeepSeek V4 Pro - Latest V4 flagship model", "deepseek-v4-pro"),
|
||||||
|
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
|
||||||
|
("DeepSeek V3.2", "deepseek-chat"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"qwen": {
|
||||||
|
"quick": [
|
||||||
|
("Qwen 3.5 Flash", "qwen3.5-flash"),
|
||||||
|
("Qwen Plus", "qwen-plus"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Qwen 3.6 Plus", "qwen3.6-plus"),
|
||||||
|
("Qwen 3.5 Plus", "qwen3.5-plus"),
|
||||||
|
("Qwen 3 Max", "qwen3-max"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"glm": {
|
||||||
|
"quick": [
|
||||||
|
("GLM-4.7", "glm-4.7"),
|
||||||
|
("GLM-5", "glm-5"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("GLM-5.1", "glm-5.1"),
|
||||||
|
("GLM-5", "glm-5"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||||
|
"ollama": {
|
||||||
|
"quick": [
|
||||||
|
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||||
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
|
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||||
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
|
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
|
||||||
|
"""Return shared model options for a provider and selection mode."""
|
||||||
|
return MODEL_OPTIONS[provider.lower()][mode]
|
||||||
|
|
||||||
|
|
||||||
|
def get_known_models() -> Dict[str, List[str]]:
|
||||||
|
"""Build known model names from the shared CLI catalog."""
|
||||||
|
return {
|
||||||
|
provider: sorted(
|
||||||
|
{
|
||||||
|
value
|
||||||
|
for options in mode_options.values()
|
||||||
|
for _, value in options
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for provider, mode_options in MODEL_OPTIONS.items()
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from .base_client import BaseLLMClient, normalize_content
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
@@ -11,13 +12,97 @@ class NormalizedChatOpenAI(ChatOpenAI):
|
|||||||
"""ChatOpenAI with normalized content output.
|
"""ChatOpenAI with normalized content output.
|
||||||
|
|
||||||
The Responses API returns content as a list of typed blocks
|
The Responses API returns content as a list of typed blocks
|
||||||
(reasoning, text, etc.). This normalizes to string for consistent
|
(reasoning, text, etc.). ``invoke`` normalizes to string for
|
||||||
downstream handling.
|
consistent downstream handling. ``with_structured_output`` defaults
|
||||||
|
to function-calling so the Responses-API parse path is avoided
|
||||||
|
(langchain-openai's parse path emits noisy
|
||||||
|
PydanticSerializationUnexpectedValue warnings per call without
|
||||||
|
affecting correctness).
|
||||||
|
|
||||||
|
Provider-specific quirks (e.g. DeepSeek's thinking mode) live in
|
||||||
|
purpose-built subclasses below so this base class stays small.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def invoke(self, input, config=None, **kwargs):
|
def invoke(self, input, config=None, **kwargs):
|
||||||
return normalize_content(super().invoke(input, config, **kwargs))
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
def with_structured_output(self, schema, *, method=None, **kwargs):
|
||||||
|
if method is None:
|
||||||
|
method = "function_calling"
|
||||||
|
return super().with_structured_output(schema, method=method, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _input_to_messages(input_: Any) -> list:
|
||||||
|
"""Normalise a langchain LLM input to a list of message objects.
|
||||||
|
|
||||||
|
Accepts a list of messages, a ``ChatPromptValue`` (from a
|
||||||
|
ChatPromptTemplate), or anything else (treated as no messages).
|
||||||
|
Used by providers that need to walk the outgoing message history;
|
||||||
|
in particular DeepSeek thinking-mode propagation must work for
|
||||||
|
both bare-list invocations and ChatPromptTemplate-driven ones, so
|
||||||
|
treating only ``list`` here would silently skip half the call sites.
|
||||||
|
"""
|
||||||
|
if isinstance(input_, list):
|
||||||
|
return input_
|
||||||
|
if hasattr(input_, "to_messages"):
|
||||||
|
return input_.to_messages()
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekChatOpenAI(NormalizedChatOpenAI):
|
||||||
|
"""DeepSeek-specific overrides on top of the OpenAI-compatible client.
|
||||||
|
|
||||||
|
Two quirks that don't apply to other OpenAI-compatible providers:
|
||||||
|
|
||||||
|
1. **Thinking-mode round-trip.** When DeepSeek's thinking models return
|
||||||
|
a response with ``reasoning_content``, that field must be echoed
|
||||||
|
back as part of the assistant message on the next turn or the API
|
||||||
|
fails with HTTP 400. ``_create_chat_result`` captures the field on
|
||||||
|
receive and ``_get_request_payload`` re-attaches it on send.
|
||||||
|
|
||||||
|
2. **deepseek-reasoner has no tool_choice.** Structured output via
|
||||||
|
function-calling is unavailable, so we raise NotImplementedError
|
||||||
|
and let the agent factories fall back to free-text generation
|
||||||
|
(see ``tradingagents/agents/utils/structured.py``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||||
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||||
|
outgoing = payload.get("messages", [])
|
||||||
|
for message_dict, message in zip(outgoing, _input_to_messages(input_)):
|
||||||
|
if not isinstance(message, AIMessage):
|
||||||
|
continue
|
||||||
|
reasoning = message.additional_kwargs.get("reasoning_content")
|
||||||
|
if reasoning is not None:
|
||||||
|
message_dict["reasoning_content"] = reasoning
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _create_chat_result(self, response, generation_info=None):
|
||||||
|
chat_result = super()._create_chat_result(response, generation_info)
|
||||||
|
response_dict = (
|
||||||
|
response
|
||||||
|
if isinstance(response, dict)
|
||||||
|
else response.model_dump(
|
||||||
|
exclude={"choices": {"__all__": {"message": {"parsed"}}}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for generation, choice in zip(
|
||||||
|
chat_result.generations, response_dict.get("choices", [])
|
||||||
|
):
|
||||||
|
reasoning = choice.get("message", {}).get("reasoning_content")
|
||||||
|
if reasoning is not None:
|
||||||
|
generation.message.additional_kwargs["reasoning_content"] = reasoning
|
||||||
|
return chat_result
|
||||||
|
|
||||||
|
def with_structured_output(self, schema, *, method=None, **kwargs):
|
||||||
|
if self.model_name == "deepseek-reasoner":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"deepseek-reasoner does not support tool_choice; structured "
|
||||||
|
"output is unavailable. Agent factories fall back to "
|
||||||
|
"free-text generation automatically."
|
||||||
|
)
|
||||||
|
return super().with_structured_output(schema, method=method, **kwargs)
|
||||||
|
|
||||||
# Kwargs forwarded from user config to ChatOpenAI
|
# Kwargs forwarded from user config to ChatOpenAI
|
||||||
_PASSTHROUGH_KWARGS = (
|
_PASSTHROUGH_KWARGS = (
|
||||||
"timeout", "max_retries", "reasoning_effort",
|
"timeout", "max_retries", "reasoning_effort",
|
||||||
@@ -27,6 +112,9 @@ _PASSTHROUGH_KWARGS = (
|
|||||||
# Provider base URLs and API key env vars
|
# Provider base URLs and API key env vars
|
||||||
_PROVIDER_CONFIG = {
|
_PROVIDER_CONFIG = {
|
||||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||||
|
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
|
||||||
|
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
|
||||||
|
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
|
||||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||||
"ollama": ("http://localhost:11434/v1", None),
|
"ollama": ("http://localhost:11434/v1", None),
|
||||||
}
|
}
|
||||||
@@ -53,12 +141,15 @@ class OpenAIClient(BaseLLMClient):
|
|||||||
|
|
||||||
def get_llm(self) -> Any:
|
def get_llm(self) -> Any:
|
||||||
"""Return configured ChatOpenAI instance."""
|
"""Return configured ChatOpenAI instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
llm_kwargs = {"model": self.model}
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
# Provider-specific base URL and auth
|
# Provider-specific base URL and auth. An explicit base_url on the
|
||||||
|
# client (e.g. a corporate proxy) takes precedence over the
|
||||||
|
# provider default so users can route through their own gateway.
|
||||||
if self.provider in _PROVIDER_CONFIG:
|
if self.provider in _PROVIDER_CONFIG:
|
||||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
default_base, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||||
llm_kwargs["base_url"] = base_url
|
llm_kwargs["base_url"] = self.base_url or default_base
|
||||||
if api_key_env:
|
if api_key_env:
|
||||||
api_key = os.environ.get(api_key_env)
|
api_key = os.environ.get(api_key_env)
|
||||||
if api_key:
|
if api_key:
|
||||||
@@ -78,7 +169,10 @@ class OpenAIClient(BaseLLMClient):
|
|||||||
if self.provider == "openai":
|
if self.provider == "openai":
|
||||||
llm_kwargs["use_responses_api"] = True
|
llm_kwargs["use_responses_api"] = True
|
||||||
|
|
||||||
return NormalizedChatOpenAI(**llm_kwargs)
|
# DeepSeek's thinking-mode quirks live in their own subclass so the
|
||||||
|
# base NormalizedChatOpenAI stays free of provider-specific branches.
|
||||||
|
chat_cls = DeepSeekChatOpenAI if self.provider == "deepseek" else NormalizedChatOpenAI
|
||||||
|
return chat_cls(**llm_kwargs)
|
||||||
|
|
||||||
def validate_model(self) -> bool:
|
def validate_model(self) -> bool:
|
||||||
"""Validate model for the provider."""
|
"""Validate model for the provider."""
|
||||||
|
|||||||
@@ -1,53 +1,12 @@
|
|||||||
"""Model name validators for each provider.
|
"""Model name validators for each provider."""
|
||||||
|
|
||||||
|
from .model_catalog import get_known_models
|
||||||
|
|
||||||
Only validates model names - does NOT enforce limits.
|
|
||||||
Let LLM providers use their own defaults for unspecified params.
|
|
||||||
"""
|
|
||||||
|
|
||||||
VALID_MODELS = {
|
VALID_MODELS = {
|
||||||
"openai": [
|
provider: models
|
||||||
# GPT-5 series
|
for provider, models in get_known_models().items()
|
||||||
"gpt-5.4-pro",
|
if provider not in ("ollama", "openrouter")
|
||||||
"gpt-5.4",
|
|
||||||
"gpt-5.2",
|
|
||||||
"gpt-5.1",
|
|
||||||
"gpt-5",
|
|
||||||
"gpt-5-mini",
|
|
||||||
"gpt-5-nano",
|
|
||||||
# GPT-4.1 series
|
|
||||||
"gpt-4.1",
|
|
||||||
"gpt-4.1-mini",
|
|
||||||
"gpt-4.1-nano",
|
|
||||||
],
|
|
||||||
"anthropic": [
|
|
||||||
# Claude 4.6 series (latest)
|
|
||||||
"claude-opus-4-6",
|
|
||||||
"claude-sonnet-4-6",
|
|
||||||
# Claude 4.5 series
|
|
||||||
"claude-opus-4-5",
|
|
||||||
"claude-sonnet-4-5",
|
|
||||||
"claude-haiku-4-5",
|
|
||||||
],
|
|
||||||
"google": [
|
|
||||||
# Gemini 3.1 series (preview)
|
|
||||||
"gemini-3.1-pro-preview",
|
|
||||||
"gemini-3.1-flash-lite-preview",
|
|
||||||
# Gemini 3 series (preview)
|
|
||||||
"gemini-3-flash-preview",
|
|
||||||
# Gemini 2.5 series
|
|
||||||
"gemini-2.5-pro",
|
|
||||||
"gemini-2.5-flash",
|
|
||||||
"gemini-2.5-flash-lite",
|
|
||||||
],
|
|
||||||
"xai": [
|
|
||||||
# Grok 4.1 series
|
|
||||||
"grok-4-1-fast-reasoning",
|
|
||||||
"grok-4-1-fast-non-reasoning",
|
|
||||||
# Grok 4 series
|
|
||||||
"grok-4-0709",
|
|
||||||
"grok-4-fast-reasoning",
|
|
||||||
"grok-4-fast-non-reasoning",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user