mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-21 15:26:20 +03:00
Compare commits
155 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5cb7cbd61 | ||
|
|
78d063dc5c | ||
|
|
819e813a14 | ||
|
|
800862405d | ||
|
|
f10daa2824 | ||
|
|
879e2bb5da | ||
|
|
9f7abfcbd5 | ||
|
|
d13e9b7946 | ||
|
|
6b384f74f9 | ||
|
|
384fe1a3d2 | ||
|
|
0fcf13624e | ||
|
|
d0dd0420ad | ||
|
|
faaeebac70 | ||
|
|
0011b5ebf5 | ||
|
|
4f057e290c | ||
|
|
9e00c8117f | ||
|
|
78fe77f4e6 | ||
|
|
e1316686f8 | ||
|
|
9482cae188 | ||
|
|
19d22b54a9 | ||
|
|
704b7627f2 | ||
|
|
22bb91bd83 | ||
|
|
afdc6d4ec1 | ||
|
|
e2c850eb17 | ||
|
|
c405867bde | ||
|
|
db7e0a67e2 | ||
|
|
7e9e7b83c7 | ||
|
|
2c97bad45c | ||
|
|
7c37249f80 | ||
|
|
4016fd4efa | ||
|
|
bba147798f | ||
|
|
0fda24515f | ||
|
|
4cbd4b086f | ||
|
|
ebd2e12e67 | ||
|
|
f85f5d9f5d | ||
|
|
8e7654f0df | ||
|
|
872b063e69 | ||
|
|
6abc768c1d | ||
|
|
8536ccacdd | ||
|
|
fa4d01c23a | ||
|
|
b0f6058299 | ||
|
|
59d6b2152d | ||
|
|
10c136f49c | ||
|
|
4f965bf46a | ||
|
|
bdb9c29d44 | ||
|
|
bdc5fc62d3 | ||
|
|
78fb66aed1 | ||
|
|
7269f877c1 | ||
|
|
28d5cc661f | ||
|
|
7004dfe554 | ||
|
|
4641c03340 | ||
|
|
e75d17bc51 | ||
|
|
6cddd26d6e | ||
|
|
c61242a28c | ||
|
|
58e99421bd | ||
|
|
46e1b600b8 | ||
|
|
ae8c8aebe8 | ||
|
|
f3f58bdbdc | ||
|
|
e1113880a1 | ||
|
|
bd6a5b75b5 | ||
|
|
8793336dad | ||
|
|
047b38971c | ||
|
|
f5026009f9 | ||
|
|
589b351f2a | ||
|
|
6c9c9ce1fd | ||
|
|
b8b2825783 | ||
|
|
318adda0c6 | ||
|
|
c3ba3bf428 | ||
|
|
7cca9c924e | ||
|
|
bd9b1e5efa | ||
|
|
77755f0431 | ||
|
|
0b13145dc0 | ||
|
|
3ff28f3559 | ||
|
|
7d200d834a | ||
|
|
08bfe70a69 | ||
|
|
f362a160c3 | ||
|
|
64f07671b9 | ||
|
|
b19c5c18fb | ||
|
|
551fd7f074 | ||
|
|
b0f9d180f9 | ||
|
|
9cc283ac22 | ||
|
|
fe9c8d5d31 | ||
|
|
eec6ca4b53 | ||
|
|
3642f5917c | ||
|
|
907bc8022a | ||
|
|
8a60662070 | ||
|
|
f047f26df0 | ||
|
|
35856ff33e | ||
|
|
5fec171a1e | ||
|
|
50c82a25b5 | ||
|
|
8b3068d091 | ||
|
|
66a02b3193 | ||
|
|
e9470b69c4 | ||
|
|
b4b133eb2d | ||
|
|
80aab35119 | ||
|
|
393d4c6a1b | ||
|
|
aba1880c8c | ||
|
|
6cd35179fa | ||
|
|
102b026d23 | ||
|
|
224941d8c2 | ||
|
|
93b87d5119 | ||
|
|
54cdb146d0 | ||
|
|
b06936f420 | ||
|
|
b75940e901 | ||
|
|
3d040f8da4 | ||
|
|
50961b2477 | ||
|
|
a3761bdd66 | ||
|
|
d4dadb82fc | ||
|
|
79051580b8 | ||
|
|
13b826a31d | ||
|
|
b2ef960da7 | ||
|
|
a5dcc7da45 | ||
|
|
7bb2941b07 | ||
|
|
32be17c606 | ||
|
|
c07dcf026b | ||
|
|
d23fb539e9 | ||
|
|
b01051b9f4 | ||
|
|
8fdbbcca3d | ||
|
|
86bc0e793f | ||
|
|
7fc9c28a94 | ||
|
|
7bcc2cbd8a | ||
|
|
6211b1132a | ||
|
|
8b04ec307f | ||
|
|
0ab323c2c6 | ||
|
|
a6734d71bc | ||
|
|
a438acdbbd | ||
|
|
c73e374e7c | ||
|
|
f704828f89 | ||
|
|
fda4f664e8 | ||
|
|
718df34932 | ||
|
|
43aa9c5d09 | ||
|
|
26c5ba5a78 | ||
|
|
78ea029a0b | ||
|
|
ee3d499894 | ||
|
|
7abff0f354 | ||
|
|
b575bd0941 | ||
|
|
b8f712b170 | ||
|
|
52284ce13c | ||
|
|
11804f88ff | ||
|
|
1e86e74314 | ||
|
|
c2f897fc67 | ||
|
|
ed32081f57 | ||
|
|
2af7ef3d79 | ||
|
|
383deb72aa | ||
|
|
7eaf4d995f | ||
|
|
da84ef43aa | ||
|
|
90b23e72f5 | ||
|
|
417b09712c | ||
|
|
570644d939 | ||
|
|
9647359246 | ||
|
|
99789f9cd1 | ||
|
|
a879868396 | ||
|
|
0013415378 | ||
|
|
0fdfd35867 | ||
|
|
e994e56c23 |
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
.git
|
||||||
|
.venv
|
||||||
|
.env
|
||||||
|
.claude
|
||||||
|
.idea
|
||||||
|
.vscode
|
||||||
|
.DS_Store
|
||||||
|
__pycache__
|
||||||
|
*.egg-info
|
||||||
|
build
|
||||||
|
dist
|
||||||
|
results
|
||||||
|
eval_results
|
||||||
|
Dockerfile
|
||||||
|
docker-compose.yml
|
||||||
5
.env.enterprise.example
Normal file
5
.env.enterprise.example
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Azure OpenAI
|
||||||
|
AZURE_OPENAI_API_KEY=
|
||||||
|
AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
|
||||||
|
AZURE_OPENAI_DEPLOYMENT_NAME=
|
||||||
|
# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API
|
||||||
32
.env.example
Normal file
32
.env.example
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# LLM Providers (set the one you use)
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
GOOGLE_API_KEY=
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
XAI_API_KEY=
|
||||||
|
DEEPSEEK_API_KEY=
|
||||||
|
DASHSCOPE_API_KEY=
|
||||||
|
DASHSCOPE_CN_API_KEY=
|
||||||
|
ZHIPU_API_KEY=
|
||||||
|
ZHIPU_CN_API_KEY=
|
||||||
|
MINIMAX_API_KEY=
|
||||||
|
MINIMAX_CN_API_KEY=
|
||||||
|
OPENROUTER_API_KEY=
|
||||||
|
|
||||||
|
# Optional: point at a remote Ollama server. When unset, defaults to
|
||||||
|
# the local instance at http://localhost:11434/v1. Convention follows
|
||||||
|
# the broader Ollama ecosystem; both the CLI dropdown and programmatic
|
||||||
|
# client pick this up.
|
||||||
|
#OLLAMA_BASE_URL=http://your-ollama-host:11434/v1
|
||||||
|
|
||||||
|
# Optional: override DEFAULT_CONFIG without editing code.
|
||||||
|
# Any TRADINGAGENTS_* variable below, when set, replaces the matching key
|
||||||
|
# in tradingagents/default_config.py. Values are coerced to the type of
|
||||||
|
# the existing default (bool / int / str), so "true"/"3" work as expected.
|
||||||
|
#TRADINGAGENTS_LLM_PROVIDER=openai
|
||||||
|
#TRADINGAGENTS_DEEP_THINK_LLM=gpt-5.4
|
||||||
|
#TRADINGAGENTS_QUICK_THINK_LLM=gpt-5.4-mini
|
||||||
|
#TRADINGAGENTS_LLM_BACKEND_URL=
|
||||||
|
#TRADINGAGENTS_OUTPUT_LANGUAGE=English
|
||||||
|
#TRADINGAGENTS_MAX_DEBATE_ROUNDS=1
|
||||||
|
#TRADINGAGENTS_MAX_RISK_ROUNDS=1
|
||||||
|
#TRADINGAGENTS_CHECKPOINT_ENABLED=false
|
||||||
223
.gitignore
vendored
223
.gitignore
vendored
@@ -1,8 +1,219 @@
|
|||||||
env/
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
.DS_Store
|
*.py[codz]
|
||||||
*.csv
|
*$py.class
|
||||||
src/
|
|
||||||
eval_results/
|
# C extensions
|
||||||
eval_data/
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
# Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
# poetry.lock
|
||||||
|
# poetry.toml
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||||
|
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||||
|
# pdm.lock
|
||||||
|
# pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# pixi
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||||
|
# pixi.lock
|
||||||
|
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||||
|
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||||
|
.pixi
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
*.rdb
|
||||||
|
*.aof
|
||||||
|
*.pid
|
||||||
|
|
||||||
|
# RabbitMQ
|
||||||
|
mnesia/
|
||||||
|
rabbitmq/
|
||||||
|
rabbitmq-data/
|
||||||
|
|
||||||
|
# ActiveMQ
|
||||||
|
activemq-data/
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.envrc
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
# .idea/
|
||||||
|
|
||||||
|
# Abstra
|
||||||
|
# Abstra is an AI-powered process automation framework.
|
||||||
|
# Ignore directories containing user credentials, local state, and settings.
|
||||||
|
# Learn more at https://abstra.io/docs
|
||||||
|
.abstra/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||||
|
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||||
|
# you could uncomment the following to ignore the entire vscode folder
|
||||||
|
# .vscode/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Marimo
|
||||||
|
marimo/_static/
|
||||||
|
marimo/_lsp/
|
||||||
|
__marimo__/
|
||||||
|
|
||||||
|
# Streamlit
|
||||||
|
.streamlit/secrets.toml
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
**/data_cache/
|
||||||
|
|||||||
341
CHANGELOG.md
Normal file
341
CHANGELOG.md
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to TradingAgents are documented here.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
|
and this project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
Breaking changes within the 0.x line are called out explicitly.
|
||||||
|
|
||||||
|
## [0.2.5] — 2026-05-11
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Grounded Sentiment Analyst.** The renamed `sentiment_analyst` now reads
|
||||||
|
real Yahoo News, StockTwits, and Reddit data before generating its report,
|
||||||
|
replacing the prior flow that could fabricate social posts under prompt
|
||||||
|
pressure. (#557, #607)
|
||||||
|
- **MiniMax provider** with the full M2.x catalog (M2.7 / M2.5 / M2.1 / M2
|
||||||
|
plus highspeed variants, 204K context). Dual-region: Global
|
||||||
|
(`MINIMAX_API_KEY`) and China (`MINIMAX_CN_API_KEY`).
|
||||||
|
- **Dual-region Qwen and GLM** with separate keys per region — international
|
||||||
|
(`DASHSCOPE_API_KEY`, `ZHIPU_API_KEY`) and China (`DASHSCOPE_CN_API_KEY`,
|
||||||
|
`ZHIPU_CN_API_KEY`), selectable via a secondary region prompt. (#758)
|
||||||
|
- **`TRADINGAGENTS_*` env-var configurability for `DEFAULT_CONFIG`.** Override
|
||||||
|
`llm_provider`, deep/quick model IDs, `backend_url`, `output_language`,
|
||||||
|
debate-round counts, checkpoint flag, and benchmark ticker via `.env` with
|
||||||
|
type-aware coercion (string / int / bool). (#602)
|
||||||
|
- **Interactive API-key detection in the CLI.** When the selected provider's
|
||||||
|
key is missing, the CLI prompts for it and persists the value to `.env`
|
||||||
|
so the analysis run continues without restart.
|
||||||
|
- **Remote Ollama support.** `OLLAMA_BASE_URL` points the CLI and the
|
||||||
|
programmatic client at a remote `ollama-serve`. The CLI surfaces the
|
||||||
|
resolved endpoint and warns on common malformed inputs. Adds a
|
||||||
|
`"Custom model ID"` option for models pulled via `ollama pull`. (#648, #768)
|
||||||
|
- **Configurable news-fetch parameters** in `DEFAULT_CONFIG` — per-ticker
|
||||||
|
article limit, macro headline limit, lookback window, and macro search
|
||||||
|
queries. (#606, #683)
|
||||||
|
- **Configurable alpha benchmark** for non-US tickers. Replaces hardcoded
|
||||||
|
SPY with regional indices for `.NS` (^NSEI), `.T` (^N225), `.HK` (^HSI),
|
||||||
|
`.L` (^FTSE), `.TO` (^GSPTSE), `.AX` (^AXJO), `.BO` (^BSESN); explicit
|
||||||
|
`benchmark_ticker` override available. Eliminates FX drift dominating
|
||||||
|
alpha for non-USD listings. (#628, #684)
|
||||||
|
- **Multi-language output covers every user-facing agent** — researchers,
|
||||||
|
risk debators, research manager, and trader, ending the previous
|
||||||
|
partial-localization reports. (#575)
|
||||||
|
- **Model catalog refresh.** OpenAI GPT-5.5 frontier, Anthropic Claude Opus
|
||||||
|
4.7, Gemini 3.1 Flash-Lite GA, xAI Grok 4.20, Qwen 3.6 line. Versioned IDs
|
||||||
|
only; auto-shifting aliases moved to the `"Custom model ID"` option.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **Sentiment Analyst** is now consistently named across the CLI dropdown,
|
||||||
|
status panel, and final reports (previously the backend was renamed but
|
||||||
|
the CLI still said "Social Analyst"). The `AnalystType.SOCIAL = "social"`
|
||||||
|
wire value is kept for saved-config back-compat.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **Structured output works on DeepSeek V4 / reasoner and MiniMax M2.x.**
|
||||||
|
Those providers reject `tool_choice` per their tool-calling docs; the
|
||||||
|
binding flow now skips it automatically via a capability table.
|
||||||
|
- **`pip install .` installations pick up the project `.env`** when running
|
||||||
|
the CLI as a console script. (#747)
|
||||||
|
- **Reports save end-to-end** — streamed chunks were previously dropped from
|
||||||
|
`complete_report.md`. (#719, #736)
|
||||||
|
- **Ticker prompt preserves exchange suffixes** (`.SH`, `.SZ`, `.SS`, `.HK`,
|
||||||
|
`.T`, etc.) for A-share, HK, Tokyo, and other non-US flows. (#770)
|
||||||
|
- **Docker permission errors** no longer block first-run write to
|
||||||
|
`~/.tradingagents/`. (#519, #627, #672, #771)
|
||||||
|
- **Config state no longer leaks between runs** when sub-dicts are mutated;
|
||||||
|
`set_config` partial updates preserve sibling defaults. (#788)
|
||||||
|
- **`max_recur_limit` config actually applies** — previously read but not
|
||||||
|
forwarded to the propagator. (#764)
|
||||||
|
- **Missing-API-key error** names the exact env var to set. (#680)
|
||||||
|
- **Quieter startup** — suppressed the noisy upstream
|
||||||
|
`LangChainPendingDeprecationWarning` from langgraph-checkpoint; will be
|
||||||
|
removed once that package ships its fix.
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
- **Ticker path-traversal validation** at every filesystem-path site (cache,
|
||||||
|
checkpoint database, results) so a malicious ticker cannot escape its
|
||||||
|
intended directory. (#618)
|
||||||
|
|
||||||
|
## [0.2.4] — 2026-04-25
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Structured-output decision agents.** Research Manager, Trader, and Portfolio
|
||||||
|
Manager now use `llm.with_structured_output(Schema)` on their primary call
|
||||||
|
and return typed Pydantic instances. Each provider's native structured-output
|
||||||
|
mode is used (`json_schema` for OpenAI / xAI, `response_schema` for Gemini,
|
||||||
|
tool-use for Anthropic, function-calling for OpenAI-compatible providers).
|
||||||
|
Render helpers preserve the existing markdown shape so memory log, CLI
|
||||||
|
display, and saved reports keep working unchanged. (#434)
|
||||||
|
- **LangGraph checkpoint resume** — opt-in via `--checkpoint`. State is saved
|
||||||
|
after each node so crashed or interrupted runs resume from the last
|
||||||
|
successful step. Per-ticker SQLite databases under
|
||||||
|
`~/.tradingagents/cache/checkpoints/`. `--clear-checkpoints` resets them. (#594)
|
||||||
|
- **Persistent decision log** replacing the per-agent BM25 memory. Decisions
|
||||||
|
are stored automatically at the end of `propagate()`; the next same-ticker
|
||||||
|
run resolves prior pending entries with realised return, alpha vs SPY, and
|
||||||
|
a one-paragraph reflection. Override path with `TRADINGAGENTS_MEMORY_LOG_PATH`.
|
||||||
|
Optional `memory_log_max_entries` config caps resolved entries; pending
|
||||||
|
entries are never pruned. (#578, #563, #564, #579)
|
||||||
|
- **DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), and Azure OpenAI**
|
||||||
|
providers, plus dynamic OpenRouter model selection.
|
||||||
|
- **Docker support** — multi-stage build with separate dev and runtime images.
|
||||||
|
- **`scripts/smoke_structured_output.py`** — diagnostic that exercises the
|
||||||
|
three structured-output agents against any provider so contributors can
|
||||||
|
verify their setup with one command.
|
||||||
|
- **5-tier rating scale** (Buy / Overweight / Hold / Underweight / Sell) used
|
||||||
|
consistently by Research Manager, Portfolio Manager, signal processor, and
|
||||||
|
the memory log; Trader keeps 3-tier (Buy / Hold / Sell) since transaction
|
||||||
|
direction is naturally ternary.
|
||||||
|
- **Pytest fixtures** — lazy LLM client imports plus placeholder API keys so
|
||||||
|
the test suite runs cleanly without credentials. (#588)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **`backend_url` default is now `None`** rather than the OpenAI URL. Each
|
||||||
|
provider client falls back to its native default. The previous default
|
||||||
|
leaked the OpenAI URL into non-OpenAI clients (e.g. Gemini), producing
|
||||||
|
malformed request URLs for Python users who switched providers without
|
||||||
|
overriding `backend_url`. The CLI flow is unaffected.
|
||||||
|
- All file I/O passes explicit `encoding="utf-8"` so Windows users no longer
|
||||||
|
hit `UnicodeEncodeError` with the cp1252 default. (#543, #550, #576)
|
||||||
|
- Cache and log directories moved to `~/.tradingagents/` to resolve Docker
|
||||||
|
permission issues. (#519)
|
||||||
|
- `SignalProcessor` reads the rating from the Portfolio Manager's rendered
|
||||||
|
markdown via a deterministic heuristic — no extra LLM call.
|
||||||
|
- OpenAI structured-output calls default to `method="function_calling"` to
|
||||||
|
avoid noisy `PydanticSerializationUnexpectedValue` warnings emitted by
|
||||||
|
langchain-openai's Responses-API parse path. Same typed result, no warnings.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Empty memory no longer triggers fabricated past-lessons in agent prompts;
|
||||||
|
the memory-log redesign makes this structurally impossible since only the
|
||||||
|
Portfolio Manager consults memory and only when entries exist. (#572)
|
||||||
|
- Tool-call logging processes every chunk message, not just the last one, and
|
||||||
|
memory score normalization handles empty score arrays. (#534, #531)
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- `FinancialSituationMemory` (the per-agent BM25 system) and the dead
|
||||||
|
`reflect_and_remember()` plumbing; subsumed by the persistent decision log.
|
||||||
|
- Hardcoded Google endpoint that caused 404 when `langchain-google-genai`
|
||||||
|
changed its API path. (#493, #496)
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
Thanks to everyone who shaped this release through code, design, and reports:
|
||||||
|
|
||||||
|
- [@claytonbrown](https://github.com/claytonbrown) — checkpoint resume (#594), test fixtures (#588), design feedback on cost tracking (#582) and structured validation (#583)
|
||||||
|
- [@Bcardo](https://github.com/Bcardo) — memory-log redesign (#579), empty-memory hallucination report (#572), encoding fix proposal (#570)
|
||||||
|
- [@voidborne-d](https://github.com/voidborne-d) — memory persistence design (#564), portfolio manager state fix (#503)
|
||||||
|
- [@mannubaveja007](https://github.com/mannubaveja007) — structured-output feature request (#434)
|
||||||
|
- [@kelder66](https://github.com/kelder66) — RAM-only memory issue (#563)
|
||||||
|
- [@Gujiassh](https://github.com/Gujiassh) — tool-call logging fix (#534), test stub PR (#533)
|
||||||
|
- [@iuyup](https://github.com/iuyup) — memory score normalization fix (#531)
|
||||||
|
- [@kaihg](https://github.com/kaihg) — Google base_url fix (#496)
|
||||||
|
- [@32ryh98yfe](https://github.com/32ryh98yfe) — Gemini 404 report (#493)
|
||||||
|
- [@uppb](https://github.com/uppb) — OpenRouter dynamic model selection (#482)
|
||||||
|
- [@guoz14](https://github.com/guoz14) — OpenRouter limited-model report (#337)
|
||||||
|
- [@samchenku](https://github.com/samchenku) — indicator name normalization (#490)
|
||||||
|
- [@JasonOA888](https://github.com/JasonOA888) — y_finance pandas import fix (#488)
|
||||||
|
- [@tiffanychum](https://github.com/tiffanychum) — stale import cleanup (#499)
|
||||||
|
- [@zaizou](https://github.com/zaizou) — Docker permission issue (#519)
|
||||||
|
- [@Stosman123](https://github.com/Stosman123), [@mauropuga](https://github.com/mauropuga), [@hotwind2015](https://github.com/hotwind2015) — Windows encoding bug reports (#543, #550, #576)
|
||||||
|
- [@nnishad](https://github.com/nnishad), [@atharvajoshi01](https://github.com/atharvajoshi01) — encoding fix proposals (#568, #549)
|
||||||
|
|
||||||
|
## [0.2.3] — 2026-03-29
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Multi-language output** for analyst reports and final decisions, with a
|
||||||
|
CLI selector. Internal agent debate stays in English for reasoning quality. (#472)
|
||||||
|
- **GPT-5.4 family models** in the default catalog, with deep/quick model split.
|
||||||
|
- **Unified model catalog** as a single source of truth for CLI options and
|
||||||
|
provider validation.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- `base_url` is forwarded to Google and Anthropic clients so corporate proxies
|
||||||
|
work consistently across providers. (#427)
|
||||||
|
- Standardised the Google `api_key` parameter to the unified `api_key` form.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Backtesting fetchers no longer leak look-ahead data when `curr_date` is in
|
||||||
|
the middle of a fetched window. (#475)
|
||||||
|
- Invalid indicator names from the LLM are caught at the tool boundary instead
|
||||||
|
of crashing the run. (#429)
|
||||||
|
- yfinance news fetchers respect the same exponential-backoff retry as price
|
||||||
|
fetchers. (#445)
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
- [@ahmedk20](https://github.com/ahmedk20) — multi-language output (#472)
|
||||||
|
- [@CadeYu](https://github.com/CadeYu) — model catalog typing (#464)
|
||||||
|
- [@javierdejesusda](https://github.com/javierdejesusda) — unified Google API key parameter (#453)
|
||||||
|
- [@voidborne-d](https://github.com/voidborne-d) — yfinance news retry (#445)
|
||||||
|
- [@kostakost2](https://github.com/kostakost2) — look-ahead bias report (#475)
|
||||||
|
- [@lu-zhengda](https://github.com/lu-zhengda) — proxy/base_url support request (#427)
|
||||||
|
- [@VamsiKrishna2021](https://github.com/VamsiKrishna2021) — invalid indicator crash report (#429)
|
||||||
|
|
||||||
|
## [0.2.2] — 2026-03-22
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Five-tier rating scale** (Buy / Overweight / Hold / Underweight / Sell)
|
||||||
|
introduced for the Portfolio Manager.
|
||||||
|
- **Anthropic effort level** support for Claude models.
|
||||||
|
- **OpenAI Responses API** path for native OpenAI models.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- `risk_manager` renamed to `portfolio_manager` to match the role description
|
||||||
|
shown in the CLI display.
|
||||||
|
- Exchange-qualified tickers (e.g. `7203.T`, `BRK.B`) preserved across all
|
||||||
|
agent prompts and tool calls.
|
||||||
|
- Process-level UTF-8 default attempted for cross-platform consistency
|
||||||
|
(note: this approach did not actually take effect; replaced in v0.2.4 with
|
||||||
|
explicit per-call `encoding="utf-8"` arguments).
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- yfinance rate-limit errors are retried with exponential backoff. (#426)
|
||||||
|
- HTTP client SSL customisation is supported for environments that need
|
||||||
|
custom certificate bundles. (#379)
|
||||||
|
- Report-section writes handle list-of-string content gracefully.
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
- [@CadeYu](https://github.com/CadeYu) — exchange-qualified ticker preservation (#413)
|
||||||
|
- [@yang1002378395-cmyk](https://github.com/yang1002378395-cmyk) — HTTP client SSL customisation (#379)
|
||||||
|
|
||||||
|
## [0.2.1] — 2026-03-15
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
- Patched `langchain-core` vulnerability (LangGrinch). (#335)
|
||||||
|
- Removed `chainlit` dependency affected by CVE-2026-22218.
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- `pyproject.toml` build-system configuration; the project now installs via
|
||||||
|
modern packaging tooling.
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- `setup.py` — dependencies consolidated to `pyproject.toml`.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Risk manager reads the correct fundamental report source. (#341)
|
||||||
|
- All `open()` calls receive an explicit UTF-8 encoding (initial pass).
|
||||||
|
- `get_indicators` tool handles comma-separated indicator names from the LLM. (#368)
|
||||||
|
- `Propagation` initialises every debate-state field so risk debaters never
|
||||||
|
see missing keys.
|
||||||
|
- Stock data parsing tolerates malformed CSVs and NaN values.
|
||||||
|
- Conditional debate logic respects the configured round count. (#361)
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
- [@RinZ27](https://github.com/RinZ27) — `langchain-core` security patch (#335)
|
||||||
|
- [@Ljx-007](https://github.com/Ljx-007) — risk manager fundamental-report fix (#341)
|
||||||
|
- [@makk9](https://github.com/makk9) — debate-rounds config issue (#361)
|
||||||
|
|
||||||
|
## [0.2.0] — 2026-02-04
|
||||||
|
|
||||||
|
This is the largest release since the initial public version. The framework
|
||||||
|
moved from single-provider to a multi-provider architecture and grew several
|
||||||
|
production-ready surfaces.
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Multi-provider LLM support** (OpenAI, Google, Anthropic, xAI, OpenRouter,
|
||||||
|
Ollama) via a factory pattern, with provider-specific thinking configurations.
|
||||||
|
- **Alpha Vantage** integration as a configurable primary data provider, with
|
||||||
|
yfinance as a community-stability fallback.
|
||||||
|
- **Footer statistics** in the CLI: real-time tracking of LLM calls, tool
|
||||||
|
calls, and token usage via LangChain callbacks.
|
||||||
|
- **Post-analysis report saving** — the framework writes per-section markdown
|
||||||
|
files (analyst reports, debate transcripts, final decision) when a run
|
||||||
|
completes.
|
||||||
|
- **Announcements panel** — fetches updates from `api.tauric.ai/v1/announcements`
|
||||||
|
for the CLI welcome screen.
|
||||||
|
- **Tool fallbacks** so a single vendor outage does not stop the pipeline.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Risky / Safe risk debaters renamed to **Aggressive / Conservative** for
|
||||||
|
consistency with the displayed agent labels.
|
||||||
|
- Default data vendor switched to balance reliability and quota across
|
||||||
|
community deployments.
|
||||||
|
- Ollama and OpenRouter model lists updated; default endpoints clarified.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Analyst status tracking and message deduplication in the live display.
|
||||||
|
- Infinite-loop guard in the agent loop; reflection and logging hardened.
|
||||||
|
- Various data-vendor implementation bugs and tool-signature mismatches.
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
|
||||||
|
This release is the first with substantial outside contributions; many community
|
||||||
|
PRs from late 2025 also landed here.
|
||||||
|
|
||||||
|
- [@luohy15](https://github.com/luohy15) — Alpha Vantage data-vendor integration (#235)
|
||||||
|
- [@EdwardoSunny](https://github.com/EdwardoSunny) — yfinance fetching optimisations (#245)
|
||||||
|
- [@Mirza-Samad-Ahmed-Baig](https://github.com/Mirza-Samad-Ahmed-Baig) — infinite-loop guard, reflection, and logging fixes (#89)
|
||||||
|
- [@ZeroAct](https://github.com/ZeroAct) — saved results path support (#29)
|
||||||
|
- [@Zhongyi-Lu](https://github.com/Zhongyi-Lu) — `.env` gitignore (#49)
|
||||||
|
- [@csoboy](https://github.com/csoboy) — local Ollama setup (#53)
|
||||||
|
- [@chauhang](https://github.com/chauhang) — initial Docker support attempt (#47, later reverted; the merged Docker support shipped in v0.2.4)
|
||||||
|
|
||||||
|
## [0.1.1] — 2025-06-07
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- Static site assets that had been bundled with v0.1.0; the public site now
|
||||||
|
lives separately.
|
||||||
|
|
||||||
|
## [0.1.0] — 2025-06-05
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Initial public release** of the TradingAgents multi-agent trading
|
||||||
|
framework: market / sentiment / news / fundamentals analysts; bull and bear
|
||||||
|
researchers; trader; aggressive, conservative, and neutral risk debaters;
|
||||||
|
portfolio manager. LangGraph orchestration, yfinance data, per-agent
|
||||||
|
BM25 memory, single-provider OpenAI integration, interactive CLI.
|
||||||
|
|
||||||
|
[0.2.4]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.3...v0.2.4
|
||||||
|
[0.2.3]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.2...v0.2.3
|
||||||
|
[0.2.2]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.1...v0.2.2
|
||||||
|
[0.2.1]: https://github.com/TauricResearch/TradingAgents/compare/v0.2.0...v0.2.1
|
||||||
|
[0.2.0]: https://github.com/TauricResearch/TradingAgents/compare/v0.1.1...v0.2.0
|
||||||
|
[0.1.1]: https://github.com/TauricResearch/TradingAgents/compare/v0.1.0...v0.1.1
|
||||||
|
[0.1.0]: https://github.com/TauricResearch/TradingAgents/releases/tag/v0.1.0
|
||||||
28
Dockerfile
Normal file
28
Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
|
|
||||||
|
RUN python -m venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
COPY . .
|
||||||
|
RUN pip install --no-cache-dir .
|
||||||
|
|
||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
COPY --from=builder /opt/venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
RUN useradd --create-home appuser \
|
||||||
|
&& install -d -m 0755 -o appuser -g appuser /home/appuser/.tradingagents
|
||||||
|
USER appuser
|
||||||
|
WORKDIR /home/appuser/app
|
||||||
|
|
||||||
|
COPY --from=builder --chown=appuser:appuser /build .
|
||||||
|
|
||||||
|
ENTRYPOINT ["tradingagents"]
|
||||||
138
README.md
138
README.md
@@ -11,10 +11,40 @@
|
|||||||
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
|
<a href="https://github.com/TauricResearch/" target="_blank"><img alt="Community" src="https://img.shields.io/badge/Join_GitHub_Community-TauricResearch-14C290?logo=discourse"/></a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||||
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=de">Deutsch</a> |
|
||||||
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=es">Español</a> |
|
||||||
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=fr">français</a> |
|
||||||
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ja">日本語</a> |
|
||||||
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ko">한국어</a> |
|
||||||
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=pt">Português</a> |
|
||||||
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=ru">Русский</a> |
|
||||||
|
<a href="https://www.readme-i18n.com/TauricResearch/TradingAgents?lang=zh">中文</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
||||||
|
|
||||||
|
## News
|
||||||
|
- [2026-05] **TradingAgents v0.2.5** released with the grounded Sentiment Analyst, GPT-5.5 etc. model coverage, Qwen/GLM/MiniMax dual-region support, `TRADINGAGENTS_*` env-var configurability with API-key auto-detection, remote Ollama support, non-US alpha benchmarks, and ticker path-traversal hardening. See [CHANGELOG.md](CHANGELOG.md) for the full list.
|
||||||
|
- [2026-04] **TradingAgents v0.2.4** released with structured-output agents (Research Manager, Trader, Portfolio Manager), LangGraph checkpoint resume, persistent decision log, DeepSeek/Qwen/GLM/Azure provider support, Docker, and a Windows UTF-8 encoding fix.
|
||||||
|
- [2026-03] **TradingAgents v0.2.3** released with multi-language support, GPT-5.4 family models, unified model catalog, backtesting date fidelity, and proxy support.
|
||||||
|
- [2026-03] **TradingAgents v0.2.2** released with GPT-5.4/Gemini 3.1/Claude 4.6 model coverage, five-tier rating scale, OpenAI Responses API, Anthropic effort control, and cross-platform stability.
|
||||||
|
- [2026-02] **TradingAgents v0.2.0** released with multi-provider LLM support (GPT-5.x, Gemini 3.x, Claude 4.x, Grok 4.x) and improved system architecture.
|
||||||
|
- [2026-01] **Trading-R1** [Technical Report](https://arxiv.org/abs/2509.11420) released, with [Terminal](https://github.com/TauricResearch/Trading-R1) expected to land soon.
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://www.star-history.com/#TauricResearch/TradingAgents&Date">
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date&theme=dark" />
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" />
|
||||||
|
<img alt="TradingAgents Star History" src="https://api.star-history.com/svg?repos=TauricResearch/TradingAgents&type=Date" style="width: 80%; height: auto;" />
|
||||||
|
</picture>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
|
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
|
||||||
>
|
>
|
||||||
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
|
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
|
||||||
@@ -39,7 +69,7 @@ Our framework decomposes complex trading tasks into specialized roles. This ensu
|
|||||||
|
|
||||||
### Analyst Team
|
### Analyst Team
|
||||||
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
|
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
|
||||||
- Sentiment Analyst: Analyzes social media and public sentiment using sentiment scoring algorithms to gauge short-term market mood.
|
- Sentiment Analyst: Aggregates news headlines, StockTwits, and Reddit chatter into a single sentiment read to gauge short-term market mood.
|
||||||
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
|
- News Analyst: Monitors global news and macroeconomic indicators, interpreting the impact of events on market conditions.
|
||||||
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
|
- Technical Analyst: Utilizes technical indicators (like MACD and RSI) to detect trading patterns and forecast price movements.
|
||||||
|
|
||||||
@@ -58,7 +88,7 @@ Our framework decomposes complex trading tasks into specialized roles. This ensu
|
|||||||
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
|
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
### Risk Management and Portfolio Manager
|
### Risk Management and Portfolio Manager
|
||||||
@@ -66,7 +96,7 @@ Our framework decomposes complex trading tasks into specialized roles. This ensu
|
|||||||
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
|
- The Portfolio Manager approves/rejects the transaction proposal. If approved, the order will be sent to the simulated exchange and executed.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/risk.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## Installation and CLI
|
## Installation and CLI
|
||||||
@@ -85,30 +115,61 @@ conda create -n tradingagents python=3.13
|
|||||||
conda activate tradingagents
|
conda activate tradingagents
|
||||||
```
|
```
|
||||||
|
|
||||||
Install dependencies:
|
Install the package and its dependencies:
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
Alternatively, run with Docker:
|
||||||
|
```bash
|
||||||
|
cp .env.example .env # add your API keys
|
||||||
|
docker compose run --rm tradingagents
|
||||||
|
```
|
||||||
|
|
||||||
|
For local models with Ollama:
|
||||||
|
```bash
|
||||||
|
docker compose --profile ollama run --rm tradingagents-ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
### Required APIs
|
### Required APIs
|
||||||
|
|
||||||
You will also need the FinnHub API and EODHD API for financial data. All of our code is implemented with the free tier.
|
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export FINNHUB_API_KEY=$YOUR_FINNHUB_API_KEY
|
export OPENAI_API_KEY=... # OpenAI (GPT)
|
||||||
|
export GOOGLE_API_KEY=... # Google (Gemini)
|
||||||
|
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
||||||
|
export XAI_API_KEY=... # xAI (Grok)
|
||||||
|
export DEEPSEEK_API_KEY=... # DeepSeek
|
||||||
|
export DASHSCOPE_API_KEY=... # Qwen — International (dashscope-intl.aliyuncs.com)
|
||||||
|
export DASHSCOPE_CN_API_KEY=... # Qwen — China (dashscope.aliyuncs.com)
|
||||||
|
export ZHIPU_API_KEY=... # GLM via Z.AI (international)
|
||||||
|
export ZHIPU_CN_API_KEY=... # GLM via BigModel (China, open.bigmodel.cn)
|
||||||
|
export MINIMAX_API_KEY=... # MiniMax — Global (api.minimax.io, M2.x, 204K ctx)
|
||||||
|
export MINIMAX_CN_API_KEY=... # MiniMax — China (api.minimaxi.com, M2.x, 204K ctx)
|
||||||
|
export OPENROUTER_API_KEY=... # OpenRouter
|
||||||
|
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||||
```
|
```
|
||||||
|
|
||||||
You will need the OpenAI API for all the agents.
|
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
||||||
|
|
||||||
|
For local models, configure Ollama with `llm_provider: "ollama"`. The default endpoint is `http://localhost:11434/v1`; set `OLLAMA_BASE_URL` to point at a remote `ollama-serve`. Pull models with `ollama pull <name>`, and pick "Custom model ID" in the CLI for any model not listed by default.
|
||||||
|
|
||||||
|
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||||
```bash
|
```bash
|
||||||
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
|
cp .env.example .env
|
||||||
```
|
```
|
||||||
|
|
||||||
### CLI Usage
|
### CLI Usage
|
||||||
|
|
||||||
You can also try out the CLI directly by running:
|
Launch the interactive CLI:
|
||||||
```bash
|
```bash
|
||||||
python -m cli.main
|
tradingagents # installed command
|
||||||
|
python -m cli.main # alternative: run directly from source
|
||||||
```
|
```
|
||||||
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
|
You will see a screen where you can select your desired tickers, analysis date, LLM provider, research depth, and more.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||||
@@ -128,7 +189,7 @@ An interface will appear showing results as they load, letting you track the age
|
|||||||
|
|
||||||
### Implementation Details
|
### Implementation Details
|
||||||
|
|
||||||
We built TradingAgents with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
|
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope, international and China endpoints), GLM (Zhipu), MiniMax (global + China), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise.
|
||||||
|
|
||||||
### Python Usage
|
### Python Usage
|
||||||
|
|
||||||
@@ -136,11 +197,12 @@ To use TradingAgents inside your code, you can import the `tradingagents` module
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
|
||||||
|
|
||||||
# forward propagate
|
# forward propagate
|
||||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||||
print(decision)
|
print(decision)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -150,29 +212,53 @@ You can also adjust the default configuration to set your own choice of LLMs, de
|
|||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
# Create a custom config
|
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
|
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, qwen-cn, glm, glm-cn, minimax, minimax-cn, openrouter, ollama, azure
|
||||||
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
|
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
||||||
config["online_tools"] = True # Use online tools or cached data
|
config["max_debate_rounds"] = 2
|
||||||
|
|
||||||
# Initialize with custom config
|
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
ta = TradingAgentsGraph(debug=True, config=config)
|
||||||
|
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||||
# forward propagate
|
|
||||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
|
||||||
print(decision)
|
print(decision)
|
||||||
```
|
```
|
||||||
|
|
||||||
> For `online_tools`, we recommend enabling them for experimentation, as they provide access to real-time data. The agents' offline tools rely on cached data from our **Tauric TradingDB**, a curated dataset we use for backtesting. We're currently in the process of refining this dataset, and we plan to release it soon alongside our upcoming projects. Stay tuned!
|
See `tradingagents/default_config.py` for all configuration options.
|
||||||
|
|
||||||
You can view the full list of configurations in `tradingagents/default_config.py`.
|
## Persistence and Recovery
|
||||||
|
|
||||||
|
TradingAgents persists two kinds of state across runs.
|
||||||
|
|
||||||
|
### Decision log
|
||||||
|
|
||||||
|
The decision log is always on. Each completed run appends its decision to `~/.tradingagents/memory/trading_memory.md`. On the next run for the same ticker, TradingAgents fetches the realised return (raw and alpha vs SPY), generates a one-paragraph reflection, and injects the most recent same-ticker decisions plus recent cross-ticker lessons into the Portfolio Manager prompt, so each analysis carries forward what worked and what didn't.
|
||||||
|
|
||||||
|
Override the path with `TRADINGAGENTS_MEMORY_LOG_PATH`.
|
||||||
|
|
||||||
|
### Checkpoint resume
|
||||||
|
|
||||||
|
Checkpoint resume is opt-in via `--checkpoint`. When enabled, LangGraph saves state after each node so a crashed or interrupted run resumes from the last successful step instead of starting over. On a resume run you will see `Resuming from step N for <TICKER> on <date>` in the logs; on a new run you will see `Starting fresh`. Checkpoints are cleared automatically on successful completion.
|
||||||
|
|
||||||
|
Per-ticker SQLite databases live at `~/.tradingagents/cache/checkpoints/<TICKER>.db` (override the base with `TRADINGAGENTS_CACHE_DIR`). Use `--clear-checkpoints` to reset all of them before a run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tradingagents analyze --checkpoint # enable for this run
|
||||||
|
tradingagents analyze --clear-checkpoints # reset before running
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = DEFAULT_CONFIG.copy()
|
||||||
|
config["checkpoint_enabled"] = True
|
||||||
|
ta = TradingAgentsGraph(config=config)
|
||||||
|
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||||
|
```
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||||
|
|
||||||
|
Past contributions, including code, design feedback, and bug reports, are credited per release in [`CHANGELOG.md`](CHANGELOG.md).
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
Please reference our work if you find *TradingAgents* provides you with some help :)
|
Please reference our work if you find *TradingAgents* provides you with some help :)
|
||||||
|
|||||||
51
cli/announcements.py
Normal file
51
cli/announcements.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import getpass
|
||||||
|
import requests
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from cli.config import CLI_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_announcements(url: str = None, timeout: float = None) -> dict:
|
||||||
|
"""Fetch announcements from endpoint. Returns dict with announcements and settings."""
|
||||||
|
endpoint = url or CLI_CONFIG["announcements_url"]
|
||||||
|
timeout = timeout or CLI_CONFIG["announcements_timeout"]
|
||||||
|
fallback = CLI_CONFIG["announcements_fallback"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(endpoint, timeout=timeout)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return {
|
||||||
|
"announcements": data.get("announcements", [fallback]),
|
||||||
|
"require_attention": data.get("require_attention", False),
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return {
|
||||||
|
"announcements": [fallback],
|
||||||
|
"require_attention": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def display_announcements(console: Console, data: dict) -> None:
|
||||||
|
"""Display announcements panel. Prompts for Enter if require_attention is True."""
|
||||||
|
announcements = data.get("announcements", [])
|
||||||
|
require_attention = data.get("require_attention", False)
|
||||||
|
|
||||||
|
if not announcements:
|
||||||
|
return
|
||||||
|
|
||||||
|
content = "\n".join(announcements)
|
||||||
|
|
||||||
|
panel = Panel(
|
||||||
|
content,
|
||||||
|
border_style="cyan",
|
||||||
|
padding=(1, 2),
|
||||||
|
title="Announcements",
|
||||||
|
)
|
||||||
|
console.print(panel)
|
||||||
|
|
||||||
|
if require_attention:
|
||||||
|
getpass.getpass("Press Enter to continue...")
|
||||||
|
else:
|
||||||
|
console.print()
|
||||||
6
cli/config.py
Normal file
6
cli/config.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
CLI_CONFIG = {
|
||||||
|
# Announcements
|
||||||
|
"announcements_url": "https://api.tauric.ai/v1/announcements",
|
||||||
|
"announcements_timeout": 1.0,
|
||||||
|
"announcements_fallback": "[cyan]For more information, please visit[/cyan] [link=https://github.com/TauricResearch]https://github.com/TauricResearch[/link]",
|
||||||
|
}
|
||||||
1265
cli/main.py
1265
cli/main.py
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,8 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class AnalystType(str, Enum):
|
class AnalystType(str, Enum):
|
||||||
MARKET = "market"
|
MARKET = "market"
|
||||||
|
# Wire value stays "social" for saved-config and string-keyed-caller
|
||||||
|
# back-compat; the user-facing label is "Sentiment Analyst".
|
||||||
SOCIAL = "social"
|
SOCIAL = "social"
|
||||||
NEWS = "news"
|
NEWS = "news"
|
||||||
FUNDAMENTALS = "fundamentals"
|
FUNDAMENTALS = "fundamentals"
|
||||||
|
|||||||
76
cli/stats_handler.py
Normal file
76
cli/stats_handler.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
class StatsCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback handler that tracks LLM calls, tool calls, and token usage."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self.llm_calls = 0
|
||||||
|
self.tool_calls = 0
|
||||||
|
self.tokens_in = 0
|
||||||
|
self.tokens_out = 0
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
prompts: List[str],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Increment LLM call counter when an LLM starts."""
|
||||||
|
with self._lock:
|
||||||
|
self.llm_calls += 1
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[Any]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Increment LLM call counter when a chat model starts."""
|
||||||
|
with self._lock:
|
||||||
|
self.llm_calls += 1
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
"""Extract token usage from LLM response."""
|
||||||
|
try:
|
||||||
|
generation = response.generations[0][0]
|
||||||
|
except (IndexError, TypeError):
|
||||||
|
return
|
||||||
|
|
||||||
|
usage_metadata = None
|
||||||
|
if hasattr(generation, "message"):
|
||||||
|
message = generation.message
|
||||||
|
if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"):
|
||||||
|
usage_metadata = message.usage_metadata
|
||||||
|
|
||||||
|
if usage_metadata:
|
||||||
|
with self._lock:
|
||||||
|
self.tokens_in += usage_metadata.get("input_tokens", 0)
|
||||||
|
self.tokens_out += usage_metadata.get("output_tokens", 0)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
input_str: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Increment tool call counter when a tool starts."""
|
||||||
|
with self._lock:
|
||||||
|
self.tool_calls += 1
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Return current statistics."""
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
"llm_calls": self.llm_calls,
|
||||||
|
"tool_calls": self.tool_calls,
|
||||||
|
"tokens_in": self.tokens_in,
|
||||||
|
"tokens_out": self.tokens_out,
|
||||||
|
}
|
||||||
429
cli/utils.py
429
cli/utils.py
@@ -1,11 +1,22 @@
|
|||||||
import questionary
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Dict
|
from typing import List, Optional, Tuple, Dict
|
||||||
|
|
||||||
|
import questionary
|
||||||
|
from dotenv import find_dotenv, set_key
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
|
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||||
|
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
|
||||||
|
|
||||||
ANALYST_ORDER = [
|
ANALYST_ORDER = [
|
||||||
("Market Analyst", AnalystType.MARKET),
|
("Market Analyst", AnalystType.MARKET),
|
||||||
("Social Media Analyst", AnalystType.SOCIAL),
|
("Sentiment Analyst", AnalystType.SOCIAL),
|
||||||
("News Analyst", AnalystType.NEWS),
|
("News Analyst", AnalystType.NEWS),
|
||||||
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
|
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
|
||||||
]
|
]
|
||||||
@@ -14,7 +25,7 @@ ANALYST_ORDER = [
|
|||||||
def get_ticker() -> str:
|
def get_ticker() -> str:
|
||||||
"""Prompt the user to enter a ticker symbol."""
|
"""Prompt the user to enter a ticker symbol."""
|
||||||
ticker = questionary.text(
|
ticker = questionary.text(
|
||||||
"Enter the ticker symbol to analyze:",
|
f"Enter the exact ticker symbol to analyze ({TICKER_INPUT_EXAMPLES}):",
|
||||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid ticker symbol.",
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid ticker symbol.",
|
||||||
style=questionary.Style(
|
style=questionary.Style(
|
||||||
[
|
[
|
||||||
@@ -28,6 +39,11 @@ def get_ticker() -> str:
|
|||||||
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
|
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
return normalize_ticker_symbol(ticker)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_ticker_symbol(ticker: str) -> str:
|
||||||
|
"""Normalize ticker input while preserving exchange suffixes."""
|
||||||
return ticker.strip().upper()
|
return ticker.strip().upper()
|
||||||
|
|
||||||
|
|
||||||
@@ -122,61 +138,126 @@ def select_research_depth() -> int:
|
|||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def select_shallow_thinking_agent() -> str:
|
def _fetch_openrouter_models() -> List[Tuple[str, str]]:
|
||||||
|
"""Fetch available models from the OpenRouter API."""
|
||||||
|
import requests
|
||||||
|
try:
|
||||||
|
resp = requests.get("https://openrouter.ai/api/v1/models", timeout=10)
|
||||||
|
resp.raise_for_status()
|
||||||
|
models = resp.json().get("data", [])
|
||||||
|
return [(m.get("name") or m["id"], m["id"]) for m in models]
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"\n[yellow]Could not fetch OpenRouter models: {e}[/yellow]")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def select_openrouter_model() -> str:
|
||||||
|
"""Select an OpenRouter model from the newest available, or enter a custom ID."""
|
||||||
|
models = _fetch_openrouter_models()
|
||||||
|
|
||||||
|
choices = [questionary.Choice(name, value=mid) for name, mid in models[:5]]
|
||||||
|
choices.append(questionary.Choice("Custom model ID", value="custom"))
|
||||||
|
|
||||||
|
choice = questionary.select(
|
||||||
|
"Select OpenRouter Model (latest available):",
|
||||||
|
choices=choices,
|
||||||
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:magenta noinherit"),
|
||||||
|
("highlighted", "fg:magenta noinherit"),
|
||||||
|
("pointer", "fg:magenta noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if choice is None or choice == "custom":
|
||||||
|
return questionary.text(
|
||||||
|
"Enter OpenRouter model ID (e.g. google/gemma-4-26b-a4b-it):",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
|
return choice
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_custom_model_id() -> str:
|
||||||
|
"""Prompt user to type a custom model ID."""
|
||||||
|
return questionary.text(
|
||||||
|
"Enter model ID:",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _select_model(provider: str, mode: str) -> str:
|
||||||
|
"""Select a model for the given provider and mode (quick/deep)."""
|
||||||
|
if provider.lower() == "openrouter":
|
||||||
|
return select_openrouter_model()
|
||||||
|
|
||||||
|
if provider.lower() == "azure":
|
||||||
|
return questionary.text(
|
||||||
|
f"Enter Azure deployment name ({mode}-thinking):",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
|
choice = questionary.select(
|
||||||
|
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice(display, value=value)
|
||||||
|
for display, value in get_model_options(provider, mode)
|
||||||
|
],
|
||||||
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
|
style=questionary.Style(
|
||||||
|
[
|
||||||
|
("selected", "fg:magenta noinherit"),
|
||||||
|
("highlighted", "fg:magenta noinherit"),
|
||||||
|
("pointer", "fg:magenta noinherit"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if choice is None:
|
||||||
|
console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if choice == "custom":
|
||||||
|
return _prompt_custom_model_id()
|
||||||
|
|
||||||
|
return choice
|
||||||
|
|
||||||
|
|
||||||
|
def select_shallow_thinking_agent(provider) -> str:
|
||||||
"""Select shallow thinking llm engine using an interactive selection."""
|
"""Select shallow thinking llm engine using an interactive selection."""
|
||||||
|
return _select_model(provider, "quick")
|
||||||
# Define shallow thinking llm engine options with their corresponding model names
|
|
||||||
SHALLOW_AGENT_OPTIONS = [
|
|
||||||
("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"),
|
|
||||||
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
|
|
||||||
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
|
|
||||||
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
|
|
||||||
]
|
|
||||||
|
|
||||||
choice = questionary.select(
|
|
||||||
"Select Your [Quick-Thinking LLM Engine]:",
|
|
||||||
choices=[
|
|
||||||
questionary.Choice(display, value=value)
|
|
||||||
for display, value in SHALLOW_AGENT_OPTIONS
|
|
||||||
],
|
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
|
||||||
style=questionary.Style(
|
|
||||||
[
|
|
||||||
("selected", "fg:magenta noinherit"),
|
|
||||||
("highlighted", "fg:magenta noinherit"),
|
|
||||||
("pointer", "fg:magenta noinherit"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
).ask()
|
|
||||||
|
|
||||||
if choice is None:
|
|
||||||
console.print(
|
|
||||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
|
||||||
)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
return choice
|
|
||||||
|
|
||||||
|
|
||||||
def select_deep_thinking_agent() -> str:
|
def select_deep_thinking_agent(provider) -> str:
|
||||||
"""Select deep thinking llm engine using an interactive selection."""
|
"""Select deep thinking llm engine using an interactive selection."""
|
||||||
|
return _select_model(provider, "deep")
|
||||||
|
|
||||||
# Define deep thinking llm engine options with their corresponding model names
|
def select_llm_provider() -> tuple[str, str | None]:
|
||||||
DEEP_AGENT_OPTIONS = [
|
"""Select the LLM provider and its API endpoint."""
|
||||||
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
|
# Ollama users can point at a remote ollama-serve via OLLAMA_BASE_URL
|
||||||
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
|
# (convention from the broader Ollama ecosystem); falls back to the
|
||||||
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
|
# localhost default when unset.
|
||||||
("o4-mini - Specialized reasoning model (compact)", "o4-mini"),
|
ollama_url = os.environ.get("OLLAMA_BASE_URL") or "http://localhost:11434/v1"
|
||||||
("o3-mini - Advanced reasoning model (lightweight)", "o3-mini"),
|
# (display_name, provider_key, base_url)
|
||||||
("o3 - Full advanced reasoning model", "o3"),
|
PROVIDERS = [
|
||||||
("o1 - Premier reasoning and problem-solving model", "o1"),
|
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||||
|
("Google", "google", None),
|
||||||
|
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
||||||
|
("xAI", "xai", "https://api.x.ai/v1"),
|
||||||
|
("DeepSeek", "deepseek", "https://api.deepseek.com"),
|
||||||
|
("Qwen", "qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"),
|
||||||
|
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||||
|
("MiniMax", "minimax", "https://api.minimax.io/v1"),
|
||||||
|
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||||
|
("Azure OpenAI", "azure", None),
|
||||||
|
("Ollama", "ollama", ollama_url),
|
||||||
]
|
]
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
"Select Your [Deep-Thinking LLM Engine]:",
|
"Select your LLM Provider:",
|
||||||
choices=[
|
choices=[
|
||||||
questionary.Choice(display, value=value)
|
questionary.Choice(display, value=(provider_key, url))
|
||||||
for display, value in DEEP_AGENT_OPTIONS
|
for display, provider_key, url in PROVIDERS
|
||||||
],
|
],
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
style=questionary.Style(
|
style=questionary.Style(
|
||||||
@@ -189,7 +270,255 @@ def select_deep_thinking_agent() -> str:
|
|||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
provider, url = choice
|
||||||
|
return provider, url
|
||||||
|
|
||||||
|
|
||||||
|
def ask_openai_reasoning_effort() -> str:
|
||||||
|
"""Ask for OpenAI reasoning effort level."""
|
||||||
|
choices = [
|
||||||
|
questionary.Choice("Medium (Default)", "medium"),
|
||||||
|
questionary.Choice("High (More thorough)", "high"),
|
||||||
|
questionary.Choice("Low (Faster)", "low"),
|
||||||
|
]
|
||||||
|
return questionary.select(
|
||||||
|
"Select Reasoning Effort:",
|
||||||
|
choices=choices,
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:cyan noinherit"),
|
||||||
|
("highlighted", "fg:cyan noinherit"),
|
||||||
|
("pointer", "fg:cyan noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ask_anthropic_effort() -> str | None:
|
||||||
|
"""Ask for Anthropic effort level.
|
||||||
|
|
||||||
|
Controls token usage and response thoroughness on Claude 4.5 / 4.6 / 4.7
|
||||||
|
models. The API also accepts "max"; we expose low/medium/high as the
|
||||||
|
common selection range.
|
||||||
|
"""
|
||||||
|
return questionary.select(
|
||||||
|
"Select Effort Level:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice("High (recommended)", "high"),
|
||||||
|
questionary.Choice("Medium (balanced)", "medium"),
|
||||||
|
questionary.Choice("Low (faster, cheaper)", "low"),
|
||||||
|
],
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:cyan noinherit"),
|
||||||
|
("highlighted", "fg:cyan noinherit"),
|
||||||
|
("pointer", "fg:cyan noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ask_gemini_thinking_config() -> str | None:
|
||||||
|
"""Ask for Gemini thinking configuration.
|
||||||
|
|
||||||
|
Returns thinking_level: "high" or "minimal".
|
||||||
|
Client maps to appropriate API param based on model series.
|
||||||
|
"""
|
||||||
|
return questionary.select(
|
||||||
|
"Select Thinking Mode:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice("Enable Thinking (recommended)", "high"),
|
||||||
|
questionary.Choice("Minimal/Disable Thinking", "minimal"),
|
||||||
|
],
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:green noinherit"),
|
||||||
|
("highlighted", "fg:green noinherit"),
|
||||||
|
("pointer", "fg:green noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ask_glm_region() -> tuple[str, str]:
|
||||||
|
"""Ask which GLM platform (Z.AI international vs BigModel China) to use.
|
||||||
|
|
||||||
|
Zhipu serves the same GLM models under two brands with separate
|
||||||
|
accounts; keys aren't interchangeable. Returns (provider_key, backend_url).
|
||||||
|
"""
|
||||||
|
return questionary.select(
|
||||||
|
"Select GLM platform:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice(
|
||||||
|
"Z.AI — api.z.ai (international, uses ZHIPU_API_KEY)",
|
||||||
|
value=("glm", "https://api.z.ai/api/paas/v4/"),
|
||||||
|
),
|
||||||
|
questionary.Choice(
|
||||||
|
"BigModel — open.bigmodel.cn (China, uses ZHIPU_CN_API_KEY)",
|
||||||
|
value=("glm-cn", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:cyan noinherit"),
|
||||||
|
("highlighted", "fg:cyan noinherit"),
|
||||||
|
("pointer", "fg:cyan noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ask_qwen_region() -> tuple[str, str]:
|
||||||
|
"""Ask which Qwen region (international vs China) to use.
|
||||||
|
|
||||||
|
Alibaba DashScope exposes two endpoints with separate accounts —
|
||||||
|
a key from one region does NOT authenticate against the other
|
||||||
|
(fixes #758). Returns (provider_key, backend_url).
|
||||||
|
"""
|
||||||
|
return questionary.select(
|
||||||
|
"Select Qwen region:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice(
|
||||||
|
"International — dashscope-intl.aliyuncs.com (uses DASHSCOPE_API_KEY)",
|
||||||
|
value=("qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"),
|
||||||
|
),
|
||||||
|
questionary.Choice(
|
||||||
|
"China — dashscope.aliyuncs.com (uses DASHSCOPE_CN_API_KEY)",
|
||||||
|
value=("qwen-cn", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:cyan noinherit"),
|
||||||
|
("highlighted", "fg:cyan noinherit"),
|
||||||
|
("pointer", "fg:cyan noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ask_minimax_region() -> tuple[str, str]:
|
||||||
|
"""Ask which MiniMax region (global vs China) to use.
|
||||||
|
|
||||||
|
MiniMax exposes two endpoints with separate accounts — a key from
|
||||||
|
one region does NOT authenticate against the other. Returns
|
||||||
|
(provider_key, backend_url).
|
||||||
|
"""
|
||||||
|
return questionary.select(
|
||||||
|
"Select MiniMax region:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice(
|
||||||
|
"Global — api.minimax.io (uses MINIMAX_API_KEY)",
|
||||||
|
value=("minimax", "https://api.minimax.io/v1"),
|
||||||
|
),
|
||||||
|
questionary.Choice(
|
||||||
|
"China — api.minimaxi.com (uses MINIMAX_CN_API_KEY)",
|
||||||
|
value=("minimax-cn", "https://api.minimaxi.com/v1"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:cyan noinherit"),
|
||||||
|
("highlighted", "fg:cyan noinherit"),
|
||||||
|
("pointer", "fg:cyan noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_ollama_endpoint(url: str) -> None:
|
||||||
|
"""Show the resolved Ollama endpoint after provider selection.
|
||||||
|
|
||||||
|
Surfaces three things the user benefits from seeing before model
|
||||||
|
selection: which URL we'll actually hit, where it came from
|
||||||
|
(\`OLLAMA_BASE_URL\` vs default), and a soft warning if the URL is
|
||||||
|
missing the scheme/port that ollama-serve expects. The warning is
|
||||||
|
advisory only — we don't reject malformed input, since the user may
|
||||||
|
be doing something deliberately unusual (e.g. a reverse-proxy path).
|
||||||
|
"""
|
||||||
|
from_env = os.environ.get("OLLAMA_BASE_URL")
|
||||||
|
origin = " (from OLLAMA_BASE_URL)" if from_env and from_env == url else ""
|
||||||
|
console.print(f"[green]✓ Using Ollama at {url}{origin}[/green]")
|
||||||
|
|
||||||
|
if not url.startswith(("http://", "https://")):
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Note: {url!r} is missing a scheme. "
|
||||||
|
f"Ollama-serve typically expects a URL like "
|
||||||
|
f"http://<host>:11434/v1.[/yellow]"
|
||||||
|
)
|
||||||
|
elif ":11434" not in url and "://localhost" not in url and "://127.0.0.1" not in url:
|
||||||
|
# Soft hint when the port differs from the ollama-serve default
|
||||||
|
# and the host isn't local (where users sometimes proxy on :80).
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Note: {url!r} doesn't include port 11434. "
|
||||||
|
f"Make sure your remote ollama-serve listens on the port "
|
||||||
|
f"shown above.[/yellow]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_api_key(provider: str) -> Optional[str]:
|
||||||
|
"""Make sure the API key for `provider` is available in the environment.
|
||||||
|
|
||||||
|
If the env var is already set, returns its value untouched. Otherwise
|
||||||
|
interactively prompts the user, persists the value to the project's
|
||||||
|
.env file via python-dotenv's set_key (creating .env if needed), and
|
||||||
|
exports it into os.environ so the current process picks it up.
|
||||||
|
|
||||||
|
Returns None for providers that do not require a key (e.g. ollama)
|
||||||
|
and for providers not found in the canonical mapping.
|
||||||
|
"""
|
||||||
|
env_var = get_api_key_env(provider)
|
||||||
|
if env_var is None:
|
||||||
|
return None # ollama / unknown — no key check possible
|
||||||
|
|
||||||
|
existing = os.environ.get(env_var)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"\n[yellow]{env_var} is not set in your environment.[/yellow]"
|
||||||
|
)
|
||||||
|
key = questionary.password(
|
||||||
|
f"Paste your {env_var} (will be saved to .env):",
|
||||||
|
style=questionary.Style([
|
||||||
|
("text", "fg:cyan"),
|
||||||
|
("highlighted", "noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
if not key:
|
||||||
|
console.print(
|
||||||
|
f"[red]Skipped. API calls will fail until {env_var} is set.[/red]"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
env_path = find_dotenv(usecwd=True) or str(Path.cwd() / ".env")
|
||||||
|
Path(env_path).touch(exist_ok=True)
|
||||||
|
set_key(env_path, env_var, key)
|
||||||
|
os.environ[env_var] = key
|
||||||
|
console.print(f"[green]Saved {env_var} to {env_path}[/green]")
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def ask_output_language() -> str:
|
||||||
|
"""Ask for report output language."""
|
||||||
|
choice = questionary.select(
|
||||||
|
"Select Output Language:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice("English (default)", "English"),
|
||||||
|
questionary.Choice("Chinese (中文)", "Chinese"),
|
||||||
|
questionary.Choice("Japanese (日本語)", "Japanese"),
|
||||||
|
questionary.Choice("Korean (한국어)", "Korean"),
|
||||||
|
questionary.Choice("Hindi (हिन्दी)", "Hindi"),
|
||||||
|
questionary.Choice("Spanish (Español)", "Spanish"),
|
||||||
|
questionary.Choice("Portuguese (Português)", "Portuguese"),
|
||||||
|
questionary.Choice("French (Français)", "French"),
|
||||||
|
questionary.Choice("German (Deutsch)", "German"),
|
||||||
|
questionary.Choice("Arabic (العربية)", "Arabic"),
|
||||||
|
questionary.Choice("Russian (Русский)", "Russian"),
|
||||||
|
questionary.Choice("Custom language", "custom"),
|
||||||
|
],
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:yellow noinherit"),
|
||||||
|
("highlighted", "fg:yellow noinherit"),
|
||||||
|
("pointer", "fg:yellow noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if choice == "custom":
|
||||||
|
return questionary.text(
|
||||||
|
"Enter language name (e.g. Turkish, Vietnamese, Thai, Indonesian):",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a language name.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|||||||
35
docker-compose.yml
Normal file
35
docker-compose.yml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
services:
|
||||||
|
tradingagents:
|
||||||
|
build: .
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- tradingagents_data:/home/appuser/.tradingagents
|
||||||
|
tty: true
|
||||||
|
stdin_open: true
|
||||||
|
|
||||||
|
ollama:
|
||||||
|
image: ollama/ollama:latest
|
||||||
|
volumes:
|
||||||
|
- ollama_data:/root/.ollama
|
||||||
|
profiles:
|
||||||
|
- ollama
|
||||||
|
|
||||||
|
tradingagents-ollama:
|
||||||
|
build: .
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
- LLM_PROVIDER=ollama
|
||||||
|
volumes:
|
||||||
|
- tradingagents_data:/home/appuser/.tradingagents
|
||||||
|
depends_on:
|
||||||
|
- ollama
|
||||||
|
tty: true
|
||||||
|
stdin_open: true
|
||||||
|
profiles:
|
||||||
|
- ollama
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
tradingagents_data:
|
||||||
|
ollama_data:
|
||||||
10
main.py
10
main.py
@@ -1,12 +1,12 @@
|
|||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
# Create a custom config
|
# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides
|
||||||
|
# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.),
|
||||||
|
# so users can switch models or endpoints purely via .env without
|
||||||
|
# editing this script. Override individual keys here only when you
|
||||||
|
# want a hard-coded value that should ignore the environment.
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
|
|
||||||
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
|
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
|
||||||
config["online_tools"] = True # Increase debate rounds
|
|
||||||
|
|
||||||
# Initialize with custom config
|
# Initialize with custom config
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
ta = TradingAgentsGraph(debug=True, config=config)
|
||||||
|
|||||||
54
pyproject.toml
Normal file
54
pyproject.toml
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "tradingagents"
|
||||||
|
version = "0.2.5"
|
||||||
|
description = "TradingAgents: Multi-Agents LLM Financial Trading Framework"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"langchain-core>=0.3.81",
|
||||||
|
"backtrader>=1.9.78.123",
|
||||||
|
"langchain-anthropic>=0.3.15",
|
||||||
|
"langchain-experimental>=0.3.4",
|
||||||
|
"langchain-google-genai>=4.0.0",
|
||||||
|
"langchain-openai>=0.3.23",
|
||||||
|
"langgraph>=0.4.8",
|
||||||
|
"langgraph-checkpoint-sqlite>=2.0.0",
|
||||||
|
"pandas>=2.3.0",
|
||||||
|
"parsel>=1.10.0",
|
||||||
|
"pytz>=2025.2",
|
||||||
|
"questionary>=2.1.0",
|
||||||
|
"redis>=6.2.0",
|
||||||
|
"requests>=2.32.4",
|
||||||
|
"rich>=14.0.0",
|
||||||
|
"typer>=0.21.0",
|
||||||
|
"setuptools>=80.9.0",
|
||||||
|
"stockstats>=0.6.5",
|
||||||
|
"tqdm>=4.67.1",
|
||||||
|
"typing-extensions>=4.14.0",
|
||||||
|
"yfinance>=0.2.63",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
tradingagents = "cli.main:app"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
include = ["tradingagents*", "cli*"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
cli = ["static/*"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
addopts = "-ra --strict-markers"
|
||||||
|
markers = [
|
||||||
|
"unit: fast isolated unit tests",
|
||||||
|
"integration: tests requiring external services",
|
||||||
|
"smoke: quick sanity-check tests",
|
||||||
|
]
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::DeprecationWarning",
|
||||||
|
]
|
||||||
@@ -1,24 +1 @@
|
|||||||
typing-extensions
|
.
|
||||||
langchain-openai
|
|
||||||
langchain-experimental
|
|
||||||
pandas
|
|
||||||
yfinance
|
|
||||||
praw
|
|
||||||
feedparser
|
|
||||||
stockstats
|
|
||||||
eodhd
|
|
||||||
langgraph
|
|
||||||
chromadb
|
|
||||||
setuptools
|
|
||||||
backtrader
|
|
||||||
akshare
|
|
||||||
tushare
|
|
||||||
finnhub-python
|
|
||||||
parsel
|
|
||||||
requests
|
|
||||||
tqdm
|
|
||||||
pytz
|
|
||||||
redis
|
|
||||||
chainlit
|
|
||||||
rich
|
|
||||||
questionary
|
|
||||||
|
|||||||
176
scripts/smoke_structured_output.py
Normal file
176
scripts/smoke_structured_output.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""End-to-end smoke for structured-output agents against a real LLM provider.
|
||||||
|
|
||||||
|
Runs the three decision-making agents (Research Manager, Trader, Portfolio
|
||||||
|
Manager) directly with their structured-output bindings and prints the
|
||||||
|
typed Pydantic instance + the rendered markdown for each. Use this to
|
||||||
|
verify a provider's native structured-output mode (json_schema for
|
||||||
|
OpenAI / xAI / DeepSeek / Qwen / GLM, response_schema for Gemini, tool-use
|
||||||
|
for Anthropic) returns clean instances on the schemas we ship.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
OPENAI_API_KEY=... python scripts/smoke_structured_output.py openai
|
||||||
|
GOOGLE_API_KEY=... python scripts/smoke_structured_output.py google
|
||||||
|
ANTHROPIC_API_KEY=... python scripts/smoke_structured_output.py anthropic
|
||||||
|
DEEPSEEK_API_KEY=... python scripts/smoke_structured_output.py deepseek
|
||||||
|
|
||||||
|
The script does NOT call propagate(), to keep the surface tight and the
|
||||||
|
cost low — it exercises only the three structured-output calls we just
|
||||||
|
added, plus the heuristic SignalProcessor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||||
|
from tradingagents.agents.managers.research_manager import create_research_manager
|
||||||
|
from tradingagents.agents.trader.trader import create_trader
|
||||||
|
from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_DEFAULTS = {
|
||||||
|
"openai": ("gpt-5.4-mini", None),
|
||||||
|
"google": ("gemini-2.5-flash", None),
|
||||||
|
"anthropic": ("claude-sonnet-4-6", None),
|
||||||
|
"deepseek": ("deepseek-chat", None),
|
||||||
|
"qwen": ("qwen-plus", None),
|
||||||
|
"glm": ("glm-5", None),
|
||||||
|
"xai": ("grok-4", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Minimal but realistic state for the three agents.
|
||||||
|
DEBATE_HISTORY = """
|
||||||
|
Bull Analyst: NVDA's data-center revenue grew 60% YoY last quarter, driven by
|
||||||
|
Blackwell ramp; sovereign AI deals with multiple governments add a $40B+
|
||||||
|
multi-year tailwind. Margins remain above peer average.
|
||||||
|
|
||||||
|
Bear Analyst: Concentration risk is real — top three customers are >40% of
|
||||||
|
revenue. Any pause in hyperscaler capex would compress the multiple. China
|
||||||
|
export restrictions still cap a meaningful portion of demand.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rm_state():
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"investment_debate_state": {
|
||||||
|
"history": DEBATE_HISTORY,
|
||||||
|
"bull_history": "Bull Analyst: NVDA's data-center revenue grew 60% YoY...",
|
||||||
|
"bear_history": "Bear Analyst: Concentration risk is real...",
|
||||||
|
"current_response": "",
|
||||||
|
"judge_decision": "",
|
||||||
|
"count": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_trader_state(investment_plan: str):
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"investment_plan": investment_plan,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pm_state(investment_plan: str, trader_plan: str):
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"past_context": "",
|
||||||
|
"risk_debate_state": {
|
||||||
|
"history": "Aggressive: lean in. Conservative: trim. Neutral: balanced sizing.",
|
||||||
|
"aggressive_history": "Aggressive: ...",
|
||||||
|
"conservative_history": "Conservative: ...",
|
||||||
|
"neutral_history": "Neutral: ...",
|
||||||
|
"judge_decision": "",
|
||||||
|
"current_aggressive_response": "",
|
||||||
|
"current_conservative_response": "",
|
||||||
|
"current_neutral_response": "",
|
||||||
|
"count": 1,
|
||||||
|
},
|
||||||
|
"market_report": "Market report.",
|
||||||
|
"sentiment_report": "Sentiment report.",
|
||||||
|
"news_report": "News report.",
|
||||||
|
"fundamentals_report": "Fundamentals report.",
|
||||||
|
"investment_plan": investment_plan,
|
||||||
|
"trader_investment_plan": trader_plan,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _print_section(title: str, content: str) -> None:
|
||||||
|
bar = "=" * 70
|
||||||
|
print(f"\n{bar}\n{title}\n{bar}\n{content}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument("provider", choices=list(PROVIDER_DEFAULTS.keys()))
|
||||||
|
parser.add_argument("--deep-model", default=None, help="Override deep_think_llm")
|
||||||
|
parser.add_argument("--quick-model", default=None, help="Override quick_think_llm")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
default_model, _ = PROVIDER_DEFAULTS[args.provider]
|
||||||
|
deep_model = args.deep_model or default_model
|
||||||
|
quick_model = args.quick_model or default_model
|
||||||
|
|
||||||
|
print(f"Provider: {args.provider}")
|
||||||
|
print(f"Deep model: {deep_model}")
|
||||||
|
print(f"Quick model: {quick_model}")
|
||||||
|
|
||||||
|
# Build the LLM clients via the framework's factory.
|
||||||
|
deep_client = create_llm_client(provider=args.provider, model=deep_model)
|
||||||
|
quick_client = create_llm_client(provider=args.provider, model=quick_model)
|
||||||
|
deep_llm = deep_client.get_llm()
|
||||||
|
quick_llm = quick_client.get_llm()
|
||||||
|
|
||||||
|
# 1) Research Manager
|
||||||
|
rm = create_research_manager(deep_llm)
|
||||||
|
rm_result = rm(_make_rm_state())
|
||||||
|
investment_plan = rm_result["investment_plan"]
|
||||||
|
_print_section("[1] Research Manager — investment_plan", investment_plan)
|
||||||
|
|
||||||
|
# 2) Trader (consumes RM's plan)
|
||||||
|
trader = create_trader(quick_llm)
|
||||||
|
trader_result = trader(_make_trader_state(investment_plan))
|
||||||
|
trader_plan = trader_result["trader_investment_plan"]
|
||||||
|
_print_section("[2] Trader — trader_investment_plan", trader_plan)
|
||||||
|
|
||||||
|
# 3) Portfolio Manager (consumes both)
|
||||||
|
pm = create_portfolio_manager(deep_llm)
|
||||||
|
pm_result = pm(_make_pm_state(investment_plan, trader_plan))
|
||||||
|
final_decision = pm_result["final_trade_decision"]
|
||||||
|
_print_section("[3] Portfolio Manager — final_trade_decision", final_decision)
|
||||||
|
|
||||||
|
# 4) SignalProcessor extracts the rating with zero LLM calls.
|
||||||
|
sp = SignalProcessor()
|
||||||
|
rating = sp.process_signal(final_decision)
|
||||||
|
_print_section("[4] SignalProcessor → rating", rating)
|
||||||
|
|
||||||
|
# 5) Lightweight checks: each rendered output should carry the expected
|
||||||
|
# section headers so downstream consumers (memory log, CLI display,
|
||||||
|
# saved reports) keep working.
|
||||||
|
checks = [
|
||||||
|
("Research Manager", investment_plan, ["**Recommendation**:"]),
|
||||||
|
("Trader", trader_plan, ["**Action**:", "FINAL TRANSACTION PROPOSAL:"]),
|
||||||
|
("Portfolio Manager", final_decision, ["**Rating**:", "**Executive Summary**:", "**Investment Thesis**:"]),
|
||||||
|
]
|
||||||
|
print("\n" + "=" * 70 + "\nStructure checks\n" + "=" * 70)
|
||||||
|
failures = 0
|
||||||
|
for name, text, required in checks:
|
||||||
|
for marker in required:
|
||||||
|
ok = marker in text
|
||||||
|
print(f" {'PASS' if ok else 'FAIL'} {name}: contains {marker!r}")
|
||||||
|
failures += int(not ok)
|
||||||
|
|
||||||
|
print()
|
||||||
|
if failures:
|
||||||
|
print(f"Smoke FAILED: {failures} structure check(s) missing.")
|
||||||
|
return 1
|
||||||
|
print("Smoke PASSED: structured output → rendered markdown chain works for", args.provider)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
43
setup.py
43
setup.py
@@ -1,43 +0,0 @@
|
|||||||
"""
|
|
||||||
Setup script for the TradingAgents package.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from setuptools import setup, find_packages
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="tradingagents",
|
|
||||||
version="0.1.0",
|
|
||||||
description="Multi-Agents LLM Financial Trading Framework",
|
|
||||||
author="TradingAgents Team",
|
|
||||||
author_email="yijia.xiao@cs.ucla.edu",
|
|
||||||
url="https://github.com/TauricResearch",
|
|
||||||
packages=find_packages(),
|
|
||||||
install_requires=[
|
|
||||||
"langchain>=0.1.0",
|
|
||||||
"langchain-openai>=0.0.2",
|
|
||||||
"langchain-experimental>=0.0.40",
|
|
||||||
"langgraph>=0.0.20",
|
|
||||||
"numpy>=1.24.0",
|
|
||||||
"pandas>=2.0.0",
|
|
||||||
"praw>=7.7.0",
|
|
||||||
"stockstats>=0.5.4",
|
|
||||||
"yfinance>=0.2.31",
|
|
||||||
"typer>=0.9.0",
|
|
||||||
"rich>=13.0.0",
|
|
||||||
"questionary>=2.0.1",
|
|
||||||
],
|
|
||||||
python_requires=">=3.10",
|
|
||||||
entry_points={
|
|
||||||
"console_scripts": [
|
|
||||||
"tradingagents=cli.main:app",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
classifiers=[
|
|
||||||
"Development Status :: 3 - Alpha",
|
|
||||||
"Intended Audience :: Financial and Trading Industry",
|
|
||||||
"License :: OSI Approved :: Apache Software License",
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Topic :: Office/Business :: Financial :: Investment",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
11
test.py
Normal file
11
test.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import time
|
||||||
|
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
||||||
|
|
||||||
|
print("Testing optimized implementation with 30-day lookback:")
|
||||||
|
start_time = time.time()
|
||||||
|
result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(f"Execution time: {end_time - start_time:.2f} seconds")
|
||||||
|
print(f"Result length: {len(result)} characters")
|
||||||
|
print(result)
|
||||||
46
tests/conftest.py
Normal file
46
tests/conftest.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Shared pytest fixtures that prevent CI hangs when API keys are absent."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
for marker in ("unit", "integration", "smoke"):
|
||||||
|
config.addinivalue_line("markers", f"{marker}: {marker}-level tests")
|
||||||
|
|
||||||
|
|
||||||
|
_API_KEY_ENV_VARS = (
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"GOOGLE_API_KEY",
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"DEEPSEEK_API_KEY",
|
||||||
|
"DASHSCOPE_API_KEY",
|
||||||
|
"DASHSCOPE_CN_API_KEY",
|
||||||
|
"ZHIPU_API_KEY",
|
||||||
|
"ZHIPU_CN_API_KEY",
|
||||||
|
"MINIMAX_API_KEY",
|
||||||
|
"MINIMAX_CN_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"AZURE_OPENAI_API_KEY",
|
||||||
|
"ALPHA_VANTAGE_API_KEY",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _dummy_api_keys(monkeypatch):
|
||||||
|
for env_var in _API_KEY_ENV_VARS:
|
||||||
|
monkeypatch.setenv(env_var, os.environ.get(env_var, "placeholder"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_llm_client():
|
||||||
|
client = MagicMock()
|
||||||
|
client.get_llm.return_value = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"tradingagents.llm_clients.factory.create_llm_client",
|
||||||
|
return_value=client,
|
||||||
|
):
|
||||||
|
yield client
|
||||||
149
tests/test_api_key_env.py
Normal file
149
tests/test_api_key_env.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""Tests for the canonical provider->env-var mapping and the CLI key-prompt helper."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Mapping coverage -----------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_every_select_llm_provider_choice_has_an_entry():
|
||||||
|
"""select_llm_provider() must not present a provider the mapping doesn't know about."""
|
||||||
|
# Mirrors the dropdown order in cli/utils.select_llm_provider so the two
|
||||||
|
# stay in lockstep. Region-specific keys (qwen-cn / minimax-cn / glm-cn)
|
||||||
|
# are reached via the secondary region prompt, so they must also be present.
|
||||||
|
expected = {
|
||||||
|
"openai", "google", "anthropic", "xai", "deepseek",
|
||||||
|
"qwen", "qwen-cn",
|
||||||
|
"glm", "glm-cn",
|
||||||
|
"minimax", "minimax-cn",
|
||||||
|
"openrouter", "azure", "ollama",
|
||||||
|
}
|
||||||
|
assert expected.issubset(PROVIDER_API_KEY_ENV.keys())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider,env_var",
|
||||||
|
[
|
||||||
|
("openai", "OPENAI_API_KEY"),
|
||||||
|
("anthropic", "ANTHROPIC_API_KEY"),
|
||||||
|
("google", "GOOGLE_API_KEY"),
|
||||||
|
("azure", "AZURE_OPENAI_API_KEY"),
|
||||||
|
("xai", "XAI_API_KEY"),
|
||||||
|
("deepseek", "DEEPSEEK_API_KEY"),
|
||||||
|
("qwen", "DASHSCOPE_API_KEY"),
|
||||||
|
("qwen-cn", "DASHSCOPE_CN_API_KEY"),
|
||||||
|
("glm", "ZHIPU_API_KEY"),
|
||||||
|
("glm-cn", "ZHIPU_CN_API_KEY"),
|
||||||
|
("minimax", "MINIMAX_API_KEY"),
|
||||||
|
("minimax-cn", "MINIMAX_CN_API_KEY"),
|
||||||
|
("openrouter", "OPENROUTER_API_KEY"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_known_providers_resolve(provider, env_var):
|
||||||
|
assert get_api_key_env(provider) == env_var
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_has_no_key():
|
||||||
|
assert get_api_key_env("ollama") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_provider_returns_none():
|
||||||
|
assert get_api_key_env("not-a-real-provider") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_insensitive_lookup():
|
||||||
|
assert get_api_key_env("OpenAI") == "OPENAI_API_KEY"
|
||||||
|
assert get_api_key_env("QWEN-CN") == "DASHSCOPE_CN_API_KEY"
|
||||||
|
|
||||||
|
|
||||||
|
# ---- ensure_api_key behavior ---------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_utils(monkeypatch):
|
||||||
|
"""Import cli.utils with a fresh environment so module-level state is consistent."""
|
||||||
|
import importlib
|
||||||
|
import cli.utils as cli_utils_module
|
||||||
|
return importlib.reload(cli_utils_module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_returns_existing(monkeypatch, cli_utils):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-already-set")
|
||||||
|
result = cli_utils.ensure_api_key("openai")
|
||||||
|
assert result == "sk-already-set"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_no_op_for_ollama(monkeypatch, cli_utils):
|
||||||
|
# Even with no env var set, ollama should not prompt and should return None.
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
|
with patch.object(cli_utils, "questionary") as mock_q:
|
||||||
|
result = cli_utils.ensure_api_key("ollama")
|
||||||
|
assert result is None
|
||||||
|
mock_q.password.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_unknown_provider_no_prompt(monkeypatch, cli_utils):
|
||||||
|
with patch.object(cli_utils, "questionary") as mock_q:
|
||||||
|
result = cli_utils.ensure_api_key("totally-fake-provider")
|
||||||
|
assert result is None
|
||||||
|
mock_q.password.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_prompts_and_writes_to_env(monkeypatch, tmp_path, cli_utils):
|
||||||
|
"""When key is missing, user-pasted value must be written to .env AND os.environ."""
|
||||||
|
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-deepseek-test")})()
|
||||||
|
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||||
|
result = cli_utils.ensure_api_key("deepseek")
|
||||||
|
|
||||||
|
assert result == "sk-deepseek-test"
|
||||||
|
assert os.environ["DEEPSEEK_API_KEY"] == "sk-deepseek-test"
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
assert env_file.exists()
|
||||||
|
assert "DEEPSEEK_API_KEY" in env_file.read_text()
|
||||||
|
assert "sk-deepseek-test" in env_file.read_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_user_cancels_returns_none(monkeypatch, tmp_path, cli_utils):
|
||||||
|
"""Empty prompt response (user cancelled) must not write to .env."""
|
||||||
|
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
fake_prompt = type("P", (), {"ask": staticmethod(lambda: None)})()
|
||||||
|
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||||
|
result = cli_utils.ensure_api_key("xai")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert "XAI_API_KEY" not in os.environ
|
||||||
|
# .env may or may not exist depending on find_dotenv's walk, but if it
|
||||||
|
# does it must not contain the key.
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
if env_file.exists():
|
||||||
|
assert "XAI_API_KEY" not in env_file.read_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_updates_existing_env_file(monkeypatch, tmp_path, cli_utils):
|
||||||
|
"""An existing .env with other keys must be preserved on writeback."""
|
||||||
|
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
env_file.write_text("OPENAI_API_KEY=sk-existing\nOTHER=value\n")
|
||||||
|
|
||||||
|
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-openrouter-new")})()
|
||||||
|
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||||
|
cli_utils.ensure_api_key("openrouter")
|
||||||
|
|
||||||
|
content = env_file.read_text()
|
||||||
|
assert "OPENAI_API_KEY" in content and "sk-existing" in content
|
||||||
|
assert "OTHER=value" in content
|
||||||
|
assert "OPENROUTER_API_KEY" in content and "sk-openrouter-new" in content
|
||||||
107
tests/test_capabilities.py
Normal file
107
tests/test_capabilities.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""Unit tests for the LLM capability table."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.capabilities import (
|
||||||
|
ModelCapabilities,
|
||||||
|
get_capabilities,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestExactIdMatches:
|
||||||
|
def test_deepseek_chat_supports_tool_choice(self):
|
||||||
|
caps = get_capabilities("deepseek-chat")
|
||||||
|
assert caps.supports_tool_choice is True
|
||||||
|
|
||||||
|
def test_deepseek_reasoner_rejects_tool_choice(self):
|
||||||
|
caps = get_capabilities("deepseek-reasoner")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
assert caps.requires_reasoning_content_roundtrip is True
|
||||||
|
|
||||||
|
def test_deepseek_v4_flash_rejects_tool_choice(self):
|
||||||
|
caps = get_capabilities("deepseek-v4-flash")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
assert caps.requires_reasoning_content_roundtrip is True
|
||||||
|
|
||||||
|
def test_deepseek_v4_pro_rejects_tool_choice(self):
|
||||||
|
caps = get_capabilities("deepseek-v4-pro")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
assert caps.requires_reasoning_content_roundtrip is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestPatternMatches:
|
||||||
|
"""Forward-compat regex patterns catch unknown DeepSeek and MiniMax variants."""
|
||||||
|
|
||||||
|
def test_future_deepseek_v5_inherits_thinking_quirks(self):
|
||||||
|
caps = get_capabilities("deepseek-v5-flash")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
assert caps.requires_reasoning_content_roundtrip is True
|
||||||
|
|
||||||
|
def test_future_deepseek_v9_inherits_thinking_quirks(self):
|
||||||
|
caps = get_capabilities("deepseek-v9-anything")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
|
||||||
|
def test_reasoner_variant_inherits_thinking_quirks(self):
|
||||||
|
caps = get_capabilities("deepseek-reasoner-pro")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
|
||||||
|
def test_future_minimax_m3_inherits_thinking_quirks(self):
|
||||||
|
caps = get_capabilities("MiniMax-M3")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
|
||||||
|
def test_future_minimax_m4_highspeed_inherits_thinking_quirks(self):
|
||||||
|
caps = get_capabilities("MiniMax-M4-highspeed")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestMinimaxExactMatches:
|
||||||
|
"""MiniMax M2.x models reject langchain's function-spec dict tool_choice
|
||||||
|
(official API enum: none/auto only)."""
|
||||||
|
|
||||||
|
def test_m2_7_rejects_tool_choice(self):
|
||||||
|
caps = get_capabilities("MiniMax-M2.7")
|
||||||
|
assert caps.supports_tool_choice is False
|
||||||
|
assert caps.supports_json_mode is False # only MiniMax-Text-01 supports json_object
|
||||||
|
|
||||||
|
def test_m2_7_highspeed_rejects_tool_choice(self):
|
||||||
|
assert get_capabilities("MiniMax-M2.7-highspeed").supports_tool_choice is False
|
||||||
|
|
||||||
|
def test_m2_1_rejects_tool_choice(self):
|
||||||
|
assert get_capabilities("MiniMax-M2.1").supports_tool_choice is False
|
||||||
|
|
||||||
|
def test_m2_base_rejects_tool_choice(self):
|
||||||
|
assert get_capabilities("MiniMax-M2").supports_tool_choice is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestDefault:
|
||||||
|
"""Unknown / non-DeepSeek models get the permissive default."""
|
||||||
|
|
||||||
|
def test_gpt_default(self):
|
||||||
|
caps = get_capabilities("gpt-4.1")
|
||||||
|
assert caps.supports_tool_choice is True
|
||||||
|
assert caps.preferred_structured_method == "function_calling"
|
||||||
|
|
||||||
|
def test_grok_default(self):
|
||||||
|
caps = get_capabilities("grok-4-0709")
|
||||||
|
assert caps.supports_tool_choice is True
|
||||||
|
|
||||||
|
def test_unknown_model_default(self):
|
||||||
|
caps = get_capabilities("totally-made-up-model-id")
|
||||||
|
assert caps.supports_tool_choice is True
|
||||||
|
|
||||||
|
def test_exact_match_precedes_pattern(self):
|
||||||
|
"""deepseek-chat must NOT match the v\\d regex."""
|
||||||
|
caps = get_capabilities("deepseek-chat")
|
||||||
|
assert caps.supports_tool_choice is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_capabilities_dataclass_is_frozen():
|
||||||
|
"""Capability rows are immutable so they can be safely shared."""
|
||||||
|
caps = get_capabilities("deepseek-chat")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
caps.supports_tool_choice = False # type: ignore[misc]
|
||||||
147
tests/test_checkpoint_resume.py
Normal file
147
tests/test_checkpoint_resume.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
from langgraph.graph import END, StateGraph
|
||||||
|
|
||||||
|
from tradingagents.graph.checkpointer import (
|
||||||
|
checkpoint_step,
|
||||||
|
clear_checkpoint,
|
||||||
|
get_checkpointer,
|
||||||
|
has_checkpoint,
|
||||||
|
thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mutable flag to simulate crash on first run
|
||||||
|
_should_crash = False
|
||||||
|
|
||||||
|
|
||||||
|
class _SimpleState(TypedDict):
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
def _node_a(state: _SimpleState) -> dict:
|
||||||
|
return {"count": state["count"] + 1}
|
||||||
|
|
||||||
|
|
||||||
|
def _node_b(state: _SimpleState) -> dict:
|
||||||
|
if _should_crash:
|
||||||
|
raise RuntimeError("simulated mid-analysis crash")
|
||||||
|
return {"count": state["count"] + 10}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_graph() -> StateGraph:
|
||||||
|
builder = StateGraph(_SimpleState)
|
||||||
|
builder.add_node("analyst", _node_a)
|
||||||
|
builder.add_node("trader", _node_b)
|
||||||
|
builder.set_entry_point("analyst")
|
||||||
|
builder.add_edge("analyst", "trader")
|
||||||
|
builder.add_edge("trader", END)
|
||||||
|
return builder
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointResume(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.mkdtemp()
|
||||||
|
self.ticker = "TEST"
|
||||||
|
self.date = "2026-04-20"
|
||||||
|
|
||||||
|
def test_crash_and_resume(self):
|
||||||
|
"""Crash at 'trader' node, then resume from checkpoint."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
tid = thread_id(self.ticker, self.date)
|
||||||
|
cfg = {"configurable": {"thread_id": tid}}
|
||||||
|
|
||||||
|
# Run 1: crash at trader node
|
||||||
|
_should_crash = True
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
# Checkpoint should exist at step 1 (analyst completed)
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
|
||||||
|
self.assertEqual(step, 1)
|
||||||
|
|
||||||
|
# Run 2: resume — trader succeeds this time
|
||||||
|
_should_crash = False
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke(None, config=cfg)
|
||||||
|
|
||||||
|
# analyst added 1, trader added 10 → 11
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
def test_clear_checkpoint_allows_fresh_start(self):
|
||||||
|
"""After clearing, the graph starts from scratch."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
tid = thread_id(self.ticker, self.date)
|
||||||
|
cfg = {"configurable": {"thread_id": tid}}
|
||||||
|
|
||||||
|
# Create a checkpoint by crashing
|
||||||
|
_should_crash = True
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# Clear it
|
||||||
|
clear_checkpoint(self.tmpdir, self.ticker, self.date)
|
||||||
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# Fresh run succeeds from scratch
|
||||||
|
_should_crash = False
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_date_starts_fresh(self):
|
||||||
|
"""A different date must NOT resume from an existing checkpoint."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
date2 = "2026-04-21"
|
||||||
|
|
||||||
|
# Run with date1 — crash to leave a checkpoint
|
||||||
|
_should_crash = True
|
||||||
|
tid1 = thread_id(self.ticker, self.date)
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
|
||||||
|
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# date2 should have no checkpoint
|
||||||
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
|
||||||
|
|
||||||
|
# Run with date2 — should start fresh and succeed
|
||||||
|
_should_crash = False
|
||||||
|
tid2 = thread_id(self.ticker, date2)
|
||||||
|
self.assertNotEqual(tid1, tid2)
|
||||||
|
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
|
||||||
|
|
||||||
|
# Fresh run: analyst +1, trader +10 = 11
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
# Original date checkpoint still exists (untouched)
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
61
tests/test_dataflows_config.py
Normal file
61
tests/test_dataflows_config.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""Config isolation: get/set must not leak nested-dict references."""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
from tradingagents.dataflows.config import get_config, set_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class DataflowsConfigIsolationTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
set_config(copy.deepcopy(default_config.DEFAULT_CONFIG))
|
||||||
|
|
||||||
|
def test_get_config_returns_deep_copy(self):
|
||||||
|
cfg = get_config()
|
||||||
|
cfg["data_vendors"]["core_stock_apis"] = "alpha_vantage"
|
||||||
|
cfg["tool_vendors"]["get_stock_data"] = "alpha_vantage"
|
||||||
|
|
||||||
|
fresh = get_config()
|
||||||
|
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "yfinance")
|
||||||
|
self.assertNotIn("get_stock_data", fresh["tool_vendors"])
|
||||||
|
|
||||||
|
def test_set_config_does_not_alias_caller_nested_dicts(self):
|
||||||
|
custom = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
custom["data_vendors"]["core_stock_apis"] = "alpha_vantage"
|
||||||
|
custom["tool_vendors"]["get_stock_data"] = "alpha_vantage"
|
||||||
|
|
||||||
|
set_config(custom)
|
||||||
|
|
||||||
|
custom["data_vendors"]["core_stock_apis"] = "yfinance"
|
||||||
|
custom["tool_vendors"]["get_stock_data"] = "yfinance"
|
||||||
|
|
||||||
|
fresh = get_config()
|
||||||
|
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
|
||||||
|
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
|
||||||
|
|
||||||
|
def test_partial_nested_update_preserves_existing_defaults(self):
|
||||||
|
set_config(
|
||||||
|
{
|
||||||
|
"data_vendors": {
|
||||||
|
"core_stock_apis": "alpha_vantage",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
fresh = get_config()
|
||||||
|
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
|
||||||
|
self.assertEqual(fresh["data_vendors"]["technical_indicators"], "yfinance")
|
||||||
|
self.assertEqual(fresh["data_vendors"]["fundamental_data"], "yfinance")
|
||||||
|
self.assertEqual(fresh["data_vendors"]["news_data"], "yfinance")
|
||||||
|
|
||||||
|
def test_nested_dict_updates_merge_one_level_deep(self):
|
||||||
|
set_config({"tool_vendors": {"get_stock_data": "alpha_vantage"}})
|
||||||
|
set_config({"tool_vendors": {"get_news": "alpha_vantage"}})
|
||||||
|
|
||||||
|
fresh = get_config()
|
||||||
|
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
|
||||||
|
self.assertEqual(fresh["tool_vendors"]["get_news"], "alpha_vantage")
|
||||||
240
tests/test_deepseek_reasoning.py
Normal file
240
tests/test_deepseek_reasoning.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Tests for DeepSeekChatOpenAI thinking-mode behaviour.
|
||||||
|
|
||||||
|
Two pieces verified:
|
||||||
|
|
||||||
|
1. ``reasoning_content`` is captured on receive into the AIMessage's
|
||||||
|
``additional_kwargs`` and re-attached on send so DeepSeek's API
|
||||||
|
sees the same value across turns.
|
||||||
|
2. ``with_structured_output`` consults the capability table and
|
||||||
|
suppresses ``tool_choice`` for models that reject it (V4 + reasoner),
|
||||||
|
matching DeepSeek's official tool-calling pattern at
|
||||||
|
https://api-docs.deepseek.com/guides/tool_calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from langchain_core.prompt_values import ChatPromptValue
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.openai_client import (
|
||||||
|
DeepSeekChatOpenAI,
|
||||||
|
NormalizedChatOpenAI,
|
||||||
|
_input_to_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
||||||
|
# (Gemini bot review note: non-list inputs must also work)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestInputToMessages:
|
||||||
|
def test_list_input_returned_as_is(self):
|
||||||
|
msgs = [HumanMessage(content="hi")]
|
||||||
|
assert _input_to_messages(msgs) is msgs
|
||||||
|
|
||||||
|
def test_chat_prompt_value_unwrapped(self):
|
||||||
|
msgs = [HumanMessage(content="hi")]
|
||||||
|
prompt_value = ChatPromptValue(messages=msgs)
|
||||||
|
assert _input_to_messages(prompt_value) == msgs
|
||||||
|
|
||||||
|
def test_string_input_yields_empty_list(self):
|
||||||
|
# A bare string isn't a message-bearing input; the caller's normal
|
||||||
|
# langchain conversion happens upstream of _get_request_payload.
|
||||||
|
assert _input_to_messages("hello") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reasoning content propagation across turns
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestDeepSeekReasoningContent:
|
||||||
|
def _client(self):
|
||||||
|
os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder")
|
||||||
|
return DeepSeekChatOpenAI(
|
||||||
|
model="deepseek-v4-flash",
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_capture_on_receive(self):
|
||||||
|
"""When the response carries reasoning_content, it lands on the
|
||||||
|
AIMessage's additional_kwargs so the next turn can echo it back."""
|
||||||
|
client = self._client()
|
||||||
|
result = client._create_chat_result(
|
||||||
|
{
|
||||||
|
"model": "deepseek-v4-flash",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Plan: buy NVDA.",
|
||||||
|
"reasoning_content": "Step 1: trend is up. Step 2: ...",
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ai = result.generations[0].message
|
||||||
|
assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..."
|
||||||
|
|
||||||
|
def test_propagate_on_send(self):
|
||||||
|
"""When an outgoing AIMessage carries reasoning_content, the request
|
||||||
|
payload echoes it on the corresponding message dict."""
|
||||||
|
client = self._client()
|
||||||
|
prior = AIMessage(
|
||||||
|
content="Plan",
|
||||||
|
additional_kwargs={"reasoning_content": "weighed bull case"},
|
||||||
|
)
|
||||||
|
new_user = HumanMessage(content="Refine.")
|
||||||
|
payload = client._get_request_payload([prior, new_user])
|
||||||
|
# Find the assistant message in the payload
|
||||||
|
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
|
||||||
|
assert assistant_dicts, "assistant message missing from outgoing payload"
|
||||||
|
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
|
||||||
|
|
||||||
|
def test_propagate_through_chat_prompt_value(self):
|
||||||
|
"""Gemini bot review note: non-list inputs (ChatPromptValue) must
|
||||||
|
also propagate reasoning_content."""
|
||||||
|
client = self._client()
|
||||||
|
prior = AIMessage(
|
||||||
|
content="Plan",
|
||||||
|
additional_kwargs={"reasoning_content": "weighed bull case"},
|
||||||
|
)
|
||||||
|
prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")])
|
||||||
|
payload = client._get_request_payload(prompt_value)
|
||||||
|
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
|
||||||
|
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Capability-driven structured output: tool_choice suppressed for V4 + reasoner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _bound_kwargs(runnable):
|
||||||
|
"""Extract bind() kwargs from a with_structured_output result."""
|
||||||
|
first = runnable.steps[0] if hasattr(runnable, "steps") else runnable
|
||||||
|
return getattr(first, "kwargs", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestStructuredOutputCapabilityDispatch:
|
||||||
|
"""DeepSeek V4 and reasoner reject the tool_choice parameter
|
||||||
|
(official guide: api-docs.deepseek.com/guides/tool_calls passes
|
||||||
|
tools=[...] without tool_choice). Verify the capability dispatch
|
||||||
|
suppresses tool_choice for those models and sends it for chat."""
|
||||||
|
|
||||||
|
class _Sample(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
def _client(self, model):
|
||||||
|
return DeepSeekChatOpenAI(
|
||||||
|
model=model, api_key="placeholder", base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chat_sends_tool_choice(self):
|
||||||
|
bound = self._client("deepseek-chat").with_structured_output(self._Sample)
|
||||||
|
assert _bound_kwargs(bound).get("tool_choice") is not None
|
||||||
|
|
||||||
|
def test_reasoner_suppresses_tool_choice(self):
|
||||||
|
bound = self._client("deepseek-reasoner").with_structured_output(self._Sample)
|
||||||
|
# tool_choice is either absent or explicitly None — both are valid
|
||||||
|
# signals that langchain's bind_tools will skip the parameter.
|
||||||
|
assert _bound_kwargs(bound).get("tool_choice") in (None, ...) or \
|
||||||
|
"tool_choice" not in _bound_kwargs(bound)
|
||||||
|
|
||||||
|
def test_v4_flash_suppresses_tool_choice(self):
|
||||||
|
bound = self._client("deepseek-v4-flash").with_structured_output(self._Sample)
|
||||||
|
assert _bound_kwargs(bound).get("tool_choice") is None or \
|
||||||
|
"tool_choice" not in _bound_kwargs(bound)
|
||||||
|
|
||||||
|
def test_v4_pro_suppresses_tool_choice(self):
|
||||||
|
bound = self._client("deepseek-v4-pro").with_structured_output(self._Sample)
|
||||||
|
assert _bound_kwargs(bound).get("tool_choice") is None or \
|
||||||
|
"tool_choice" not in _bound_kwargs(bound)
|
||||||
|
|
||||||
|
def test_future_v_variant_via_regex(self):
|
||||||
|
"""Forward-compat: unknown deepseek-v\\d-* IDs inherit V4 quirks."""
|
||||||
|
bound = self._client("deepseek-v5-hypothetical").with_structured_output(self._Sample)
|
||||||
|
assert _bound_kwargs(bound).get("tool_choice") is None or \
|
||||||
|
"tool_choice" not in _bound_kwargs(bound)
|
||||||
|
|
||||||
|
def test_schema_is_still_bound_as_tool(self):
|
||||||
|
"""tool_choice is suppressed, but the schema is still bound as a tool —
|
||||||
|
exactly matching DeepSeek's official tool-calling examples."""
|
||||||
|
bound = self._client("deepseek-reasoner").with_structured_output(self._Sample)
|
||||||
|
kwargs = _bound_kwargs(bound)
|
||||||
|
tools = kwargs.get("tools", [])
|
||||||
|
assert any(
|
||||||
|
t.get("function", {}).get("name") == "_Sample" for t in tools
|
||||||
|
), f"schema not bound as a tool: {tools}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Live API: structured output round-trips against the real DeepSeek backend
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _has_real_deepseek_key():
|
||||||
|
key = os.environ.get("DEEPSEEK_API_KEY", "")
|
||||||
|
return bool(key) and key != "placeholder"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _has_real_deepseek_key(),
|
||||||
|
reason="DEEPSEEK_API_KEY not set (or placeholder); skipping live API call",
|
||||||
|
)
|
||||||
|
class TestDeepSeekLiveStructuredOutput:
|
||||||
|
"""End-to-end: a real DeepSeek V4-flash call returns a typed instance.
|
||||||
|
|
||||||
|
Verifies the no-tool_choice path doesn't trigger the 400 reported in
|
||||||
|
issue #678 and that the structured-output binding still parses to a
|
||||||
|
Pydantic instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class _Pick(BaseModel):
|
||||||
|
action: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
def test_v4_flash_returns_structured_output(self):
|
||||||
|
client = DeepSeekChatOpenAI(
|
||||||
|
model="deepseek-v4-flash",
|
||||||
|
api_key=os.environ["DEEPSEEK_API_KEY"],
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
bound = client.with_structured_output(self._Pick)
|
||||||
|
result = bound.invoke(
|
||||||
|
"Pick BUY or SELL or HOLD for a tech stock with strong earnings. "
|
||||||
|
"Confidence is a float between 0 and 1."
|
||||||
|
)
|
||||||
|
assert isinstance(result, self._Pick)
|
||||||
|
assert result.action in {"BUY", "SELL", "HOLD"}
|
||||||
|
assert 0.0 <= result.confidence <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestBaseClassIsolation:
|
||||||
|
def test_normalized_does_not_propagate_reasoning_content(self):
|
||||||
|
"""The general-purpose NormalizedChatOpenAI must not carry
|
||||||
|
DeepSeek-specific behaviour. Only the subclass does."""
|
||||||
|
assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or (
|
||||||
|
NormalizedChatOpenAI._get_request_payload
|
||||||
|
is NormalizedChatOpenAI.__bases__[0]._get_request_payload
|
||||||
|
)
|
||||||
98
tests/test_env_overrides.py
Normal file
98
tests/test_env_overrides.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Tests for TRADINGAGENTS_* env-var overlay onto DEFAULT_CONFIG."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.default_config as default_config_module
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_with_env(monkeypatch, **overrides):
|
||||||
|
"""Set/clear env vars then reload default_config to re-evaluate DEFAULT_CONFIG."""
|
||||||
|
for key in list(default_config_module._ENV_OVERRIDES):
|
||||||
|
monkeypatch.delenv(key, raising=False)
|
||||||
|
for key, val in overrides.items():
|
||||||
|
monkeypatch.setenv(key, val)
|
||||||
|
return importlib.reload(default_config_module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_env_uses_built_in_defaults(monkeypatch):
|
||||||
|
dc = _reload_with_env(monkeypatch)
|
||||||
|
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
|
||||||
|
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gpt-5.4"
|
||||||
|
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gpt-5.4-mini"
|
||||||
|
assert dc.DEFAULT_CONFIG["backend_url"] is None
|
||||||
|
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
|
||||||
|
assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_overrides(monkeypatch):
|
||||||
|
dc = _reload_with_env(
|
||||||
|
monkeypatch,
|
||||||
|
TRADINGAGENTS_LLM_PROVIDER="google",
|
||||||
|
TRADINGAGENTS_DEEP_THINK_LLM="gemini-3-pro-preview",
|
||||||
|
TRADINGAGENTS_QUICK_THINK_LLM="gemini-3-flash-preview",
|
||||||
|
TRADINGAGENTS_LLM_BACKEND_URL="https://example.invalid/v1",
|
||||||
|
TRADINGAGENTS_OUTPUT_LANGUAGE="Chinese",
|
||||||
|
)
|
||||||
|
assert dc.DEFAULT_CONFIG["llm_provider"] == "google"
|
||||||
|
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gemini-3-pro-preview"
|
||||||
|
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gemini-3-flash-preview"
|
||||||
|
assert dc.DEFAULT_CONFIG["backend_url"] == "https://example.invalid/v1"
|
||||||
|
assert dc.DEFAULT_CONFIG["output_language"] == "Chinese"
|
||||||
|
|
||||||
|
|
||||||
|
def test_int_coercion(monkeypatch):
|
||||||
|
dc = _reload_with_env(
|
||||||
|
monkeypatch,
|
||||||
|
TRADINGAGENTS_MAX_DEBATE_ROUNDS="3",
|
||||||
|
TRADINGAGENTS_MAX_RISK_ROUNDS="2",
|
||||||
|
)
|
||||||
|
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 3
|
||||||
|
assert isinstance(dc.DEFAULT_CONFIG["max_debate_rounds"], int)
|
||||||
|
assert dc.DEFAULT_CONFIG["max_risk_discuss_rounds"] == 2
|
||||||
|
assert isinstance(dc.DEFAULT_CONFIG["max_risk_discuss_rounds"], int)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"raw,expected",
|
||||||
|
[
|
||||||
|
("true", True), ("True", True), ("1", True), ("yes", True), ("on", True),
|
||||||
|
("false", False), ("False", False), ("0", False), ("no", False), ("off", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bool_coercion(monkeypatch, raw, expected):
|
||||||
|
dc = _reload_with_env(monkeypatch, TRADINGAGENTS_CHECKPOINT_ENABLED=raw)
|
||||||
|
assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_env_value_is_passthrough(monkeypatch):
|
||||||
|
"""Empty TRADINGAGENTS_* values must not clobber the built-in default."""
|
||||||
|
dc = _reload_with_env(
|
||||||
|
monkeypatch,
|
||||||
|
TRADINGAGENTS_LLM_PROVIDER="",
|
||||||
|
TRADINGAGENTS_MAX_DEBATE_ROUNDS="",
|
||||||
|
)
|
||||||
|
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
|
||||||
|
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_int_raises(monkeypatch):
|
||||||
|
"""Garbage int values should surface a ValueError at import, not silently misconfigure."""
|
||||||
|
monkeypatch.setenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", "not-a-number")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
importlib.reload(default_config_module)
|
||||||
|
# Restore module state for subsequent tests in this process
|
||||||
|
monkeypatch.delenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", raising=False)
|
||||||
|
importlib.reload(default_config_module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_env_var_is_ignored(monkeypatch):
|
||||||
|
"""Env vars outside _ENV_OVERRIDES must not bleed into DEFAULT_CONFIG."""
|
||||||
|
dc = _reload_with_env(
|
||||||
|
monkeypatch,
|
||||||
|
TRADINGAGENTS_NONEXISTENT_KEY="oops",
|
||||||
|
)
|
||||||
|
assert "nonexistent_key" not in dc.DEFAULT_CONFIG
|
||||||
31
tests/test_google_api_key.py
Normal file
31
tests/test_google_api_key.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.google_client import GoogleClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestGoogleApiKeyStandardization(unittest.TestCase):
|
||||||
|
"""Verify GoogleClient accepts unified api_key parameter."""
|
||||||
|
|
||||||
|
@patch("tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI")
|
||||||
|
def test_api_key_handling(self, mock_chat):
|
||||||
|
test_cases = [
|
||||||
|
("unified api_key is mapped", {"api_key": "test-key-123"}, "test-key-123"),
|
||||||
|
("legacy google_api_key still works", {"google_api_key": "legacy-key-456"}, "legacy-key-456"),
|
||||||
|
("unified api_key takes precedence", {"api_key": "unified", "google_api_key": "legacy"}, "unified"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for msg, kwargs, expected_key in test_cases:
|
||||||
|
with self.subTest(msg=msg):
|
||||||
|
mock_chat.reset_mock()
|
||||||
|
client = GoogleClient("gemini-2.5-flash", **kwargs)
|
||||||
|
client.get_llm()
|
||||||
|
call_kwargs = mock_chat.call_args[1]
|
||||||
|
self.assertEqual(call_kwargs.get("google_api_key"), expected_key)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
860
tests/test_memory_log.py
Normal file
860
tests/test_memory_log.py
Normal file
@@ -0,0 +1,860 @@
|
|||||||
|
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||||
|
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
|
||||||
|
from tradingagents.graph.reflection import Reflector
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
from tradingagents.graph.propagation import Propagator
|
||||||
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||||
|
|
||||||
|
_SEP = TradingMemoryLog._SEPARATOR
|
||||||
|
|
||||||
|
DECISION_BUY = "Rating: Buy\nEnter at $189-192, 6% portfolio cap."
|
||||||
|
DECISION_OVERWEIGHT = (
|
||||||
|
"Rating: Overweight\n"
|
||||||
|
"Executive Summary: Moderate position, await confirmation.\n"
|
||||||
|
"Investment Thesis: Strong fundamentals but near-term headwinds."
|
||||||
|
)
|
||||||
|
DECISION_SELL = "Rating: Sell\nExit position immediately."
|
||||||
|
DECISION_NO_RATING = (
|
||||||
|
"Executive Summary: Complex situation with multiple competing factors.\n"
|
||||||
|
"Investment Thesis: No clear directional signal at this time."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def make_log(tmp_path, filename="trading_memory.md"):
|
||||||
|
config = {"memory_log_path": str(tmp_path / filename)}
|
||||||
|
return TradingMemoryLog(config)
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_completed(tmp_path, ticker, date, decision_text, reflection_text, filename="trading_memory.md"):
|
||||||
|
"""Write a completed entry directly to file, bypassing the API."""
|
||||||
|
entry = (
|
||||||
|
f"[{date} | {ticker} | Buy | +1.0% | +0.5% | 5d]\n\n"
|
||||||
|
f"DECISION:\n{decision_text}\n\n"
|
||||||
|
f"REFLECTION:\n{reflection_text}"
|
||||||
|
+ _SEP
|
||||||
|
)
|
||||||
|
with open(tmp_path / filename, "a", encoding="utf-8") as f:
|
||||||
|
f.write(entry)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_entry(log, ticker, date, decision, reflection="Good call."):
|
||||||
|
"""Store a decision then immediately resolve it via the API."""
|
||||||
|
log.store_decision(ticker, date, decision)
|
||||||
|
log.update_with_outcome(ticker, date, 0.05, 0.02, 5, reflection)
|
||||||
|
|
||||||
|
|
||||||
|
def _price_df(prices):
|
||||||
|
"""Minimal DataFrame matching yfinance .history() output shape."""
|
||||||
|
return pd.DataFrame({"Close": prices})
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pm_state(past_context=""):
|
||||||
|
"""Minimal AgentState dict for portfolio_manager_node."""
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"past_context": past_context,
|
||||||
|
"risk_debate_state": {
|
||||||
|
"history": "Risk debate history.",
|
||||||
|
"aggressive_history": "",
|
||||||
|
"conservative_history": "",
|
||||||
|
"neutral_history": "",
|
||||||
|
"judge_decision": "",
|
||||||
|
"current_aggressive_response": "",
|
||||||
|
"current_conservative_response": "",
|
||||||
|
"current_neutral_response": "",
|
||||||
|
"count": 1,
|
||||||
|
},
|
||||||
|
"market_report": "Market report.",
|
||||||
|
"sentiment_report": "Sentiment report.",
|
||||||
|
"news_report": "News report.",
|
||||||
|
"fundamentals_report": "Fundamentals report.",
|
||||||
|
"investment_plan": "Research plan.",
|
||||||
|
"trader_investment_plan": "Trader plan.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _structured_pm_llm(captured: dict, decision: PortfolioDecision | None = None):
|
||||||
|
"""Build a MagicMock LLM whose with_structured_output binding captures the
|
||||||
|
prompt and returns a real PortfolioDecision (so render_pm_decision works).
|
||||||
|
"""
|
||||||
|
if decision is None:
|
||||||
|
decision = PortfolioDecision(
|
||||||
|
rating=PortfolioRating.HOLD,
|
||||||
|
executive_summary="Hold the position; await catalyst.",
|
||||||
|
investment_thesis="Balanced view; neither side carried the debate.",
|
||||||
|
)
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = lambda prompt: (
|
||||||
|
captured.__setitem__("prompt", prompt) or decision
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core: storage and read path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTradingMemoryLogCore:
|
||||||
|
|
||||||
|
def test_store_creates_file(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
assert not (tmp_path / "trading_memory.md").exists()
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert (tmp_path / "trading_memory.md").exists()
|
||||||
|
|
||||||
|
def test_store_appends_not_overwrites(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 2
|
||||||
|
assert entries[0]["ticker"] == "NVDA"
|
||||||
|
assert entries[1]["ticker"] == "AAPL"
|
||||||
|
|
||||||
|
def test_store_decision_idempotent(self, tmp_path):
|
||||||
|
"""Calling store_decision twice with same (ticker, date) stores only one entry."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert len(log.load_entries()) == 1
|
||||||
|
|
||||||
|
def test_batch_update_resolves_multiple_entries(self, tmp_path):
|
||||||
|
"""batch_update_with_outcomes resolves multiple pending entries in one write."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||||
|
log.store_decision("NVDA", "2026-01-12", DECISION_SELL)
|
||||||
|
|
||||||
|
updates = [
|
||||||
|
{"ticker": "NVDA", "trade_date": "2026-01-05",
|
||||||
|
"raw_return": 0.05, "alpha_return": 0.02, "holding_days": 5,
|
||||||
|
"reflection": "First correct."},
|
||||||
|
{"ticker": "NVDA", "trade_date": "2026-01-12",
|
||||||
|
"raw_return": -0.03, "alpha_return": -0.01, "holding_days": 5,
|
||||||
|
"reflection": "Second correct."},
|
||||||
|
]
|
||||||
|
log.batch_update_with_outcomes(updates)
|
||||||
|
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 2
|
||||||
|
assert all(not e["pending"] for e in entries)
|
||||||
|
assert entries[0]["reflection"] == "First correct."
|
||||||
|
assert entries[1]["reflection"] == "Second correct."
|
||||||
|
|
||||||
|
def test_pending_tag_format(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||||
|
assert "[2026-01-10 | NVDA | Buy | pending]" in text
|
||||||
|
|
||||||
|
# Rating parsing
|
||||||
|
|
||||||
|
def test_rating_parsed_buy(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Buy"
|
||||||
|
|
||||||
|
def test_rating_parsed_overweight(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Overweight"
|
||||||
|
|
||||||
|
def test_rating_fallback_hold(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Hold"
|
||||||
|
|
||||||
|
def test_rating_priority_over_prose(self, tmp_path):
|
||||||
|
"""'Rating: X' label wins even when an opposing rating word appears earlier in prose."""
|
||||||
|
decision = (
|
||||||
|
"The sell thesis is weak. The hold case is marginal.\n\n"
|
||||||
|
"Rating: Buy\n\n"
|
||||||
|
"Executive Summary: Strong fundamentals support the position."
|
||||||
|
)
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Buy"
|
||||||
|
|
||||||
|
# Delimiter robustness
|
||||||
|
|
||||||
|
def test_decision_with_markdown_separator(self, tmp_path):
|
||||||
|
"""LLM decision containing '---' must not corrupt the entry."""
|
||||||
|
decision = "Rating: Buy\n\n---\n\nRisk: elevated volatility."
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert "Risk: elevated volatility" in entries[0]["decision"]
|
||||||
|
|
||||||
|
# load_entries
|
||||||
|
|
||||||
|
def test_load_entries_empty_file(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
assert log.load_entries() == []
|
||||||
|
|
||||||
|
def test_load_entries_single(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
e = entries[0]
|
||||||
|
assert e["date"] == "2026-01-10"
|
||||||
|
assert e["ticker"] == "NVDA"
|
||||||
|
assert e["rating"] == "Buy"
|
||||||
|
assert e["pending"] is True
|
||||||
|
assert e["raw"] is None
|
||||||
|
|
||||||
|
def test_load_entries_multiple(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||||
|
log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING)
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 3
|
||||||
|
assert [e["ticker"] for e in entries] == ["NVDA", "AAPL", "MSFT"]
|
||||||
|
|
||||||
|
def test_decision_content_preserved(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert log.load_entries()[0]["decision"] == DECISION_BUY.strip()
|
||||||
|
|
||||||
|
# get_pending_entries
|
||||||
|
|
||||||
|
def test_get_pending_returns_pending_only(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA.", "Correct.")
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
pending = log.get_pending_entries()
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0]["ticker"] == "NVDA"
|
||||||
|
assert pending[0]["date"] == "2026-01-10"
|
||||||
|
|
||||||
|
# get_past_context
|
||||||
|
|
||||||
|
def test_get_past_context_empty(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
assert log.get_past_context("NVDA") == ""
|
||||||
|
|
||||||
|
def test_get_past_context_pending_excluded(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert log.get_past_context("NVDA") == ""
|
||||||
|
|
||||||
|
def test_get_past_context_same_ticker(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA — AI capex thesis intact.", "Directionally correct.")
|
||||||
|
ctx = log.get_past_context("NVDA")
|
||||||
|
assert "Past analyses of NVDA" in ctx
|
||||||
|
assert "Buy NVDA" in ctx
|
||||||
|
|
||||||
|
def test_get_past_context_cross_ticker(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_seed_completed(tmp_path, "AAPL", "2026-01-05", "Buy AAPL — Services growth.", "Correct.")
|
||||||
|
ctx = log.get_past_context("NVDA")
|
||||||
|
assert "Recent cross-ticker lessons" in ctx
|
||||||
|
assert "Past analyses of NVDA" not in ctx
|
||||||
|
|
||||||
|
def test_n_same_limit_respected(self, tmp_path):
|
||||||
|
"""Only the n_same most recent same-ticker entries are included."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
for i in range(6):
|
||||||
|
_seed_completed(tmp_path, "NVDA", f"2026-01-{i+1:02d}", f"Buy entry {i}.", "Correct.")
|
||||||
|
ctx = log.get_past_context("NVDA", n_same=5)
|
||||||
|
assert "Buy entry 0" not in ctx
|
||||||
|
assert "Buy entry 5" in ctx
|
||||||
|
|
||||||
|
def test_n_cross_limit_respected(self, tmp_path):
|
||||||
|
"""Only the n_cross most recent cross-ticker entries are included."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
for i, ticker in enumerate(["AAPL", "MSFT", "GOOG", "META"]):
|
||||||
|
_seed_completed(tmp_path, ticker, f"2026-01-{i+1:02d}", f"Buy {ticker}.", "Correct.")
|
||||||
|
ctx = log.get_past_context("NVDA", n_cross=3)
|
||||||
|
assert "AAPL" not in ctx
|
||||||
|
assert "META" in ctx
|
||||||
|
|
||||||
|
# No-op when config is None
|
||||||
|
|
||||||
|
def test_no_log_path_is_noop(self):
|
||||||
|
log = TradingMemoryLog(config=None)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
assert log.load_entries() == []
|
||||||
|
assert log.get_past_context("NVDA") == ""
|
||||||
|
|
||||||
|
# Rotation: opt-in cap on resolved entries
|
||||||
|
|
||||||
|
def test_rotation_disabled_by_default(self, tmp_path):
|
||||||
|
"""Without max_entries, all resolved entries are kept."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
for i in range(7):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||||
|
assert len(log.load_entries()) == 7
|
||||||
|
|
||||||
|
def test_rotation_prunes_oldest_resolved(self, tmp_path):
|
||||||
|
"""When max_entries is set and exceeded, oldest resolved entries are pruned."""
|
||||||
|
log = TradingMemoryLog({
|
||||||
|
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||||
|
"memory_log_max_entries": 3,
|
||||||
|
})
|
||||||
|
# Resolve 5 entries; rotation should keep only the 3 most recent.
|
||||||
|
for i in range(5):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 3
|
||||||
|
# Confirm the OLDEST were dropped, not the newest.
|
||||||
|
dates = [e["date"] for e in entries]
|
||||||
|
assert dates == ["2026-01-03", "2026-01-04", "2026-01-05"]
|
||||||
|
|
||||||
|
def test_rotation_never_prunes_pending(self, tmp_path):
|
||||||
|
"""Pending entries (unresolved) are kept regardless of the cap."""
|
||||||
|
log = TradingMemoryLog({
|
||||||
|
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||||
|
"memory_log_max_entries": 2,
|
||||||
|
})
|
||||||
|
# 3 resolved + 2 pending. With cap=2, only 2 resolved survive; both pending stay.
|
||||||
|
for i in range(3):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Resolved {i}.")
|
||||||
|
log.store_decision("NVDA", "2026-02-01", DECISION_BUY)
|
||||||
|
log.store_decision("NVDA", "2026-02-02", DECISION_OVERWEIGHT)
|
||||||
|
# Trigger rotation by resolving one more entry — pending entries must stay.
|
||||||
|
_resolve_entry(log, "NVDA", "2026-01-04", DECISION_BUY, "Resolved 3.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
pending = [e for e in entries if e["pending"]]
|
||||||
|
resolved = [e for e in entries if not e["pending"]]
|
||||||
|
assert len(pending) == 2, "pending entries must never be pruned"
|
||||||
|
assert len(resolved) == 2, f"expected 2 resolved after rotation, got {len(resolved)}"
|
||||||
|
|
||||||
|
def test_rotation_under_cap_is_noop(self, tmp_path):
|
||||||
|
"""No rotation when resolved count <= max_entries."""
|
||||||
|
log = TradingMemoryLog({
|
||||||
|
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||||
|
"memory_log_max_entries": 10,
|
||||||
|
})
|
||||||
|
for i in range(3):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||||
|
assert len(log.load_entries()) == 3
|
||||||
|
|
||||||
|
# Rating parsing: markdown bold and numbered list formats
|
||||||
|
|
||||||
|
def test_rating_parsed_from_bold_markdown(self, tmp_path):
|
||||||
|
"""**Rating**: Buy — markdown bold around the label must not prevent parsing."""
|
||||||
|
decision = "**Rating**: Buy\nEnter at $190."
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Buy"
|
||||||
|
|
||||||
|
def test_rating_parsed_from_bold_value(self, tmp_path):
|
||||||
|
"""Rating: **Sell** — markdown bold around the value must not prevent parsing."""
|
||||||
|
decision = "Rating: **Sell**\nExit immediately."
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Sell"
|
||||||
|
|
||||||
|
def test_rating_label_wins_over_prose_with_markdown(self, tmp_path):
|
||||||
|
"""Rating: **Sell** must win even when prose contains a conflicting rating word."""
|
||||||
|
decision = (
|
||||||
|
"The buy thesis is weakened by guidance.\n"
|
||||||
|
"Rating: **Sell**\n"
|
||||||
|
"Exit before earnings."
|
||||||
|
)
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Sell"
|
||||||
|
|
||||||
|
def test_rating_parsed_from_numbered_list(self, tmp_path):
|
||||||
|
"""1. Rating: Buy — numbered list prefix must not prevent parsing."""
|
||||||
|
decision = "1. Rating: Buy\nEnter at $190."
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", decision)
|
||||||
|
assert log.load_entries()[0]["rating"] == "Buy"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Deferred reflection: update_with_outcome, Reflector, _fetch_returns
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDeferredReflection:
|
||||||
|
|
||||||
|
# update_with_outcome
|
||||||
|
|
||||||
|
def test_update_replaces_pending_tag(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||||
|
text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||||
|
assert "[2026-01-10 | NVDA | Buy | pending]" not in text
|
||||||
|
assert "+4.2%" in text
|
||||||
|
assert "+2.1%" in text
|
||||||
|
assert "5d" in text
|
||||||
|
|
||||||
|
def test_update_appends_reflection(self, tmp_path):
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
e = entries[0]
|
||||||
|
assert e["pending"] is False
|
||||||
|
assert e["reflection"] == "Momentum confirmed."
|
||||||
|
assert e["decision"] == DECISION_BUY.strip()
|
||||||
|
|
||||||
|
def test_update_preserves_other_entries(self, tmp_path):
|
||||||
|
"""Only the matching entry is modified; all other entries remain unchanged."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.store_decision("AAPL", "2026-01-11", "Rating: Hold\nHold AAPL.")
|
||||||
|
log.store_decision("MSFT", "2026-01-12", DECISION_SELL)
|
||||||
|
log.update_with_outcome("AAPL", "2026-01-11", 0.01, -0.01, 5, "Neutral result.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 3
|
||||||
|
nvda, aapl, msft = entries
|
||||||
|
assert nvda["ticker"] == "NVDA" and nvda["pending"] is True
|
||||||
|
assert aapl["ticker"] == "AAPL" and aapl["pending"] is False
|
||||||
|
assert aapl["reflection"] == "Neutral result."
|
||||||
|
assert msft["ticker"] == "MSFT" and msft["pending"] is True
|
||||||
|
|
||||||
|
def test_update_atomic_write(self, tmp_path):
|
||||||
|
"""A pre-existing .tmp file is overwritten; the log is correctly updated."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
stale_tmp = tmp_path / "trading_memory.tmp"
|
||||||
|
stale_tmp.write_text("GARBAGE CONTENT — should be overwritten", encoding="utf-8")
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Correct.")
|
||||||
|
assert not stale_tmp.exists()
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["reflection"] == "Correct."
|
||||||
|
assert entries[0]["pending"] is False
|
||||||
|
|
||||||
|
def test_update_noop_when_no_log_path(self):
|
||||||
|
log = TradingMemoryLog(config=None)
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.05, 0.02, 5, "Reflection")
|
||||||
|
|
||||||
|
def test_formatting_roundtrip_after_update(self, tmp_path):
|
||||||
|
"""All fields intact and blank line between tag and DECISION preserved after update."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
e = entries[0]
|
||||||
|
assert e["pending"] is False
|
||||||
|
assert e["decision"] == DECISION_BUY.strip()
|
||||||
|
assert e["reflection"] == "Momentum confirmed."
|
||||||
|
assert e["raw"] == "+4.2%"
|
||||||
|
assert e["alpha"] == "+2.1%"
|
||||||
|
assert e["holding"] == "5d"
|
||||||
|
raw_text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||||
|
assert "[2026-01-10 | NVDA | Buy | +4.2% | +2.1% | 5d]\n\nDECISION:" in raw_text
|
||||||
|
|
||||||
|
# Reflector.reflect_on_final_decision
|
||||||
|
|
||||||
|
def test_reflect_on_final_decision_returns_llm_output(self):
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.invoke.return_value.content = "Directionally correct. Thesis confirmed."
|
||||||
|
reflector = Reflector(mock_llm)
|
||||||
|
result = reflector.reflect_on_final_decision(
|
||||||
|
final_decision=DECISION_BUY, raw_return=0.042, alpha_return=0.021
|
||||||
|
)
|
||||||
|
assert result == "Directionally correct. Thesis confirmed."
|
||||||
|
mock_llm.invoke.assert_called_once()
|
||||||
|
|
||||||
|
def test_reflect_on_final_decision_includes_returns_in_prompt(self):
|
||||||
|
"""Return figures are present in the human message sent to the LLM."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.invoke.return_value.content = "Incorrect call."
|
||||||
|
reflector = Reflector(mock_llm)
|
||||||
|
reflector.reflect_on_final_decision(
|
||||||
|
final_decision=DECISION_SELL, raw_return=-0.08, alpha_return=-0.05
|
||||||
|
)
|
||||||
|
messages = mock_llm.invoke.call_args[0][0]
|
||||||
|
human_content = next(content for role, content in messages if role == "human")
|
||||||
|
assert "-8.0%" in human_content
|
||||||
|
assert "-5.0%" in human_content
|
||||||
|
assert "Exit position immediately." in human_content
|
||||||
|
|
||||||
|
# TradingAgentsGraph._fetch_returns
|
||||||
|
|
||||||
|
def test_fetch_returns_valid_ticker(self):
|
||||||
|
stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0]
|
||||||
|
spy_prices = [400.0, 402.0, 404.0, 403.0, 405.0, 406.0]
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||||
|
def _make_ticker(sym):
|
||||||
|
m = MagicMock()
|
||||||
|
m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices)
|
||||||
|
return m
|
||||||
|
mock_ticker_cls.side_effect = _make_ticker
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05")
|
||||||
|
assert raw is not None and alpha is not None and days is not None
|
||||||
|
assert isinstance(raw, float) and isinstance(alpha, float) and isinstance(days, int)
|
||||||
|
assert days == 5
|
||||||
|
|
||||||
|
def test_fetch_returns_too_recent(self):
|
||||||
|
"""Only 1 data point available → returns (None, None, None), no crash."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||||
|
m = MagicMock()
|
||||||
|
m.history.return_value = _price_df([100.0])
|
||||||
|
mock_ticker_cls.return_value = m
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-04-19")
|
||||||
|
assert raw is None and alpha is None and days is None
|
||||||
|
|
||||||
|
def test_fetch_returns_delisted(self):
|
||||||
|
"""Empty DataFrame → returns (None, None, None), no crash."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||||
|
m = MagicMock()
|
||||||
|
m.history.return_value = pd.DataFrame({"Close": []})
|
||||||
|
mock_ticker_cls.return_value = m
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "XXXXXFAKE", "2026-01-10")
|
||||||
|
assert raw is None and alpha is None and days is None
|
||||||
|
|
||||||
|
def test_fetch_returns_spy_shorter_than_stock(self):
|
||||||
|
"""SPY having fewer rows than the stock must not raise IndexError."""
|
||||||
|
stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0]
|
||||||
|
spy_prices = [400.0, 402.0, 403.0]
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||||
|
def _make_ticker(sym):
|
||||||
|
m = MagicMock()
|
||||||
|
m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices)
|
||||||
|
return m
|
||||||
|
mock_ticker_cls.side_effect = _make_ticker
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05")
|
||||||
|
assert raw is not None and alpha is not None and days is not None
|
||||||
|
assert days == 2
|
||||||
|
|
||||||
|
# TradingAgentsGraph._resolve_benchmark — picks index for alpha calc
|
||||||
|
|
||||||
|
def test_resolve_benchmark_explicit_override(self):
|
||||||
|
"""config['benchmark_ticker'] wins for every ticker."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.config = {
|
||||||
|
"benchmark_ticker": "QQQ",
|
||||||
|
"benchmark_map": {"": "SPY", ".T": "^N225"},
|
||||||
|
}
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.T") == "QQQ"
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "NVDA") == "QQQ"
|
||||||
|
|
||||||
|
def test_resolve_benchmark_suffix_map(self):
|
||||||
|
"""Known suffixes route to their regional index."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.config = {
|
||||||
|
"benchmark_ticker": None,
|
||||||
|
"benchmark_map": {
|
||||||
|
".T": "^N225", ".HK": "^HSI", ".NS": "^NSEI",
|
||||||
|
".L": "^FTSE", ".TO": "^GSPTSE", ".AX": "^AXJO",
|
||||||
|
".BO": "^BSESN", "": "SPY",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.T") == "^N225"
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "0700.HK") == "^HSI"
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "RELIANCE.NS") == "^NSEI"
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "AZN.L") == "^FTSE"
|
||||||
|
|
||||||
|
def test_resolve_benchmark_us_ticker_defaults_to_spy(self):
|
||||||
|
"""US tickers (no dotted suffix) take the empty-suffix entry."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.config = {
|
||||||
|
"benchmark_ticker": None,
|
||||||
|
"benchmark_map": {"": "SPY", ".T": "^N225"},
|
||||||
|
}
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "NVDA") == "SPY"
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "AAPL") == "SPY"
|
||||||
|
|
||||||
|
def test_resolve_benchmark_unknown_suffix_falls_back(self):
|
||||||
|
"""Unrecognised suffix (BRK.B, FAKE.XX) falls back to SPY."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.config = {
|
||||||
|
"benchmark_ticker": None,
|
||||||
|
"benchmark_map": {"": "SPY", ".T": "^N225"},
|
||||||
|
}
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "FAKE.XX") == "SPY"
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "BRK.B") == "SPY"
|
||||||
|
|
||||||
|
def test_resolve_benchmark_case_insensitive(self):
|
||||||
|
"""Suffix matching is case-insensitive so 7203.t resolves like 7203.T."""
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.config = {
|
||||||
|
"benchmark_ticker": None,
|
||||||
|
"benchmark_map": {".T": "^N225", "": "SPY"},
|
||||||
|
}
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.t") == "^N225"
|
||||||
|
|
||||||
|
def test_reflector_includes_benchmark_in_label(self):
|
||||||
|
"""benchmark_name appears in the prompt label, not 'SPY' hardcoded."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.invoke.return_value.content = "Directionally correct."
|
||||||
|
reflector = Reflector(mock_llm)
|
||||||
|
reflector.reflect_on_final_decision(
|
||||||
|
final_decision=DECISION_BUY,
|
||||||
|
raw_return=0.05,
|
||||||
|
alpha_return=0.02,
|
||||||
|
benchmark_name="^N225",
|
||||||
|
)
|
||||||
|
messages = mock_llm.invoke.call_args[0][0]
|
||||||
|
human_content = next(content for role, content in messages if role == "human")
|
||||||
|
assert "Alpha vs ^N225:" in human_content
|
||||||
|
assert "Alpha vs SPY:" not in human_content
|
||||||
|
|
||||||
|
def test_reflector_defaults_to_spy_for_unupdated_callers(self):
|
||||||
|
"""Default benchmark_name keeps the SPY label for legacy callers."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.invoke.return_value.content = "ok"
|
||||||
|
reflector = Reflector(mock_llm)
|
||||||
|
reflector.reflect_on_final_decision(
|
||||||
|
final_decision=DECISION_BUY,
|
||||||
|
raw_return=0.05,
|
||||||
|
alpha_return=0.02,
|
||||||
|
)
|
||||||
|
messages = mock_llm.invoke.call_args[0][0]
|
||||||
|
human_content = next(content for role, content in messages if role == "human")
|
||||||
|
assert "Alpha vs SPY:" in human_content
|
||||||
|
|
||||||
|
# TradingAgentsGraph._resolve_pending_entries
|
||||||
|
|
||||||
|
def test_resolve_skips_other_tickers(self, tmp_path):
|
||||||
|
"""Pending AAPL entry is not resolved when the run is for NVDA."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("AAPL", "2026-01-10", DECISION_BUY)
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.memory_log = log
|
||||||
|
mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5))
|
||||||
|
TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA")
|
||||||
|
mock_graph._fetch_returns.assert_not_called()
|
||||||
|
assert len(log.get_pending_entries()) == 1
|
||||||
|
|
||||||
|
def test_resolve_marks_entry_completed(self, tmp_path):
|
||||||
|
"""After resolve, get_pending_entries() is empty and the entry has a REFLECTION."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||||
|
mock_reflector = MagicMock()
|
||||||
|
mock_reflector.reflect_on_final_decision.return_value = "Momentum confirmed."
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.memory_log = log
|
||||||
|
mock_graph.reflector = mock_reflector
|
||||||
|
mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5))
|
||||||
|
TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA")
|
||||||
|
assert log.get_pending_entries() == []
|
||||||
|
entries = log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["pending"] is False
|
||||||
|
assert entries[0]["reflection"] == "Momentum confirmed."
|
||||||
|
assert "+5.0%" in entries[0]["raw"]
|
||||||
|
assert "+2.0%" in entries[0]["alpha"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Portfolio Manager injection: past_context in state and prompt
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPortfolioManagerInjection:
|
||||||
|
|
||||||
|
# past_context in initial state
|
||||||
|
|
||||||
|
def test_past_context_in_initial_state(self):
|
||||||
|
propagator = Propagator()
|
||||||
|
state = propagator.create_initial_state("NVDA", "2026-01-10", past_context="some context")
|
||||||
|
assert "past_context" in state
|
||||||
|
assert state["past_context"] == "some context"
|
||||||
|
|
||||||
|
def test_past_context_defaults_to_empty(self):
|
||||||
|
propagator = Propagator()
|
||||||
|
state = propagator.create_initial_state("NVDA", "2026-01-10")
|
||||||
|
assert state["past_context"] == ""
|
||||||
|
|
||||||
|
# PM prompt
|
||||||
|
|
||||||
|
def test_pm_prompt_includes_past_context(self):
|
||||||
|
captured = {}
|
||||||
|
llm = _structured_pm_llm(captured)
|
||||||
|
pm_node = create_portfolio_manager(llm)
|
||||||
|
state = _make_pm_state(past_context="[2026-01-05 | NVDA | Buy | +5.0% | +2.0% | 5d]\nGreat call.")
|
||||||
|
pm_node(state)
|
||||||
|
assert "Lessons from prior decisions and outcomes" in captured["prompt"]
|
||||||
|
assert "Great call." in captured["prompt"]
|
||||||
|
|
||||||
|
def test_pm_no_past_context_no_section(self):
|
||||||
|
"""PM prompt omits the lessons section entirely when past_context is empty."""
|
||||||
|
captured = {}
|
||||||
|
llm = _structured_pm_llm(captured)
|
||||||
|
pm_node = create_portfolio_manager(llm)
|
||||||
|
state = _make_pm_state(past_context="")
|
||||||
|
pm_node(state)
|
||||||
|
assert "Lessons from prior decisions" not in captured["prompt"]
|
||||||
|
|
||||||
|
def test_pm_returns_rendered_markdown_with_rating(self):
|
||||||
|
"""The structured PortfolioDecision is rendered to markdown that
|
||||||
|
downstream consumers (memory log, signal processor, CLI display)
|
||||||
|
can parse without any extra LLM call."""
|
||||||
|
captured = {}
|
||||||
|
decision = PortfolioDecision(
|
||||||
|
rating=PortfolioRating.OVERWEIGHT,
|
||||||
|
executive_summary="Build position gradually over the next two weeks.",
|
||||||
|
investment_thesis="AI capex cycle remains intact; institutional flows constructive.",
|
||||||
|
price_target=215.0,
|
||||||
|
time_horizon="3-6 months",
|
||||||
|
)
|
||||||
|
llm = _structured_pm_llm(captured, decision)
|
||||||
|
pm_node = create_portfolio_manager(llm)
|
||||||
|
result = pm_node(_make_pm_state())
|
||||||
|
md = result["final_trade_decision"]
|
||||||
|
assert "**Rating**: Overweight" in md
|
||||||
|
assert "**Executive Summary**: Build position gradually" in md
|
||||||
|
assert "**Investment Thesis**: AI capex cycle" in md
|
||||||
|
assert "**Price Target**: 215.0" in md
|
||||||
|
assert "**Time Horizon**: 3-6 months" in md
|
||||||
|
|
||||||
|
def test_pm_falls_back_to_freetext_when_structured_unavailable(self):
|
||||||
|
"""If a provider does not support with_structured_output, the agent
|
||||||
|
falls back to a plain invoke and returns whatever prose the model
|
||||||
|
produced, so the pipeline never blocks."""
|
||||||
|
plain_response = "**Rating**: Sell\n\nExit ahead of guidance."
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||||
|
pm_node = create_portfolio_manager(llm)
|
||||||
|
result = pm_node(_make_pm_state())
|
||||||
|
assert result["final_trade_decision"] == plain_response
|
||||||
|
|
||||||
|
# get_past_context ordering and limits
|
||||||
|
|
||||||
|
def test_same_ticker_prioritised(self, tmp_path):
|
||||||
|
"""Same-ticker entries in same-ticker section; cross-ticker entries in cross-ticker section."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_resolve_entry(log, "NVDA", "2026-01-05", DECISION_BUY, "Momentum confirmed.")
|
||||||
|
_resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued.")
|
||||||
|
result = log.get_past_context("NVDA")
|
||||||
|
assert "Past analyses of NVDA" in result
|
||||||
|
assert "Recent cross-ticker lessons" in result
|
||||||
|
same_block, cross_block = result.split("Recent cross-ticker lessons")
|
||||||
|
assert "NVDA" in same_block
|
||||||
|
assert "AAPL" in cross_block
|
||||||
|
|
||||||
|
def test_cross_ticker_reflection_only(self, tmp_path):
|
||||||
|
"""Cross-ticker entries show only the REFLECTION text, not the full DECISION."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
_resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued correction.")
|
||||||
|
result = log.get_past_context("NVDA")
|
||||||
|
assert "Overvalued correction." in result
|
||||||
|
assert "Exit position immediately." not in result
|
||||||
|
|
||||||
|
def test_n_same_limit_respected(self, tmp_path):
|
||||||
|
"""More than 5 same-ticker completed entries → only 5 injected."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
for i in range(7):
|
||||||
|
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||||
|
result = log.get_past_context("NVDA", n_same=5)
|
||||||
|
lessons_present = sum(1 for i in range(7) if f"Lesson {i}." in result)
|
||||||
|
assert lessons_present == 5
|
||||||
|
|
||||||
|
def test_n_cross_limit_respected(self, tmp_path):
|
||||||
|
"""More than 3 cross-ticker completed entries → only 3 injected."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
tickers = ["AAPL", "MSFT", "TSLA", "AMZN", "GOOG"]
|
||||||
|
for i, ticker in enumerate(tickers):
|
||||||
|
_resolve_entry(log, ticker, f"2026-01-{i+1:02d}", DECISION_BUY, f"{ticker} lesson.")
|
||||||
|
result = log.get_past_context("NVDA", n_cross=3)
|
||||||
|
cross_count = sum(result.count(f"{t} lesson.") for t in tickers)
|
||||||
|
assert cross_count == 3
|
||||||
|
|
||||||
|
# Full A→B→C integration cycle
|
||||||
|
|
||||||
|
def test_full_cycle_store_resolve_inject(self, tmp_path):
|
||||||
|
"""store pending → resolve with outcome → past_context non-empty for PM."""
|
||||||
|
log = make_log(tmp_path)
|
||||||
|
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||||
|
assert len(log.get_pending_entries()) == 1
|
||||||
|
assert log.get_past_context("NVDA") == ""
|
||||||
|
log.update_with_outcome("NVDA", "2026-01-05", 0.05, 0.02, 5, "Correct call.")
|
||||||
|
assert log.get_pending_entries() == []
|
||||||
|
past_ctx = log.get_past_context("NVDA")
|
||||||
|
assert past_ctx != ""
|
||||||
|
assert "NVDA" in past_ctx
|
||||||
|
assert "Correct call." in past_ctx
|
||||||
|
assert "DECISION:" in past_ctx
|
||||||
|
assert "REFLECTION:" in past_ctx
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Legacy removal: BM25 / FinancialSituationMemory fully gone
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestLegacyRemoval:
|
||||||
|
|
||||||
|
def test_financial_situation_memory_removed(self):
|
||||||
|
"""FinancialSituationMemory must not be importable from the memory module."""
|
||||||
|
import tradingagents.agents.utils.memory as m
|
||||||
|
assert not hasattr(m, "FinancialSituationMemory")
|
||||||
|
|
||||||
|
def test_bm25_not_imported(self):
|
||||||
|
"""rank_bm25 must not be present in the memory module namespace."""
|
||||||
|
import tradingagents.agents.utils.memory as m
|
||||||
|
assert not hasattr(m, "BM25Okapi")
|
||||||
|
|
||||||
|
def test_reflect_and_remember_removed(self):
|
||||||
|
"""TradingAgentsGraph must not expose reflect_and_remember."""
|
||||||
|
assert not hasattr(TradingAgentsGraph, "reflect_and_remember")
|
||||||
|
|
||||||
|
def test_portfolio_manager_no_memory_param(self):
|
||||||
|
"""create_portfolio_manager accepts only llm; passing memory= raises TypeError."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
create_portfolio_manager(mock_llm)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
create_portfolio_manager(mock_llm, memory=MagicMock())
|
||||||
|
|
||||||
|
def test_full_pipeline_no_regression(self, tmp_path):
|
||||||
|
"""propagate() completes and stores the decision after the redesign."""
|
||||||
|
import functools
|
||||||
|
|
||||||
|
fake_state = {
|
||||||
|
"final_trade_decision": "Rating: Buy\nBuy NVDA.",
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"trade_date": "2026-01-10",
|
||||||
|
"market_report": "",
|
||||||
|
"sentiment_report": "",
|
||||||
|
"news_report": "",
|
||||||
|
"fundamentals_report": "",
|
||||||
|
"investment_debate_state": {
|
||||||
|
"bull_history": "", "bear_history": "", "history": "",
|
||||||
|
"current_response": "", "judge_decision": "",
|
||||||
|
},
|
||||||
|
"investment_plan": "",
|
||||||
|
"trader_investment_plan": "",
|
||||||
|
"risk_debate_state": {
|
||||||
|
"aggressive_history": "", "conservative_history": "",
|
||||||
|
"neutral_history": "", "history": "", "judge_decision": "",
|
||||||
|
"current_aggressive_response": "", "current_conservative_response": "",
|
||||||
|
"current_neutral_response": "", "count": 1, "latest_speaker": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.memory_log = TradingMemoryLog({"memory_log_path": str(tmp_path / "mem.md")})
|
||||||
|
mock_graph.log_states_dict = {}
|
||||||
|
mock_graph.debug = False
|
||||||
|
mock_graph.config = {"results_dir": str(tmp_path)}
|
||||||
|
mock_graph.graph.invoke.return_value = fake_state
|
||||||
|
mock_graph.propagator.create_initial_state.return_value = fake_state
|
||||||
|
mock_graph.propagator.get_graph_args.return_value = {}
|
||||||
|
mock_graph.signal_processor.process_signal.return_value = "Buy"
|
||||||
|
# Bind the real _run_graph so propagate's call to self._run_graph executes
|
||||||
|
# the actual write path instead of the auto-MagicMock.
|
||||||
|
mock_graph._run_graph = functools.partial(
|
||||||
|
TradingAgentsGraph._run_graph, mock_graph
|
||||||
|
)
|
||||||
|
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
|
||||||
|
entries = mock_graph.memory_log.load_entries()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["ticker"] == "NVDA"
|
||||||
|
assert entries[0]["pending"] is True
|
||||||
73
tests/test_minimax.py
Normal file
73
tests/test_minimax.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Tests for MinimaxChatOpenAI quirks.
|
||||||
|
|
||||||
|
Verifies the subclass injects ``reasoning_split=True`` into outgoing
|
||||||
|
requests so M2.x reasoning models put their <think> block into
|
||||||
|
``reasoning_details`` instead of polluting ``message.content``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.openai_client import MinimaxChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def _client(model: str = "MiniMax-M2.7"):
|
||||||
|
os.environ.setdefault("MINIMAX_API_KEY", "placeholder")
|
||||||
|
return MinimaxChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="https://api.minimax.io/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestMinimaxReasoningSplit:
|
||||||
|
def test_request_payload_sets_reasoning_split(self):
|
||||||
|
payload = _client()._get_request_payload([HumanMessage(content="hi")])
|
||||||
|
assert payload.get("reasoning_split") is True
|
||||||
|
|
||||||
|
def test_caller_supplied_reasoning_split_is_preserved(self):
|
||||||
|
"""If the user explicitly sets reasoning_split, don't override it
|
||||||
|
(setdefault semantics — caller wins)."""
|
||||||
|
client = _client()
|
||||||
|
payload = client._get_request_payload(
|
||||||
|
[HumanMessage(content="hi")],
|
||||||
|
reasoning_split=False,
|
||||||
|
)
|
||||||
|
# langchain may or may not surface that kwarg into the payload;
|
||||||
|
# what matters is we don't blindly overwrite a non-default value
|
||||||
|
# the caller passed. setdefault leaves an existing value alone.
|
||||||
|
assert payload.get("reasoning_split") in (False, True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestMinimaxStructuredOutputDispatch:
|
||||||
|
"""M2.x models route through the capability table — tool_choice is
|
||||||
|
suppressed but the schema is still bound as a tool."""
|
||||||
|
|
||||||
|
class _Pick(BaseModel):
|
||||||
|
action: str
|
||||||
|
|
||||||
|
def _bound_kwargs(self, runnable):
|
||||||
|
first = runnable.steps[0] if hasattr(runnable, "steps") else runnable
|
||||||
|
return getattr(first, "kwargs", {})
|
||||||
|
|
||||||
|
def test_m2_7_suppresses_tool_choice(self):
|
||||||
|
bound = _client("MiniMax-M2.7").with_structured_output(self._Pick)
|
||||||
|
kwargs = self._bound_kwargs(bound)
|
||||||
|
assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs
|
||||||
|
|
||||||
|
def test_m2_7_highspeed_suppresses_tool_choice(self):
|
||||||
|
bound = _client("MiniMax-M2.7-highspeed").with_structured_output(self._Pick)
|
||||||
|
kwargs = self._bound_kwargs(bound)
|
||||||
|
assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs
|
||||||
|
|
||||||
|
def test_schema_still_bound_as_tool(self):
|
||||||
|
bound = _client("MiniMax-M2.7").with_structured_output(self._Pick)
|
||||||
|
tools = self._bound_kwargs(bound).get("tools", [])
|
||||||
|
assert any(
|
||||||
|
t.get("function", {}).get("name") == "_Pick" for t in tools
|
||||||
|
), f"schema not bound: {tools}"
|
||||||
55
tests/test_model_validation.py
Normal file
55
tests/test_model_validation.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.base_client import BaseLLMClient
|
||||||
|
from tradingagents.llm_clients.model_catalog import get_known_models
|
||||||
|
from tradingagents.llm_clients.validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
class DummyLLMClient(BaseLLMClient):
|
||||||
|
def __init__(self, provider: str, model: str):
|
||||||
|
self.provider = provider
|
||||||
|
super().__init__(model)
|
||||||
|
|
||||||
|
def get_llm(self):
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
return object()
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
return validate_model(self.provider, self.model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class ModelValidationTests(unittest.TestCase):
|
||||||
|
def test_cli_catalog_models_are_all_validator_approved(self):
|
||||||
|
for provider, models in get_known_models().items():
|
||||||
|
if provider in ("ollama", "openrouter"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
with self.subTest(provider=provider, model=model):
|
||||||
|
self.assertTrue(validate_model(provider, model))
|
||||||
|
|
||||||
|
def test_unknown_model_emits_warning_for_strict_provider(self):
|
||||||
|
client = DummyLLMClient("openai", "not-a-real-openai-model")
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as caught:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
client.get_llm()
|
||||||
|
|
||||||
|
self.assertEqual(len(caught), 1)
|
||||||
|
self.assertIn("not-a-real-openai-model", str(caught[0].message))
|
||||||
|
self.assertIn("openai", str(caught[0].message))
|
||||||
|
|
||||||
|
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
|
||||||
|
for provider in ("openrouter", "ollama"):
|
||||||
|
client = DummyLLMClient(provider, "custom-model-name")
|
||||||
|
|
||||||
|
with self.subTest(provider=provider):
|
||||||
|
with warnings.catch_warnings(record=True) as caught:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
client.get_llm()
|
||||||
|
|
||||||
|
self.assertEqual(caught, [])
|
||||||
167
tests/test_ollama_base_url.py
Normal file
167
tests/test_ollama_base_url.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""Tests for OLLAMA_BASE_URL env-var override across CLI and client paths."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---- openai_client side: _resolve_provider_base_url -----------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_client():
|
||||||
|
import tradingagents.llm_clients.openai_client as mod
|
||||||
|
return importlib.reload(mod)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_returns_default_when_env_unset(monkeypatch):
|
||||||
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
|
mod = _reload_client()
|
||||||
|
assert mod._resolve_provider_base_url("ollama") == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_returns_env_when_set(monkeypatch):
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-ollama:11434/v1")
|
||||||
|
mod = _reload_client()
|
||||||
|
assert mod._resolve_provider_base_url("ollama") == "http://remote-ollama:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_evaluation_is_call_time(monkeypatch):
|
||||||
|
"""Setting the env AFTER module import must still take effect."""
|
||||||
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
|
mod = _reload_client()
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://late-set:11434/v1")
|
||||||
|
assert mod._resolve_provider_base_url("ollama") == "http://late-set:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_does_not_affect_other_providers(monkeypatch):
|
||||||
|
"""OLLAMA_BASE_URL should NOT leak into xai/deepseek/etc."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://elsewhere/v1")
|
||||||
|
mod = _reload_client()
|
||||||
|
assert mod._resolve_provider_base_url("xai") == "https://api.x.ai/v1"
|
||||||
|
assert mod._resolve_provider_base_url("deepseek") == "https://api.deepseek.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_get_llm_picks_up_env(monkeypatch):
|
||||||
|
"""End-to-end: OllamaClient.get_llm() respects OLLAMA_BASE_URL."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://my-ollama:11434/v1")
|
||||||
|
mod = _reload_client()
|
||||||
|
client = mod.OpenAIClient(model="llama3.1", provider="ollama")
|
||||||
|
llm = client.get_llm()
|
||||||
|
assert "my-ollama" in str(llm.openai_api_base)
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_base_url_overrides_env(monkeypatch):
|
||||||
|
"""An explicit base_url passed to the client wins over the env var."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://env-set:11434/v1")
|
||||||
|
mod = _reload_client()
|
||||||
|
client = mod.OpenAIClient(
|
||||||
|
model="llama3.1",
|
||||||
|
provider="ollama",
|
||||||
|
base_url="http://explicit:11434/v1",
|
||||||
|
)
|
||||||
|
llm = client.get_llm()
|
||||||
|
assert "explicit" in str(llm.openai_api_base)
|
||||||
|
assert "env-set" not in str(llm.openai_api_base)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- cli.utils side: select_llm_provider dropdown -------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_dropdown_uses_env(monkeypatch):
|
||||||
|
"""The Ollama entry in the CLI dropdown must reflect OLLAMA_BASE_URL."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://cli-remote:11434/v1")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
# Reach inside the function via the same env-read it does at call time
|
||||||
|
ollama_url = (
|
||||||
|
__import__("os").environ.get("OLLAMA_BASE_URL")
|
||||||
|
or "http://localhost:11434/v1"
|
||||||
|
)
|
||||||
|
assert ollama_url == "http://cli-remote:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_dropdown_default_when_unset(monkeypatch):
|
||||||
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
ollama_url = (
|
||||||
|
__import__("os").environ.get("OLLAMA_BASE_URL")
|
||||||
|
or "http://localhost:11434/v1"
|
||||||
|
)
|
||||||
|
assert ollama_url == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
# ---- confirm_ollama_endpoint UX -------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_shows_default(monkeypatch, capsys):
|
||||||
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("http://localhost:11434/v1")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "http://localhost:11434/v1" in out
|
||||||
|
assert "OLLAMA_BASE_URL" not in out # not from env
|
||||||
|
assert "Note" not in out # no warnings for the canonical default
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_marks_env_origin(monkeypatch, capsys):
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-host:11434/v1")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("http://remote-host:11434/v1")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "http://remote-host:11434/v1" in out
|
||||||
|
assert "OLLAMA_BASE_URL" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_warns_on_missing_scheme(monkeypatch, capsys):
|
||||||
|
"""If user sets OLLAMA_BASE_URL=0.0.0.128, advise on the expected shape."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "0.0.0.128")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("0.0.0.128")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "missing a scheme" in out
|
||||||
|
assert "http://<host>:11434/v1" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_warns_on_non_default_port_remote(monkeypatch, capsys):
|
||||||
|
"""A remote host with no :11434 gets a soft hint about port mismatch."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-host/v1")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("http://remote-host/v1")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "port 11434" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_quiet_on_local_no_port(monkeypatch, capsys):
|
||||||
|
"""Local host without port shouldn't trigger the remote-port hint."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://localhost/v1")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("http://localhost/v1")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Note" not in out # localhost is fine without explicit port
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_model_labels_no_local_suffix():
|
||||||
|
"""Labels should no longer claim '(local)' since the endpoint is dynamic."""
|
||||||
|
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||||
|
for mode in ("quick", "deep"):
|
||||||
|
labels = [label for label, _ in get_model_options("ollama", mode)]
|
||||||
|
assert all("local" not in label for label in labels), labels
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_offers_custom_model_id():
|
||||||
|
"""Ollama users with custom-pulled models can pick 'Custom model ID'."""
|
||||||
|
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||||
|
for mode in ("quick", "deep"):
|
||||||
|
entries = get_model_options("ollama", mode)
|
||||||
|
values = [v for _, v in entries]
|
||||||
|
assert "custom" in values, f"Ollama {mode!r} missing 'custom' option: {entries}"
|
||||||
|
# Custom option is last so it doesn't push the curated defaults off-screen
|
||||||
|
assert values[-1] == "custom", f"'custom' should be last entry: {values}"
|
||||||
52
tests/test_safe_ticker_component.py
Normal file
52
tests/test_safe_ticker_component.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Tests for the ticker path-component validator that blocks directory traversal."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSafeTickerComponent(unittest.TestCase):
|
||||||
|
def test_accepts_common_ticker_formats(self):
|
||||||
|
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
|
||||||
|
self.assertEqual(safe_ticker_component(ticker), ticker)
|
||||||
|
|
||||||
|
def test_rejects_path_separators(self):
|
||||||
|
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_null_byte_and_whitespace(self):
|
||||||
|
for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_empty_or_non_string(self):
|
||||||
|
for bad in ("", None, 123, b"AAPL"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_overlong_input(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component("A" * 33)
|
||||||
|
|
||||||
|
def test_rejects_dot_only_values(self):
|
||||||
|
# '.' and '..' pass the regex but traverse when used as a path
|
||||||
|
# component (e.g. ``Path(results_dir) / ticker / "logs"``).
|
||||||
|
for bad in (".", "..", "...", "...."):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_traversal_string_does_not_escape_join(self):
|
||||||
|
"""Sanity: sanitized values stay within base when joined."""
|
||||||
|
base = os.path.realpath("/tmp/cache")
|
||||||
|
ticker = safe_ticker_component("AAPL")
|
||||||
|
joined = os.path.realpath(os.path.join(base, f"{ticker}.csv"))
|
||||||
|
self.assertTrue(joined.startswith(base + os.sep))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
90
tests/test_signal_processing.py
Normal file
90
tests/test_signal_processing.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Tests for the shared rating heuristic and the SignalProcessor adapter.
|
||||||
|
|
||||||
|
The Portfolio Manager produces a typed PortfolioDecision via structured
|
||||||
|
output and renders it to markdown that always contains a ``**Rating**: X``
|
||||||
|
header. The deterministic heuristic in ``tradingagents.agents.utils.rating``
|
||||||
|
is therefore sufficient to extract the rating downstream — no second LLM
|
||||||
|
call is needed — and SignalProcessor is now a thin adapter that delegates
|
||||||
|
to it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
|
||||||
|
from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Heuristic parser
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestParseRating:
|
||||||
|
def test_explicit_label_buy(self):
|
||||||
|
assert parse_rating("Rating: Buy\nReasoning here.") == "Buy"
|
||||||
|
|
||||||
|
def test_explicit_label_overweight(self):
|
||||||
|
assert parse_rating("Rating: Overweight\nDetails.") == "Overweight"
|
||||||
|
|
||||||
|
def test_explicit_label_with_markdown_bold_value(self):
|
||||||
|
# Regression: Rating: **Sell** — markdown around the value.
|
||||||
|
assert parse_rating("Rating: **Sell**\nExit immediately.") == "Sell"
|
||||||
|
|
||||||
|
def test_explicit_label_with_markdown_bold_label(self):
|
||||||
|
assert parse_rating("**Rating**: Underweight\nTrim exposure.") == "Underweight"
|
||||||
|
|
||||||
|
def test_rendered_pm_markdown_shape(self):
|
||||||
|
# The exact shape produced by render_pm_decision must always parse.
|
||||||
|
text = (
|
||||||
|
"**Rating**: Buy\n\n"
|
||||||
|
"**Executive Summary**: Enter at $189-192, 6% portfolio cap.\n\n"
|
||||||
|
"**Investment Thesis**: AI capex cycle intact; institutional flows constructive."
|
||||||
|
)
|
||||||
|
assert parse_rating(text) == "Buy"
|
||||||
|
|
||||||
|
def test_explicit_label_wins_over_prose_with_markdown(self):
|
||||||
|
text = (
|
||||||
|
"The buy thesis is weakened by guidance.\n"
|
||||||
|
"Rating: **Sell**\n"
|
||||||
|
"Exit before earnings."
|
||||||
|
)
|
||||||
|
assert parse_rating(text) == "Sell"
|
||||||
|
|
||||||
|
def test_no_rating_returns_default(self):
|
||||||
|
assert parse_rating("No clear directional signal at this time.") == "Hold"
|
||||||
|
|
||||||
|
def test_no_rating_custom_default(self):
|
||||||
|
assert parse_rating("Plain prose.", default="Underweight") == "Underweight"
|
||||||
|
|
||||||
|
def test_all_five_tiers_recognised(self):
|
||||||
|
for r in RATINGS_5_TIER:
|
||||||
|
assert parse_rating(f"Rating: {r}") == r
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SignalProcessor: thin adapter over the heuristic
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSignalProcessor:
|
||||||
|
def test_returns_rating_from_pm_markdown(self):
|
||||||
|
sp = SignalProcessor()
|
||||||
|
md = "**Rating**: Overweight\n\n**Executive Summary**: Build gradually."
|
||||||
|
assert sp.process_signal(md) == "Overweight"
|
||||||
|
|
||||||
|
def test_makes_no_llm_calls(self):
|
||||||
|
"""SignalProcessor must not invoke the LLM it was constructed with —
|
||||||
|
the rating is parseable from the rendered PM markdown directly."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
sp = SignalProcessor(llm)
|
||||||
|
sp.process_signal("Rating: Buy\nDetails.")
|
||||||
|
llm.invoke.assert_not_called()
|
||||||
|
llm.with_structured_output.assert_not_called()
|
||||||
|
|
||||||
|
def test_default_when_no_rating_present(self):
|
||||||
|
sp = SignalProcessor()
|
||||||
|
assert sp.process_signal("Plain prose without a recommendation.") == "Hold"
|
||||||
232
tests/test_structured_agents.py
Normal file
232
tests/test_structured_agents.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
"""Tests for structured-output agents (Trader and Research Manager).
|
||||||
|
|
||||||
|
The Portfolio Manager has its own coverage in tests/test_memory_log.py
|
||||||
|
(which exercises the full memory-log → PM injection cycle). This file
|
||||||
|
covers the parallel schemas, render functions, and graceful-fallback
|
||||||
|
behavior we added for the Trader and Research Manager so all three
|
||||||
|
decision-making agents share the same shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.agents.managers.research_manager import create_research_manager
|
||||||
|
from tradingagents.agents.schemas import (
|
||||||
|
PortfolioRating,
|
||||||
|
ResearchPlan,
|
||||||
|
TraderAction,
|
||||||
|
TraderProposal,
|
||||||
|
render_research_plan,
|
||||||
|
render_trader_proposal,
|
||||||
|
)
|
||||||
|
from tradingagents.agents.trader.trader import create_trader
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Render functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRenderTraderProposal:
|
||||||
|
def test_minimal_required_fields(self):
|
||||||
|
p = TraderProposal(action=TraderAction.HOLD, reasoning="Balanced setup; no edge.")
|
||||||
|
md = render_trader_proposal(p)
|
||||||
|
assert "**Action**: Hold" in md
|
||||||
|
assert "**Reasoning**: Balanced setup; no edge." in md
|
||||||
|
# The trailing FINAL TRANSACTION PROPOSAL line is preserved for the
|
||||||
|
# analyst stop-signal text and any external code that greps for it.
|
||||||
|
assert "FINAL TRANSACTION PROPOSAL: **HOLD**" in md
|
||||||
|
|
||||||
|
def test_optional_fields_included_when_present(self):
|
||||||
|
p = TraderProposal(
|
||||||
|
action=TraderAction.BUY,
|
||||||
|
reasoning="Strong technicals + fundamentals.",
|
||||||
|
entry_price=189.5,
|
||||||
|
stop_loss=178.0,
|
||||||
|
position_sizing="6% of portfolio",
|
||||||
|
)
|
||||||
|
md = render_trader_proposal(p)
|
||||||
|
assert "**Action**: Buy" in md
|
||||||
|
assert "**Entry Price**: 189.5" in md
|
||||||
|
assert "**Stop Loss**: 178.0" in md
|
||||||
|
assert "**Position Sizing**: 6% of portfolio" in md
|
||||||
|
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in md
|
||||||
|
|
||||||
|
def test_optional_fields_omitted_when_absent(self):
|
||||||
|
p = TraderProposal(action=TraderAction.SELL, reasoning="Guidance cut.")
|
||||||
|
md = render_trader_proposal(p)
|
||||||
|
assert "Entry Price" not in md
|
||||||
|
assert "Stop Loss" not in md
|
||||||
|
assert "Position Sizing" not in md
|
||||||
|
assert "FINAL TRANSACTION PROPOSAL: **SELL**" in md
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRenderResearchPlan:
|
||||||
|
def test_required_fields(self):
|
||||||
|
p = ResearchPlan(
|
||||||
|
recommendation=PortfolioRating.OVERWEIGHT,
|
||||||
|
rationale="Bull case carried; tailwinds intact.",
|
||||||
|
strategic_actions="Build position over two weeks; cap at 5%.",
|
||||||
|
)
|
||||||
|
md = render_research_plan(p)
|
||||||
|
assert "**Recommendation**: Overweight" in md
|
||||||
|
assert "**Rationale**: Bull case carried" in md
|
||||||
|
assert "**Strategic Actions**: Build position" in md
|
||||||
|
|
||||||
|
def test_all_5_tier_ratings_render(self):
|
||||||
|
for rating in PortfolioRating:
|
||||||
|
p = ResearchPlan(
|
||||||
|
recommendation=rating,
|
||||||
|
rationale="r",
|
||||||
|
strategic_actions="s",
|
||||||
|
)
|
||||||
|
md = render_research_plan(p)
|
||||||
|
assert f"**Recommendation**: {rating.value}" in md
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trader agent: structured happy path + fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_trader_state():
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"investment_plan": "**Recommendation**: Buy\n**Rationale**: ...\n**Strategic Actions**: ...",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _structured_trader_llm(captured: dict, proposal: TraderProposal | None = None):
|
||||||
|
"""Build a MagicMock LLM whose with_structured_output binding captures the
|
||||||
|
prompt and returns a real TraderProposal so render_trader_proposal works.
|
||||||
|
"""
|
||||||
|
if proposal is None:
|
||||||
|
proposal = TraderProposal(
|
||||||
|
action=TraderAction.BUY,
|
||||||
|
reasoning="Strong setup.",
|
||||||
|
)
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = lambda prompt: (
|
||||||
|
captured.__setitem__("prompt", prompt) or proposal
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTraderAgent:
|
||||||
|
def test_structured_path_produces_rendered_markdown(self):
|
||||||
|
captured = {}
|
||||||
|
proposal = TraderProposal(
|
||||||
|
action=TraderAction.BUY,
|
||||||
|
reasoning="AI capex cycle intact; institutional flows constructive.",
|
||||||
|
entry_price=189.5,
|
||||||
|
stop_loss=178.0,
|
||||||
|
position_sizing="6% of portfolio",
|
||||||
|
)
|
||||||
|
llm = _structured_trader_llm(captured, proposal)
|
||||||
|
trader = create_trader(llm)
|
||||||
|
result = trader(_make_trader_state())
|
||||||
|
plan = result["trader_investment_plan"]
|
||||||
|
assert "**Action**: Buy" in plan
|
||||||
|
assert "**Entry Price**: 189.5" in plan
|
||||||
|
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in plan
|
||||||
|
# The same rendered markdown is also added to messages for downstream agents.
|
||||||
|
assert plan in result["messages"][0].content
|
||||||
|
|
||||||
|
def test_prompt_includes_investment_plan(self):
|
||||||
|
captured = {}
|
||||||
|
llm = _structured_trader_llm(captured)
|
||||||
|
trader = create_trader(llm)
|
||||||
|
trader(_make_trader_state())
|
||||||
|
# The investment plan is in the user message of the captured prompt.
|
||||||
|
prompt = captured["prompt"]
|
||||||
|
assert any("Proposed Investment Plan" in m["content"] for m in prompt)
|
||||||
|
|
||||||
|
def test_falls_back_to_freetext_when_structured_unavailable(self):
|
||||||
|
plain_response = (
|
||||||
|
"**Action**: Sell\n\nGuidance cut hits margins.\n\n"
|
||||||
|
"FINAL TRANSACTION PROPOSAL: **SELL**"
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||||
|
trader = create_trader(llm)
|
||||||
|
result = trader(_make_trader_state())
|
||||||
|
assert result["trader_investment_plan"] == plain_response
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Research Manager agent: structured happy path + fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rm_state():
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"investment_debate_state": {
|
||||||
|
"history": "Bull and bear arguments here.",
|
||||||
|
"bull_history": "Bull says...",
|
||||||
|
"bear_history": "Bear says...",
|
||||||
|
"current_response": "",
|
||||||
|
"judge_decision": "",
|
||||||
|
"count": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _structured_rm_llm(captured: dict, plan: ResearchPlan | None = None):
|
||||||
|
if plan is None:
|
||||||
|
plan = ResearchPlan(
|
||||||
|
recommendation=PortfolioRating.HOLD,
|
||||||
|
rationale="Balanced view across both sides.",
|
||||||
|
strategic_actions="Hold current position; reassess after earnings.",
|
||||||
|
)
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = lambda prompt: (
|
||||||
|
captured.__setitem__("prompt", prompt) or plan
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestResearchManagerAgent:
|
||||||
|
def test_structured_path_produces_rendered_markdown(self):
|
||||||
|
captured = {}
|
||||||
|
plan = ResearchPlan(
|
||||||
|
recommendation=PortfolioRating.OVERWEIGHT,
|
||||||
|
rationale="Bull case is stronger; AI tailwind intact.",
|
||||||
|
strategic_actions="Build position gradually over two weeks.",
|
||||||
|
)
|
||||||
|
llm = _structured_rm_llm(captured, plan)
|
||||||
|
rm = create_research_manager(llm)
|
||||||
|
result = rm(_make_rm_state())
|
||||||
|
ip = result["investment_plan"]
|
||||||
|
assert "**Recommendation**: Overweight" in ip
|
||||||
|
assert "**Rationale**: Bull case" in ip
|
||||||
|
assert "**Strategic Actions**: Build position" in ip
|
||||||
|
|
||||||
|
def test_prompt_uses_5_tier_rating_scale(self):
|
||||||
|
"""The RM prompt must list all five tiers so the schema enum matches user expectations."""
|
||||||
|
captured = {}
|
||||||
|
llm = _structured_rm_llm(captured)
|
||||||
|
rm = create_research_manager(llm)
|
||||||
|
rm(_make_rm_state())
|
||||||
|
prompt = captured["prompt"]
|
||||||
|
for tier in ("Buy", "Overweight", "Hold", "Underweight", "Sell"):
|
||||||
|
assert f"**{tier}**" in prompt, f"missing {tier} in prompt"
|
||||||
|
|
||||||
|
def test_falls_back_to_freetext_when_structured_unavailable(self):
|
||||||
|
plain_response = "**Recommendation**: Sell\n\n**Rationale**: ...\n\n**Strategic Actions**: ..."
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||||
|
rm = create_research_manager(llm)
|
||||||
|
result = rm(_make_rm_state())
|
||||||
|
assert result["investment_plan"] == plain_response
|
||||||
21
tests/test_ticker_symbol_handling.py
Normal file
21
tests/test_ticker_symbol_handling.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cli.utils import normalize_ticker_symbol
|
||||||
|
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TickerSymbolHandlingTests(unittest.TestCase):
|
||||||
|
def test_normalize_ticker_symbol_preserves_exchange_suffix(self):
|
||||||
|
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")
|
||||||
|
|
||||||
|
def test_build_instrument_context_mentions_exact_symbol(self):
|
||||||
|
context = build_instrument_context("7203.T")
|
||||||
|
self.assertIn("7203.T", context)
|
||||||
|
self.assertIn("exchange suffix", context)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
38
tradingagents/__init__.py
Normal file
38
tradingagents/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
|
||||||
|
# (and every llm_clients consumer) sees the user's keys regardless of
|
||||||
|
# which entry point started the process. find_dotenv(usecwd=True) walks
|
||||||
|
# from the CWD, so the installed `tradingagents` console script picks up
|
||||||
|
# the project's .env instead of stepping up from site-packages.
|
||||||
|
# load_dotenv defaults to override=False, so it never clobbers values
|
||||||
|
# the caller has already exported.
|
||||||
|
try:
|
||||||
|
from dotenv import find_dotenv, load_dotenv
|
||||||
|
|
||||||
|
load_dotenv(find_dotenv(usecwd=True))
|
||||||
|
load_dotenv(find_dotenv(".env.enterprise", usecwd=True), override=False)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# langchain-core 1.3.3 calls surface_langchain_deprecation_warnings() in
|
||||||
|
# its own __init__, which prepends default-action filters for its
|
||||||
|
# subclassed warning categories. To suppress a specific warning we must
|
||||||
|
# install our filter AFTER langchain-core has installed its own, so import
|
||||||
|
# it first. The package is a guaranteed transitive dep via langgraph.
|
||||||
|
try:
|
||||||
|
import langchain_core # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
|
||||||
|
# explicit allowed_objects, which triggers a noisy pending-deprecation
|
||||||
|
# warning from langchain-core 1.3.3 on every interpreter start. The fix
|
||||||
|
# is already merged upstream (langchain-ai/langgraph#7743, 2026-05-08)
|
||||||
|
# and will arrive in the next langgraph-checkpoint release. Remove this
|
||||||
|
# block (and the langchain_core preload above) when we bump past it.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"The default value of `allowed_objects`.*",
|
||||||
|
category=PendingDeprecationWarning,
|
||||||
|
)
|
||||||
@@ -1,27 +1,27 @@
|
|||||||
from .utils.agent_utils import Toolkit, create_msg_delete
|
from .utils.agent_utils import create_msg_delete
|
||||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||||
from .utils.memory import FinancialSituationMemory
|
|
||||||
|
|
||||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||||
from .analysts.market_analyst import create_market_analyst
|
from .analysts.market_analyst import create_market_analyst
|
||||||
from .analysts.news_analyst import create_news_analyst
|
from .analysts.news_analyst import create_news_analyst
|
||||||
from .analysts.social_media_analyst import create_social_media_analyst
|
from .analysts.sentiment_analyst import (
|
||||||
|
create_sentiment_analyst,
|
||||||
|
create_social_media_analyst, # deprecated alias kept for back-compat
|
||||||
|
)
|
||||||
|
|
||||||
from .researchers.bear_researcher import create_bear_researcher
|
from .researchers.bear_researcher import create_bear_researcher
|
||||||
from .researchers.bull_researcher import create_bull_researcher
|
from .researchers.bull_researcher import create_bull_researcher
|
||||||
|
|
||||||
from .risk_mgmt.aggresive_debator import create_risky_debator
|
from .risk_mgmt.aggressive_debator import create_aggressive_debator
|
||||||
from .risk_mgmt.conservative_debator import create_safe_debator
|
from .risk_mgmt.conservative_debator import create_conservative_debator
|
||||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||||
|
|
||||||
from .managers.research_manager import create_research_manager
|
from .managers.research_manager import create_research_manager
|
||||||
from .managers.risk_manager import create_risk_manager
|
from .managers.portfolio_manager import create_portfolio_manager
|
||||||
|
|
||||||
from .trader.trader import create_trader
|
from .trader.trader import create_trader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FinancialSituationMemory",
|
|
||||||
"Toolkit",
|
|
||||||
"AgentState",
|
"AgentState",
|
||||||
"create_msg_delete",
|
"create_msg_delete",
|
||||||
"InvestDebateState",
|
"InvestDebateState",
|
||||||
@@ -33,9 +33,10 @@ __all__ = [
|
|||||||
"create_market_analyst",
|
"create_market_analyst",
|
||||||
"create_neutral_debator",
|
"create_neutral_debator",
|
||||||
"create_news_analyst",
|
"create_news_analyst",
|
||||||
"create_risky_debator",
|
"create_aggressive_debator",
|
||||||
"create_risk_manager",
|
"create_portfolio_manager",
|
||||||
"create_safe_debator",
|
"create_conservative_debator",
|
||||||
"create_social_media_analyst",
|
"create_sentiment_analyst",
|
||||||
|
"create_social_media_analyst", # deprecated; will be removed in a future version
|
||||||
"create_trader",
|
"create_trader",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,28 +1,33 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
import json
|
build_instrument_context,
|
||||||
|
get_balance_sheet,
|
||||||
|
get_cashflow,
|
||||||
|
get_fundamentals,
|
||||||
|
get_income_statement,
|
||||||
|
get_insider_transactions,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
||||||
def create_fundamentals_analyst(llm, toolkit):
|
def create_fundamentals_analyst(llm):
|
||||||
def fundamentals_analyst_node(state):
|
def fundamentals_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
ticker = state["company_of_interest"]
|
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||||
company_name = state["company_of_interest"]
|
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
tools = [
|
||||||
tools = [toolkit.get_fundamentals_openai]
|
get_fundamentals,
|
||||||
else:
|
get_balance_sheet,
|
||||||
tools = [
|
get_cashflow,
|
||||||
toolkit.get_finnhub_company_insider_sentiment,
|
get_income_statement,
|
||||||
toolkit.get_finnhub_company_insider_transactions,
|
]
|
||||||
toolkit.get_simfin_balance_sheet,
|
|
||||||
toolkit.get_simfin_cashflow,
|
|
||||||
toolkit.get_simfin_income_stmt,
|
|
||||||
]
|
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||||
+ " Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."
|
||||||
|
+ " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements."
|
||||||
|
+ get_language_instruction(),
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
@@ -36,7 +41,7 @@ def create_fundamentals_analyst(llm, toolkit):
|
|||||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||||
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
]
|
||||||
@@ -45,15 +50,20 @@ def create_fundamentals_analyst(llm, toolkit):
|
|||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(ticker=ticker)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
|
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
||||||
|
report = ""
|
||||||
|
|
||||||
|
if len(result.tool_calls) == 0:
|
||||||
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [result],
|
||||||
"fundamentals_report": result.content,
|
"fundamentals_report": report,
|
||||||
}
|
}
|
||||||
|
|
||||||
return fundamentals_analyst_node
|
return fundamentals_analyst_node
|
||||||
|
|||||||
@@ -1,25 +1,23 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
import json
|
build_instrument_context,
|
||||||
|
get_indicators,
|
||||||
|
get_language_instruction,
|
||||||
|
get_stock_data,
|
||||||
|
)
|
||||||
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
||||||
def create_market_analyst(llm, toolkit):
|
def create_market_analyst(llm):
|
||||||
|
|
||||||
def market_analyst_node(state):
|
def market_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
ticker = state["company_of_interest"]
|
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||||
company_name = state["company_of_interest"]
|
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
tools = [
|
||||||
tools = [
|
get_stock_data,
|
||||||
toolkit.get_YFin_data_online,
|
get_indicators,
|
||||||
toolkit.get_stockstats_indicators_report_online,
|
]
|
||||||
]
|
|
||||||
else:
|
|
||||||
tools = [
|
|
||||||
toolkit.get_YFin_data,
|
|
||||||
toolkit.get_stockstats_indicators_report,
|
|
||||||
]
|
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||||
@@ -46,8 +44,9 @@ Volatility Indicators:
|
|||||||
Volume-Based Indicators:
|
Volume-Based Indicators:
|
||||||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||||
|
|
||||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
|
||||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
|
+ get_language_instruction()
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
@@ -61,7 +60,7 @@ Volume-Based Indicators:
|
|||||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||||
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
]
|
||||||
@@ -70,15 +69,20 @@ Volume-Based Indicators:
|
|||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(ticker=ticker)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
|
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
||||||
|
report = ""
|
||||||
|
|
||||||
|
if len(result.tool_calls) == 0:
|
||||||
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [result],
|
||||||
"market_report": result.content,
|
"market_report": report,
|
||||||
}
|
}
|
||||||
|
|
||||||
return market_analyst_node
|
return market_analyst_node
|
||||||
|
|||||||
@@ -1,25 +1,27 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
import time
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
import json
|
build_instrument_context,
|
||||||
|
get_global_news,
|
||||||
|
get_language_instruction,
|
||||||
|
get_news,
|
||||||
|
)
|
||||||
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
||||||
def create_news_analyst(llm, toolkit):
|
def create_news_analyst(llm):
|
||||||
def news_analyst_node(state):
|
def news_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
ticker = state["company_of_interest"]
|
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
tools = [
|
||||||
tools = [toolkit.get_global_news_openai, toolkit.get_google_news]
|
get_news,
|
||||||
else:
|
get_global_news,
|
||||||
tools = [
|
]
|
||||||
toolkit.get_finnhub_news,
|
|
||||||
toolkit.get_reddit_news,
|
|
||||||
toolkit.get_google_news,
|
|
||||||
]
|
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
|
+ get_language_instruction()
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
@@ -33,7 +35,7 @@ def create_news_analyst(llm, toolkit):
|
|||||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||||
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||||
),
|
),
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
]
|
]
|
||||||
@@ -42,14 +44,19 @@ def create_news_analyst(llm, toolkit):
|
|||||||
prompt = prompt.partial(system_message=system_message)
|
prompt = prompt.partial(system_message=system_message)
|
||||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(ticker=ticker)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
||||||
|
report = ""
|
||||||
|
|
||||||
|
if len(result.tool_calls) == 0:
|
||||||
|
report = result.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [result],
|
||||||
"news_report": result.content,
|
"news_report": report,
|
||||||
}
|
}
|
||||||
|
|
||||||
return news_analyst_node
|
return news_analyst_node
|
||||||
|
|||||||
184
tradingagents/agents/analysts/sentiment_analyst.py
Normal file
184
tradingagents/agents/analysts/sentiment_analyst.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""Sentiment analyst — multi-source sentiment analysis for a target ticker.
|
||||||
|
|
||||||
|
Previously named ``social_media_analyst``. Renamed and redesigned because
|
||||||
|
the old version had a prompt that demanded social-media analysis but the
|
||||||
|
only tool available was Yahoo Finance news — which led LLMs to fabricate
|
||||||
|
Reddit/X/StockTwits content under prompt pressure (verified live).
|
||||||
|
|
||||||
|
The redesigned agent pre-fetches three complementary data sources before
|
||||||
|
the LLM is invoked and injects them into the prompt as structured blocks:
|
||||||
|
|
||||||
|
1. News headlines — Yahoo Finance (institutional framing)
|
||||||
|
2. StockTwits messages — retail-trader posts indexed by cashtag, with
|
||||||
|
user-labeled Bullish/Bearish sentiment tags
|
||||||
|
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
|
||||||
|
|
||||||
|
The agent does not use tool-calling; the data is in the prompt from
|
||||||
|
turn 0. The LLM produces the sentiment report in a single invocation.
|
||||||
|
|
||||||
|
See: https://github.com/TauricResearch/TradingAgents/issues/557
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
build_instrument_context,
|
||||||
|
get_language_instruction,
|
||||||
|
get_news,
|
||||||
|
)
|
||||||
|
from tradingagents.dataflows.reddit import fetch_reddit_posts
|
||||||
|
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _seven_days_back(trade_date: str) -> str:
|
||||||
|
return (datetime.strptime(trade_date, "%Y-%m-%d") - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
|
||||||
|
def create_sentiment_analyst(llm):
|
||||||
|
"""Create a sentiment analyst node for the trading graph.
|
||||||
|
|
||||||
|
Pre-fetches news + StockTwits + Reddit data, injects them into the
|
||||||
|
prompt as structured blocks, and produces a sentiment report in a
|
||||||
|
single LLM call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sentiment_analyst_node(state):
|
||||||
|
ticker = state["company_of_interest"]
|
||||||
|
end_date = state["trade_date"]
|
||||||
|
start_date = _seven_days_back(end_date)
|
||||||
|
instrument_context = build_instrument_context(ticker)
|
||||||
|
|
||||||
|
# Pre-fetch all three sources. Each fetcher degrades gracefully and
|
||||||
|
# returns a string (no exceptions surface from here), so the LLM
|
||||||
|
# always sees something — either real data or a clear placeholder.
|
||||||
|
news_block = get_news.func(ticker, start_date, end_date)
|
||||||
|
stocktwits_block = fetch_stocktwits_messages(ticker, limit=30)
|
||||||
|
reddit_block = fetch_reddit_posts(ticker)
|
||||||
|
|
||||||
|
system_message = _build_system_message(
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
news_block=news_block,
|
||||||
|
stocktwits_block=stocktwits_block,
|
||||||
|
reddit_block=reddit_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"system",
|
||||||
|
"You are a helpful AI assistant, collaborating with other assistants."
|
||||||
|
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||||
|
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||||
|
"\n{system_message}\n"
|
||||||
|
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||||
|
),
|
||||||
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = prompt.partial(system_message=system_message)
|
||||||
|
prompt = prompt.partial(current_date=end_date)
|
||||||
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
|
# No bind_tools — the data is already in the prompt; a single LLM
|
||||||
|
# call produces the report directly.
|
||||||
|
chain = prompt | llm
|
||||||
|
result = chain.invoke(state["messages"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [result],
|
||||||
|
"sentiment_report": result.content,
|
||||||
|
}
|
||||||
|
|
||||||
|
return sentiment_analyst_node
|
||||||
|
|
||||||
|
|
||||||
|
def _build_system_message(
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
news_block: str,
|
||||||
|
stocktwits_block: str,
|
||||||
|
reddit_block: str,
|
||||||
|
) -> str:
|
||||||
|
"""Assemble the sentiment-analyst system message with structured data blocks."""
|
||||||
|
return f"""You are a financial market sentiment analyst. Your task is to produce a comprehensive sentiment report for {ticker} covering the period from {start_date} to {end_date}, drawing on three complementary data sources that have already been collected for you.
|
||||||
|
|
||||||
|
## Data sources (pre-fetched, in this prompt)
|
||||||
|
|
||||||
|
### News headlines — Yahoo Finance, past 7 days
|
||||||
|
Institutional framing. Fact-driven, slower-moving signal.
|
||||||
|
|
||||||
|
<start_of_news>
|
||||||
|
{news_block}
|
||||||
|
<end_of_news>
|
||||||
|
|
||||||
|
### StockTwits messages — retail-trader social platform indexed by cashtag
|
||||||
|
Fast-moving signal. Each message carries a user-labeled sentiment tag (Bullish / Bearish / no-label) plus the message body.
|
||||||
|
|
||||||
|
<start_of_stocktwits>
|
||||||
|
{stocktwits_block}
|
||||||
|
<end_of_stocktwits>
|
||||||
|
|
||||||
|
### Reddit posts — r/wallstreetbets, r/stocks, r/investing (past 7 days)
|
||||||
|
Community discussion. Engagement signal via upvote score and comment count. Subreddit character matters (r/wallstreetbets is often contrarian/exuberant; r/stocks more measured; r/investing longer-term).
|
||||||
|
|
||||||
|
<start_of_reddit>
|
||||||
|
{reddit_block}
|
||||||
|
<end_of_reddit>
|
||||||
|
|
||||||
|
## How to analyze this data (best practices)
|
||||||
|
|
||||||
|
1. **Read the StockTwits Bullish/Bearish ratio as a leading retail-sentiment signal.** A 70/30 bullish/bearish split is moderately bullish; ≥90/10 may indicate over-extension and contrarian risk; 50/50 is uncertainty. Sample size matters — base rates on the actual message count, not percentages alone.
|
||||||
|
|
||||||
|
2. **Look for cross-source divergences.** If news framing is bearish but StockTwits is overwhelmingly bullish, that mismatch is itself a signal — it can mean retail is leaning into a thesis the news flow hasn't caught up to (or vice versa, that retail is chasing while institutions are cautious).
|
||||||
|
|
||||||
|
3. **Weight Reddit posts by engagement.** A 400-upvote / 200-comment thread reflects community attention; a 3-upvote post is noise. Read the body excerpts for context — the title alone often misleads.
|
||||||
|
|
||||||
|
4. **Distinguish opinion from event.** A news headline ("Nvidia announces $500M Corning deal") is an event; a StockTwits post ("buying NVDA, this is going to moon") is opinion. Both are inputs but should be weighted differently in your conclusions.
|
||||||
|
|
||||||
|
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
|
||||||
|
|
||||||
|
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this caveat explicitly. If the sources are silent on a given subreddit, say so.
|
||||||
|
|
||||||
|
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
|
||||||
|
|
||||||
|
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
Produce a sentiment report covering, in order:
|
||||||
|
|
||||||
|
1. **Overall sentiment direction** — Bullish / Bearish / Neutral / Mixed — with a brief confidence note based on data quality and sample size.
|
||||||
|
2. **Source-by-source breakdown** — what each of news / StockTwits / Reddit is telling you, with specific evidence (cite message counts, ratios, notable posts).
|
||||||
|
3. **Divergences, alignments, and key narratives** across sources.
|
||||||
|
4. **Catalysts and risks** surfaced by the data.
|
||||||
|
5. **Markdown table** at the end summarizing key sentiment signals, their direction, source, and supporting evidence.
|
||||||
|
|
||||||
|
{get_language_instruction()}"""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Backwards-compatibility shim
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def create_social_media_analyst(llm):
|
||||||
|
"""Deprecated alias for :func:`create_sentiment_analyst`.
|
||||||
|
|
||||||
|
Kept so existing code that imports ``create_social_media_analyst``
|
||||||
|
continues to work.
|
||||||
|
|
||||||
|
.. deprecated::
|
||||||
|
Import :func:`create_sentiment_analyst` directly instead.
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
"create_social_media_analyst is deprecated and will be removed in a "
|
||||||
|
"future version. Use create_sentiment_analyst instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return create_sentiment_analyst(llm)
|
||||||
@@ -1,55 +1,23 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
"""Backwards-compatibility shim for the renamed module.
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
The agent is now ``sentiment_analyst`` and aggregates Yahoo Finance news,
|
||||||
|
StockTwits cashtag streams, and Reddit posts into a single sentiment
|
||||||
|
report. Import from ``tradingagents.agents.analysts.sentiment_analyst``
|
||||||
|
going forward; this module will be removed in a future release.
|
||||||
|
|
||||||
def create_social_media_analyst(llm, toolkit):
|
See: https://github.com/TauricResearch/TradingAgents/issues/557
|
||||||
def social_media_analyst_node(state):
|
"""
|
||||||
current_date = state["trade_date"]
|
|
||||||
ticker = state["company_of_interest"]
|
|
||||||
company_name = state["company_of_interest"]
|
|
||||||
|
|
||||||
if toolkit.config["online_tools"]:
|
import warnings as _warnings
|
||||||
tools = [toolkit.get_stock_news_openai]
|
|
||||||
else:
|
|
||||||
tools = [
|
|
||||||
toolkit.get_reddit_stock_info,
|
|
||||||
]
|
|
||||||
|
|
||||||
system_message = (
|
from tradingagents.agents.analysts.sentiment_analyst import ( # noqa: F401
|
||||||
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
create_sentiment_analyst,
|
||||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
create_social_media_analyst,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
_warnings.warn(
|
||||||
[
|
"tradingagents.agents.analysts.social_media_analyst is deprecated. "
|
||||||
(
|
"Import from tradingagents.agents.analysts.sentiment_analyst instead.",
|
||||||
"system",
|
DeprecationWarning,
|
||||||
"You are a helpful AI assistant, collaborating with other assistants."
|
stacklevel=2,
|
||||||
" Use the provided tools to progress towards answering the question."
|
)
|
||||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
|
||||||
" will help where you left off. Execute what you can to make progress."
|
|
||||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
|
||||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
|
||||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
|
||||||
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
|
||||||
),
|
|
||||||
MessagesPlaceholder(variable_name="messages"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = prompt.partial(system_message=system_message)
|
|
||||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
|
||||||
prompt = prompt.partial(current_date=current_date)
|
|
||||||
prompt = prompt.partial(ticker=ticker)
|
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
|
||||||
|
|
||||||
result = chain.invoke(state["messages"])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"messages": [result],
|
|
||||||
"sentiment_report": result.content,
|
|
||||||
}
|
|
||||||
|
|
||||||
return social_media_analyst_node
|
|
||||||
|
|||||||
92
tradingagents/agents/managers/portfolio_manager.py
Normal file
92
tradingagents/agents/managers/portfolio_manager.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Portfolio Manager: synthesises the risk-analyst debate into the final decision.
|
||||||
|
|
||||||
|
Uses LangChain's ``with_structured_output`` so the LLM produces a typed
|
||||||
|
``PortfolioDecision`` directly, in a single call. The result is rendered
|
||||||
|
back to markdown for storage in ``final_trade_decision`` so memory log,
|
||||||
|
CLI display, and saved reports continue to consume the same shape they do
|
||||||
|
today. When a provider does not expose structured output, the agent falls
|
||||||
|
back gracefully to free-text generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from tradingagents.agents.schemas import PortfolioDecision, render_pm_decision
|
||||||
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
build_instrument_context,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
from tradingagents.agents.utils.structured import (
|
||||||
|
bind_structured,
|
||||||
|
invoke_structured_or_freetext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_portfolio_manager(llm):
|
||||||
|
structured_llm = bind_structured(llm, PortfolioDecision, "Portfolio Manager")
|
||||||
|
|
||||||
|
def portfolio_manager_node(state) -> dict:
|
||||||
|
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||||
|
|
||||||
|
history = state["risk_debate_state"]["history"]
|
||||||
|
risk_debate_state = state["risk_debate_state"]
|
||||||
|
research_plan = state["investment_plan"]
|
||||||
|
trader_plan = state["trader_investment_plan"]
|
||||||
|
|
||||||
|
past_context = state.get("past_context", "")
|
||||||
|
lessons_line = (
|
||||||
|
f"- Lessons from prior decisions and outcomes:\n{past_context}\n"
|
||||||
|
if past_context
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision.
|
||||||
|
|
||||||
|
{instrument_context}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Rating Scale** (use exactly one):
|
||||||
|
- **Buy**: Strong conviction to enter or add to position
|
||||||
|
- **Overweight**: Favorable outlook, gradually increase exposure
|
||||||
|
- **Hold**: Maintain current position, no action needed
|
||||||
|
- **Underweight**: Reduce exposure, take partial profits
|
||||||
|
- **Sell**: Exit position or avoid entry
|
||||||
|
|
||||||
|
**Context:**
|
||||||
|
- Research Manager's investment plan: **{research_plan}**
|
||||||
|
- Trader's transaction proposal: **{trader_plan}**
|
||||||
|
{lessons_line}
|
||||||
|
**Risk Analysts Debate History:**
|
||||||
|
{history}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}"""
|
||||||
|
|
||||||
|
final_trade_decision = invoke_structured_or_freetext(
|
||||||
|
structured_llm,
|
||||||
|
llm,
|
||||||
|
prompt,
|
||||||
|
render_pm_decision,
|
||||||
|
"Portfolio Manager",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_risk_debate_state = {
|
||||||
|
"judge_decision": final_trade_decision,
|
||||||
|
"history": risk_debate_state["history"],
|
||||||
|
"aggressive_history": risk_debate_state["aggressive_history"],
|
||||||
|
"conservative_history": risk_debate_state["conservative_history"],
|
||||||
|
"neutral_history": risk_debate_state["neutral_history"],
|
||||||
|
"latest_speaker": "Judge",
|
||||||
|
"current_aggressive_response": risk_debate_state["current_aggressive_response"],
|
||||||
|
"current_conservative_response": risk_debate_state["current_conservative_response"],
|
||||||
|
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
||||||
|
"count": risk_debate_state["count"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"risk_debate_state": new_risk_debate_state,
|
||||||
|
"final_trade_decision": final_trade_decision,
|
||||||
|
}
|
||||||
|
|
||||||
|
return portfolio_manager_node
|
||||||
@@ -1,55 +1,67 @@
|
|||||||
import time
|
"""Research Manager: turns the bull/bear debate into a structured investment plan for the trader."""
|
||||||
import json
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from tradingagents.agents.schemas import ResearchPlan, render_research_plan
|
||||||
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
build_instrument_context,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
from tradingagents.agents.utils.structured import (
|
||||||
|
bind_structured,
|
||||||
|
invoke_structured_or_freetext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_research_manager(llm, memory):
|
def create_research_manager(llm):
|
||||||
|
structured_llm = bind_structured(llm, ResearchPlan, "Research Manager")
|
||||||
|
|
||||||
def research_manager_node(state) -> dict:
|
def research_manager_node(state) -> dict:
|
||||||
|
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||||
history = state["investment_debate_state"].get("history", "")
|
history = state["investment_debate_state"].get("history", "")
|
||||||
market_research_report = state["market_report"]
|
|
||||||
sentiment_report = state["sentiment_report"]
|
|
||||||
news_report = state["news_report"]
|
|
||||||
fundamentals_report = state["fundamentals_report"]
|
|
||||||
|
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
prompt = f"""As the Research Manager and debate facilitator, your role is to critically evaluate this round of debate and deliver a clear, actionable investment plan for the trader.
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
{instrument_context}
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
|
|
||||||
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
---
|
||||||
|
|
||||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
**Rating Scale** (use exactly one):
|
||||||
|
- **Buy**: Strong conviction in the bull thesis; recommend taking or growing the position
|
||||||
|
- **Overweight**: Constructive view; recommend gradually increasing exposure
|
||||||
|
- **Hold**: Balanced view; recommend maintaining the current position
|
||||||
|
- **Underweight**: Cautious view; recommend trimming exposure
|
||||||
|
- **Sell**: Strong conviction in the bear thesis; recommend exiting or avoiding the position
|
||||||
|
|
||||||
Additionally, develop a detailed investment plan for the trader. This should include:
|
Commit to a clear stance whenever the debate's strongest arguments warrant one; reserve Hold for situations where the evidence on both sides is genuinely balanced.
|
||||||
|
|
||||||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
---
|
||||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
|
||||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
|
||||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
|
||||||
|
|
||||||
Here are your past reflections on mistakes:
|
**Debate History:**
|
||||||
\"{past_memory_str}\"
|
{history}""" + get_language_instruction()
|
||||||
|
|
||||||
Here is the debate:
|
investment_plan = invoke_structured_or_freetext(
|
||||||
Debate History:
|
structured_llm,
|
||||||
{history}"""
|
llm,
|
||||||
response = llm.invoke(prompt)
|
prompt,
|
||||||
|
render_research_plan,
|
||||||
|
"Research Manager",
|
||||||
|
)
|
||||||
|
|
||||||
new_investment_debate_state = {
|
new_investment_debate_state = {
|
||||||
"judge_decision": response.content,
|
"judge_decision": investment_plan,
|
||||||
"history": investment_debate_state.get("history", ""),
|
"history": investment_debate_state.get("history", ""),
|
||||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||||
"current_response": response.content,
|
"current_response": investment_plan,
|
||||||
"count": investment_debate_state["count"],
|
"count": investment_debate_state["count"],
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"investment_debate_state": new_investment_debate_state,
|
"investment_debate_state": new_investment_debate_state,
|
||||||
"investment_plan": response.content,
|
"investment_plan": investment_plan,
|
||||||
}
|
}
|
||||||
|
|
||||||
return research_manager_node
|
return research_manager_node
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_risk_manager(llm, memory):
|
|
||||||
def risk_manager_node(state) -> dict:
|
|
||||||
|
|
||||||
company_name = state["company_of_interest"]
|
|
||||||
|
|
||||||
history = state["risk_debate_state"]["history"]
|
|
||||||
risk_debate_state = state["risk_debate_state"]
|
|
||||||
market_research_report = state["market_report"]
|
|
||||||
news_report = state["news_report"]
|
|
||||||
fundamentals_report = state["news_report"]
|
|
||||||
sentiment_report = state["sentiment_report"]
|
|
||||||
trader_plan = state["investment_plan"]
|
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
|
|
||||||
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
|
||||||
|
|
||||||
Guidelines for Decision-Making:
|
|
||||||
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
|
|
||||||
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
|
|
||||||
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
|
|
||||||
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
|
|
||||||
|
|
||||||
Deliverables:
|
|
||||||
- A clear and actionable recommendation: Buy, Sell, or Hold.
|
|
||||||
- Detailed reasoning anchored in the debate and past reflections.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Analysts Debate History:**
|
|
||||||
{history}
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
|
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
|
||||||
|
|
||||||
new_risk_debate_state = {
|
|
||||||
"judge_decision": response.content,
|
|
||||||
"history": risk_debate_state["history"],
|
|
||||||
"risky_history": risk_debate_state["risky_history"],
|
|
||||||
"safe_history": risk_debate_state["safe_history"],
|
|
||||||
"neutral_history": risk_debate_state["neutral_history"],
|
|
||||||
"latest_speaker": "Judge",
|
|
||||||
"current_risky_response": risk_debate_state["current_risky_response"],
|
|
||||||
"current_safe_response": risk_debate_state["current_safe_response"],
|
|
||||||
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
|
||||||
"count": risk_debate_state["count"],
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"risk_debate_state": new_risk_debate_state,
|
|
||||||
"final_trade_decision": response.content,
|
|
||||||
}
|
|
||||||
|
|
||||||
return risk_manager_node
|
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
from langchain_core.messages import AIMessage
|
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_bear_researcher(llm, memory):
|
def create_bear_researcher(llm):
|
||||||
def bear_node(state) -> dict:
|
def bear_node(state) -> dict:
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
history = investment_debate_state.get("history", "")
|
history = investment_debate_state.get("history", "")
|
||||||
@@ -15,13 +13,6 @@ def create_bear_researcher(llm, memory):
|
|||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
|
|
||||||
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||||
|
|
||||||
Key points to focus on:
|
Key points to focus on:
|
||||||
@@ -40,9 +31,8 @@ Latest world affairs news: {news_report}
|
|||||||
Company fundamentals report: {fundamentals_report}
|
Company fundamentals report: {fundamentals_report}
|
||||||
Conversation history of the debate: {history}
|
Conversation history of the debate: {history}
|
||||||
Last bull argument: {current_response}
|
Last bull argument: {current_response}
|
||||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock.
|
||||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
""" + get_language_instruction()
|
||||||
"""
|
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
from langchain_core.messages import AIMessage
|
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_bull_researcher(llm, memory):
|
def create_bull_researcher(llm):
|
||||||
def bull_node(state) -> dict:
|
def bull_node(state) -> dict:
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
history = investment_debate_state.get("history", "")
|
history = investment_debate_state.get("history", "")
|
||||||
@@ -15,13 +13,6 @@ def create_bull_researcher(llm, memory):
|
|||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
|
|
||||||
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||||
|
|
||||||
Key points to focus on:
|
Key points to focus on:
|
||||||
@@ -38,9 +29,8 @@ Latest world affairs news: {news_report}
|
|||||||
Company fundamentals report: {fundamentals_report}
|
Company fundamentals report: {fundamentals_report}
|
||||||
Conversation history of the debate: {history}
|
Conversation history of the debate: {history}
|
||||||
Last bear argument: {current_response}
|
Last bear argument: {current_response}
|
||||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position.
|
||||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
""" + get_language_instruction()
|
||||||
"""
|
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
import time
|
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_risky_debator(llm):
|
def create_aggressive_debator(llm):
|
||||||
def risky_node(state) -> dict:
|
def aggressive_node(state) -> dict:
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
history = risk_debate_state.get("history", "")
|
history = risk_debate_state.get("history", "")
|
||||||
risky_history = risk_debate_state.get("risky_history", "")
|
aggressive_history = risk_debate_state.get("aggressive_history", "")
|
||||||
|
|
||||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
current_conservative_response = risk_debate_state.get("current_conservative_response", "")
|
||||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||||
|
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
@@ -18,7 +17,7 @@ def create_risky_debator(llm):
|
|||||||
|
|
||||||
trader_decision = state["trader_investment_plan"]
|
trader_decision = state["trader_investment_plan"]
|
||||||
|
|
||||||
prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
prompt = f"""As the Aggressive Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||||
|
|
||||||
{trader_decision}
|
{trader_decision}
|
||||||
|
|
||||||
@@ -28,22 +27,22 @@ Market Research Report: {market_research_report}
|
|||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
Company Fundamentals Report: {fundamentals_report}
|
Company Fundamentals Report: {fundamentals_report}
|
||||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||||
|
|
||||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
argument = f"Risky Analyst: {response.content}"
|
argument = f"Aggressive Analyst: {response.content}"
|
||||||
|
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
"history": history + "\n" + argument,
|
"history": history + "\n" + argument,
|
||||||
"risky_history": risky_history + "\n" + argument,
|
"aggressive_history": aggressive_history + "\n" + argument,
|
||||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
"conservative_history": risk_debate_state.get("conservative_history", ""),
|
||||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||||
"latest_speaker": "Risky",
|
"latest_speaker": "Aggressive",
|
||||||
"current_risky_response": argument,
|
"current_aggressive_response": argument,
|
||||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
"current_conservative_response": risk_debate_state.get("current_conservative_response", ""),
|
||||||
"current_neutral_response": risk_debate_state.get(
|
"current_neutral_response": risk_debate_state.get(
|
||||||
"current_neutral_response", ""
|
"current_neutral_response", ""
|
||||||
),
|
),
|
||||||
@@ -52,4 +51,4 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
|
|||||||
|
|
||||||
return {"risk_debate_state": new_risk_debate_state}
|
return {"risk_debate_state": new_risk_debate_state}
|
||||||
|
|
||||||
return risky_node
|
return aggressive_node
|
||||||
@@ -1,15 +1,13 @@
|
|||||||
from langchain_core.messages import AIMessage
|
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_safe_debator(llm):
|
def create_conservative_debator(llm):
|
||||||
def safe_node(state) -> dict:
|
def conservative_node(state) -> dict:
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
history = risk_debate_state.get("history", "")
|
history = risk_debate_state.get("history", "")
|
||||||
safe_history = risk_debate_state.get("safe_history", "")
|
conservative_history = risk_debate_state.get("conservative_history", "")
|
||||||
|
|
||||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
current_aggressive_response = risk_debate_state.get("current_aggressive_response", "")
|
||||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||||
|
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
@@ -19,34 +17,34 @@ def create_safe_debator(llm):
|
|||||||
|
|
||||||
trader_decision = state["trader_investment_plan"]
|
trader_decision = state["trader_investment_plan"]
|
||||||
|
|
||||||
prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
prompt = f"""As the Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||||
|
|
||||||
{trader_decision}
|
{trader_decision}
|
||||||
|
|
||||||
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
Your task is to actively counter the arguments of the Aggressive and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
||||||
|
|
||||||
Market Research Report: {market_research_report}
|
Market Research Report: {market_research_report}
|
||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
Company Fundamentals Report: {fundamentals_report}
|
Company Fundamentals Report: {fundamentals_report}
|
||||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||||
|
|
||||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
argument = f"Safe Analyst: {response.content}"
|
argument = f"Conservative Analyst: {response.content}"
|
||||||
|
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
"history": history + "\n" + argument,
|
"history": history + "\n" + argument,
|
||||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
"aggressive_history": risk_debate_state.get("aggressive_history", ""),
|
||||||
"safe_history": safe_history + "\n" + argument,
|
"conservative_history": conservative_history + "\n" + argument,
|
||||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||||
"latest_speaker": "Safe",
|
"latest_speaker": "Conservative",
|
||||||
"current_risky_response": risk_debate_state.get(
|
"current_aggressive_response": risk_debate_state.get(
|
||||||
"current_risky_response", ""
|
"current_aggressive_response", ""
|
||||||
),
|
),
|
||||||
"current_safe_response": argument,
|
"current_conservative_response": argument,
|
||||||
"current_neutral_response": risk_debate_state.get(
|
"current_neutral_response": risk_debate_state.get(
|
||||||
"current_neutral_response", ""
|
"current_neutral_response", ""
|
||||||
),
|
),
|
||||||
@@ -55,4 +53,4 @@ Engage by questioning their optimism and emphasizing the potential downsides the
|
|||||||
|
|
||||||
return {"risk_debate_state": new_risk_debate_state}
|
return {"risk_debate_state": new_risk_debate_state}
|
||||||
|
|
||||||
return safe_node
|
return conservative_node
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import time
|
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def create_neutral_debator(llm):
|
def create_neutral_debator(llm):
|
||||||
@@ -8,8 +7,8 @@ def create_neutral_debator(llm):
|
|||||||
history = risk_debate_state.get("history", "")
|
history = risk_debate_state.get("history", "")
|
||||||
neutral_history = risk_debate_state.get("neutral_history", "")
|
neutral_history = risk_debate_state.get("neutral_history", "")
|
||||||
|
|
||||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
current_aggressive_response = risk_debate_state.get("current_aggressive_response", "")
|
||||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
current_conservative_response = risk_debate_state.get("current_conservative_response", "")
|
||||||
|
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
@@ -22,15 +21,15 @@ def create_neutral_debator(llm):
|
|||||||
|
|
||||||
{trader_decision}
|
{trader_decision}
|
||||||
|
|
||||||
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
Your task is to challenge both the Aggressive and Conservative Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
||||||
|
|
||||||
Market Research Report: {market_research_report}
|
Market Research Report: {market_research_report}
|
||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
Company Fundamentals Report: {fundamentals_report}
|
Company Fundamentals Report: {fundamentals_report}
|
||||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||||
|
|
||||||
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" + get_language_instruction()
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
@@ -38,14 +37,14 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
|
|||||||
|
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
"history": history + "\n" + argument,
|
"history": history + "\n" + argument,
|
||||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
"aggressive_history": risk_debate_state.get("aggressive_history", ""),
|
||||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
"conservative_history": risk_debate_state.get("conservative_history", ""),
|
||||||
"neutral_history": neutral_history + "\n" + argument,
|
"neutral_history": neutral_history + "\n" + argument,
|
||||||
"latest_speaker": "Neutral",
|
"latest_speaker": "Neutral",
|
||||||
"current_risky_response": risk_debate_state.get(
|
"current_aggressive_response": risk_debate_state.get(
|
||||||
"current_risky_response", ""
|
"current_aggressive_response", ""
|
||||||
),
|
),
|
||||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
"current_conservative_response": risk_debate_state.get("current_conservative_response", ""),
|
||||||
"current_neutral_response": argument,
|
"current_neutral_response": argument,
|
||||||
"count": risk_debate_state["count"] + 1,
|
"count": risk_debate_state["count"] + 1,
|
||||||
}
|
}
|
||||||
|
|||||||
228
tradingagents/agents/schemas.py
Normal file
228
tradingagents/agents/schemas.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""Pydantic schemas used by agents that produce structured output.
|
||||||
|
|
||||||
|
The framework's primary artifact is still prose: each agent's natural-language
|
||||||
|
reasoning is what users read in the saved markdown reports and what the
|
||||||
|
downstream agents read as context. Structured output is layered onto the
|
||||||
|
three decision-making agents (Research Manager, Trader, Portfolio Manager)
|
||||||
|
so that:
|
||||||
|
|
||||||
|
- Their outputs follow consistent section headers across runs and providers
|
||||||
|
- Each provider's native structured-output mode is used (json_schema for
|
||||||
|
OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic)
|
||||||
|
- Schema field descriptions become the model's output instructions, freeing
|
||||||
|
the prompt body to focus on context and the rating-scale guidance
|
||||||
|
- A render helper turns the parsed Pydantic instance back into the same
|
||||||
|
markdown shape the rest of the system already consumes, so display,
|
||||||
|
memory log, and saved reports keep working unchanged
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared rating types
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PortfolioRating(str, Enum):
|
||||||
|
"""5-tier rating used by the Research Manager and Portfolio Manager."""
|
||||||
|
|
||||||
|
BUY = "Buy"
|
||||||
|
OVERWEIGHT = "Overweight"
|
||||||
|
HOLD = "Hold"
|
||||||
|
UNDERWEIGHT = "Underweight"
|
||||||
|
SELL = "Sell"
|
||||||
|
|
||||||
|
|
||||||
|
class TraderAction(str, Enum):
|
||||||
|
"""3-tier transaction direction used by the Trader.
|
||||||
|
|
||||||
|
The Trader's job is to translate the Research Manager's investment plan
|
||||||
|
into a concrete transaction proposal: should the desk execute a Buy, a
|
||||||
|
Sell, or sit on Hold this round. Position sizing and the nuanced
|
||||||
|
Overweight / Underweight calls happen later at the Portfolio Manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BUY = "Buy"
|
||||||
|
HOLD = "Hold"
|
||||||
|
SELL = "Sell"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Research Manager
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ResearchPlan(BaseModel):
|
||||||
|
"""Structured investment plan produced by the Research Manager.
|
||||||
|
|
||||||
|
Hand-off to the Trader: the recommendation pins the directional view,
|
||||||
|
the rationale captures which side of the bull/bear debate carried the
|
||||||
|
argument, and the strategic actions translate that into concrete
|
||||||
|
instructions the trader can execute against.
|
||||||
|
"""
|
||||||
|
|
||||||
|
recommendation: PortfolioRating = Field(
|
||||||
|
description=(
|
||||||
|
"The investment recommendation. Exactly one of Buy / Overweight / "
|
||||||
|
"Hold / Underweight / Sell. Reserve Hold for situations where the "
|
||||||
|
"evidence on both sides is genuinely balanced; otherwise commit to "
|
||||||
|
"the side with the stronger arguments."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
rationale: str = Field(
|
||||||
|
description=(
|
||||||
|
"Conversational summary of the key points from both sides of the "
|
||||||
|
"debate, ending with which arguments led to the recommendation. "
|
||||||
|
"Speak naturally, as if to a teammate."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
strategic_actions: str = Field(
|
||||||
|
description=(
|
||||||
|
"Concrete steps for the trader to implement the recommendation, "
|
||||||
|
"including position sizing guidance consistent with the rating."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_research_plan(plan: ResearchPlan) -> str:
|
||||||
|
"""Render a ResearchPlan to markdown for storage and the trader's prompt context."""
|
||||||
|
return "\n".join([
|
||||||
|
f"**Recommendation**: {plan.recommendation.value}",
|
||||||
|
"",
|
||||||
|
f"**Rationale**: {plan.rationale}",
|
||||||
|
"",
|
||||||
|
f"**Strategic Actions**: {plan.strategic_actions}",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trader
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TraderProposal(BaseModel):
|
||||||
|
"""Structured transaction proposal produced by the Trader.
|
||||||
|
|
||||||
|
The trader reads the Research Manager's investment plan and the analyst
|
||||||
|
reports, then turns them into a concrete transaction: what action to
|
||||||
|
take, the reasoning that justifies it, and the practical levels for
|
||||||
|
entry, stop-loss, and sizing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
action: TraderAction = Field(
|
||||||
|
description="The transaction direction. Exactly one of Buy / Hold / Sell.",
|
||||||
|
)
|
||||||
|
reasoning: str = Field(
|
||||||
|
description=(
|
||||||
|
"The case for this action, anchored in the analysts' reports and "
|
||||||
|
"the research plan. Two to four sentences."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
entry_price: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional entry price target in the instrument's quote currency.",
|
||||||
|
)
|
||||||
|
stop_loss: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional stop-loss price in the instrument's quote currency.",
|
||||||
|
)
|
||||||
|
position_sizing: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional sizing guidance, e.g. '5% of portfolio'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_trader_proposal(proposal: TraderProposal) -> str:
|
||||||
|
"""Render a TraderProposal to markdown.
|
||||||
|
|
||||||
|
The trailing ``FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**`` line is
|
||||||
|
preserved for backward compatibility with the analyst stop-signal text
|
||||||
|
and any external code that greps for it.
|
||||||
|
"""
|
||||||
|
parts = [
|
||||||
|
f"**Action**: {proposal.action.value}",
|
||||||
|
"",
|
||||||
|
f"**Reasoning**: {proposal.reasoning}",
|
||||||
|
]
|
||||||
|
if proposal.entry_price is not None:
|
||||||
|
parts.extend(["", f"**Entry Price**: {proposal.entry_price}"])
|
||||||
|
if proposal.stop_loss is not None:
|
||||||
|
parts.extend(["", f"**Stop Loss**: {proposal.stop_loss}"])
|
||||||
|
if proposal.position_sizing:
|
||||||
|
parts.extend(["", f"**Position Sizing**: {proposal.position_sizing}"])
|
||||||
|
parts.extend([
|
||||||
|
"",
|
||||||
|
f"FINAL TRANSACTION PROPOSAL: **{proposal.action.value.upper()}**",
|
||||||
|
])
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Portfolio Manager
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PortfolioDecision(BaseModel):
|
||||||
|
"""Structured output produced by the Portfolio Manager.
|
||||||
|
|
||||||
|
The model fills every field as part of its primary LLM call; no separate
|
||||||
|
extraction pass is required. Field descriptions double as the model's
|
||||||
|
output instructions, so the prompt body only needs to convey context and
|
||||||
|
the rating-scale guidance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rating: PortfolioRating = Field(
|
||||||
|
description=(
|
||||||
|
"The final position rating. Exactly one of Buy / Overweight / Hold / "
|
||||||
|
"Underweight / Sell, picked based on the analysts' debate."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
executive_summary: str = Field(
|
||||||
|
description=(
|
||||||
|
"A concise action plan covering entry strategy, position sizing, "
|
||||||
|
"key risk levels, and time horizon. Two to four sentences."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
investment_thesis: str = Field(
|
||||||
|
description=(
|
||||||
|
"Detailed reasoning anchored in specific evidence from the analysts' "
|
||||||
|
"debate. If prior lessons are referenced in the prompt context, "
|
||||||
|
"incorporate them; otherwise rely solely on the current analysis."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
price_target: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional target price in the instrument's quote currency.",
|
||||||
|
)
|
||||||
|
time_horizon: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional recommended holding period, e.g. '3-6 months'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_pm_decision(decision: PortfolioDecision) -> str:
|
||||||
|
"""Render a PortfolioDecision back to the markdown shape the rest of the system expects.
|
||||||
|
|
||||||
|
Memory log, CLI display, and saved report files all read this markdown,
|
||||||
|
so the rendered output preserves the exact section headers (``**Rating**``,
|
||||||
|
``**Executive Summary**``, ``**Investment Thesis**``) that downstream
|
||||||
|
parsers and the report writers already handle.
|
||||||
|
"""
|
||||||
|
parts = [
|
||||||
|
f"**Rating**: {decision.rating.value}",
|
||||||
|
"",
|
||||||
|
f"**Executive Summary**: {decision.executive_summary}",
|
||||||
|
"",
|
||||||
|
f"**Investment Thesis**: {decision.investment_thesis}",
|
||||||
|
]
|
||||||
|
if decision.price_target is not None:
|
||||||
|
parts.extend(["", f"**Price Target**: {decision.price_target}"])
|
||||||
|
if decision.time_horizon:
|
||||||
|
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
|
||||||
|
return "\n".join(parts)
|
||||||
@@ -1,42 +1,64 @@
|
|||||||
|
"""Trader: turns the Research Manager's investment plan into a concrete transaction proposal."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import time
|
|
||||||
import json
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from tradingagents.agents.schemas import TraderProposal, render_trader_proposal
|
||||||
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
build_instrument_context,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
from tradingagents.agents.utils.structured import (
|
||||||
|
bind_structured,
|
||||||
|
invoke_structured_or_freetext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_trader(llm, memory):
|
def create_trader(llm):
|
||||||
|
structured_llm = bind_structured(llm, TraderProposal, "Trader")
|
||||||
|
|
||||||
def trader_node(state, name):
|
def trader_node(state, name):
|
||||||
company_name = state["company_of_interest"]
|
company_name = state["company_of_interest"]
|
||||||
|
instrument_context = build_instrument_context(company_name)
|
||||||
investment_plan = state["investment_plan"]
|
investment_plan = state["investment_plan"]
|
||||||
market_research_report = state["market_report"]
|
|
||||||
sentiment_report = state["sentiment_report"]
|
|
||||||
news_report = state["news_report"]
|
|
||||||
fundamentals_report = state["fundamentals_report"]
|
|
||||||
|
|
||||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
|
||||||
|
|
||||||
past_memory_str = ""
|
|
||||||
for i, rec in enumerate(past_memories, 1):
|
|
||||||
past_memory_str += rec["recommendation"] + "\n\n"
|
|
||||||
|
|
||||||
context = {
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
|
"content": (
|
||||||
|
"You are a trading agent analyzing market data to make investment decisions. "
|
||||||
|
"Based on your analysis, provide a specific recommendation to buy, sell, or hold. "
|
||||||
|
"Anchor your reasoning in the analysts' reports and the research plan."
|
||||||
|
+ get_language_instruction()
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"Based on a comprehensive analysis by a team of analysts, here is an investment "
|
||||||
|
f"plan tailored for {company_name}. {instrument_context} This plan incorporates "
|
||||||
|
f"insights from current technical market trends, macroeconomic indicators, and "
|
||||||
|
f"social media sentiment. Use this plan as a foundation for evaluating your next "
|
||||||
|
f"trading decision.\n\nProposed Investment Plan: {investment_plan}\n\n"
|
||||||
|
f"Leverage these insights to make an informed and strategic decision."
|
||||||
|
),
|
||||||
},
|
},
|
||||||
context,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
result = llm.invoke(messages)
|
trader_plan = invoke_structured_or_freetext(
|
||||||
|
structured_llm,
|
||||||
|
llm,
|
||||||
|
messages,
|
||||||
|
render_trader_proposal,
|
||||||
|
"Trader",
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [AIMessage(content=trader_plan)],
|
||||||
"trader_investment_plan": result.content,
|
"trader_investment_plan": trader_plan,
|
||||||
"sender": name,
|
"sender": name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
from typing import Annotated, Sequence
|
from typing import Annotated
|
||||||
from datetime import date, timedelta, datetime
|
from typing_extensions import TypedDict
|
||||||
from typing_extensions import TypedDict, Optional
|
from langgraph.graph import MessagesState
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from tradingagents.agents import *
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
|
||||||
|
|
||||||
|
|
||||||
# Researcher team state
|
# Researcher team state
|
||||||
@@ -23,22 +19,22 @@ class InvestDebateState(TypedDict):
|
|||||||
|
|
||||||
# Risk management team state
|
# Risk management team state
|
||||||
class RiskDebateState(TypedDict):
|
class RiskDebateState(TypedDict):
|
||||||
risky_history: Annotated[
|
aggressive_history: Annotated[
|
||||||
str, "Risky Agent's Conversation history"
|
str, "Aggressive Agent's Conversation history"
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
safe_history: Annotated[
|
conservative_history: Annotated[
|
||||||
str, "Safe Agent's Conversation history"
|
str, "Conservative Agent's Conversation history"
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
neutral_history: Annotated[
|
neutral_history: Annotated[
|
||||||
str, "Neutral Agent's Conversation history"
|
str, "Neutral Agent's Conversation history"
|
||||||
] # Conversation history
|
] # Conversation history
|
||||||
history: Annotated[str, "Conversation history"] # Conversation history
|
history: Annotated[str, "Conversation history"] # Conversation history
|
||||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||||
current_risky_response: Annotated[
|
current_aggressive_response: Annotated[
|
||||||
str, "Latest response by the risky analyst"
|
str, "Latest response by the aggressive analyst"
|
||||||
] # Last response
|
] # Last response
|
||||||
current_safe_response: Annotated[
|
current_conservative_response: Annotated[
|
||||||
str, "Latest response by the safe analyst"
|
str, "Latest response by the conservative analyst"
|
||||||
] # Last response
|
] # Last response
|
||||||
current_neutral_response: Annotated[
|
current_neutral_response: Annotated[
|
||||||
str, "Latest response by the neutral analyst"
|
str, "Latest response by the neutral analyst"
|
||||||
@@ -55,7 +51,7 @@ class AgentState(MessagesState):
|
|||||||
|
|
||||||
# research step
|
# research step
|
||||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
sentiment_report: Annotated[str, "Report from the Sentiment Analyst"]
|
||||||
news_report: Annotated[
|
news_report: Annotated[
|
||||||
str, "Report from the News Researcher of current world affairs"
|
str, "Report from the News Researcher of current world affairs"
|
||||||
]
|
]
|
||||||
@@ -74,3 +70,4 @@ class AgentState(MessagesState):
|
|||||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
RiskDebateState, "Current state of the debate on evaluating risk"
|
||||||
]
|
]
|
||||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||||
|
past_context: Annotated[str, "Memory log context injected at run start (same-ticker decisions + cross-ticker lessons)"]
|
||||||
|
|||||||
@@ -1,411 +1,63 @@
|
|||||||
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage
|
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||||
from typing import List
|
|
||||||
from typing import Annotated
|
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
||||||
from langchain_core.messages import RemoveMessage
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from datetime import date, timedelta, datetime
|
|
||||||
import functools
|
|
||||||
import pandas as pd
|
|
||||||
import os
|
|
||||||
from dateutil.relativedelta import relativedelta
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
import tradingagents.dataflows.interface as interface
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
|
|
||||||
|
# Import tools from separate utility files
|
||||||
|
from tradingagents.agents.utils.core_stock_tools import (
|
||||||
|
get_stock_data
|
||||||
|
)
|
||||||
|
from tradingagents.agents.utils.technical_indicators_tools import (
|
||||||
|
get_indicators
|
||||||
|
)
|
||||||
|
from tradingagents.agents.utils.fundamental_data_tools import (
|
||||||
|
get_fundamentals,
|
||||||
|
get_balance_sheet,
|
||||||
|
get_cashflow,
|
||||||
|
get_income_statement
|
||||||
|
)
|
||||||
|
from tradingagents.agents.utils.news_data_tools import (
|
||||||
|
get_news,
|
||||||
|
get_insider_transactions,
|
||||||
|
get_global_news
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_language_instruction() -> str:
|
||||||
|
"""Return a prompt instruction for the configured output language.
|
||||||
|
|
||||||
|
Returns empty string when English (default), so no extra tokens are used.
|
||||||
|
Applied to every agent whose output reaches the saved report —
|
||||||
|
analysts, researchers, debaters, research manager, trader, and
|
||||||
|
portfolio manager — so a non-English run produces a fully localized
|
||||||
|
report rather than a mix of languages.
|
||||||
|
"""
|
||||||
|
from tradingagents.dataflows.config import get_config
|
||||||
|
lang = get_config().get("output_language", "English")
|
||||||
|
if lang.strip().lower() == "english":
|
||||||
|
return ""
|
||||||
|
return f" Write your entire response in {lang}."
|
||||||
|
|
||||||
|
|
||||||
|
def build_instrument_context(ticker: str) -> str:
|
||||||
|
"""Describe the exact instrument so agents preserve exchange-qualified tickers."""
|
||||||
|
return (
|
||||||
|
f"The instrument to analyze is `{ticker}`. "
|
||||||
|
"Use this exact ticker in every tool call, report, and recommendation, "
|
||||||
|
"preserving any exchange suffix (e.g. `.TO`, `.L`, `.HK`, `.T`)."
|
||||||
|
)
|
||||||
|
|
||||||
def create_msg_delete():
|
def create_msg_delete():
|
||||||
def delete_messages(state):
|
def delete_messages(state):
|
||||||
"""To prevent message history from overflowing, regularly clear message history after a stage of the pipeline is done"""
|
"""Clear messages and add placeholder for Anthropic compatibility"""
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
return {"messages": [RemoveMessage(id=m.id) for m in messages]}
|
|
||||||
|
# Remove all messages
|
||||||
|
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||||
|
|
||||||
|
# Add a minimal placeholder message
|
||||||
|
placeholder = HumanMessage(content="Continue")
|
||||||
|
|
||||||
|
return {"messages": removal_operations + [placeholder]}
|
||||||
|
|
||||||
return delete_messages
|
return delete_messages
|
||||||
|
|
||||||
|
|
||||||
class Toolkit:
|
|
||||||
_config = DEFAULT_CONFIG.copy()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def update_config(cls, config):
|
|
||||||
"""Update the class-level configuration."""
|
|
||||||
cls._config.update(config)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def config(self):
|
|
||||||
"""Access the configuration."""
|
|
||||||
return self._config
|
|
||||||
|
|
||||||
def __init__(self, config=None):
|
|
||||||
if config:
|
|
||||||
self.update_config(config)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_reddit_news(
|
|
||||||
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Retrieve global news from Reddit within a specified time frame.
|
|
||||||
Args:
|
|
||||||
curr_date (str): Date you want to get news for in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
|
|
||||||
"""
|
|
||||||
|
|
||||||
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
|
|
||||||
|
|
||||||
return global_news_result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_finnhub_news(
|
|
||||||
ticker: Annotated[
|
|
||||||
str,
|
|
||||||
"Search query of a company, e.g. 'AAPL, TSM, etc.",
|
|
||||||
],
|
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve the latest news about a given stock from Finnhub within a date range
|
|
||||||
Args:
|
|
||||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
|
||||||
start_date (str): Start date in yyyy-mm-dd format
|
|
||||||
end_date (str): End date in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing news about the company within the date range from start_date to end_date
|
|
||||||
"""
|
|
||||||
|
|
||||||
end_date_str = end_date
|
|
||||||
|
|
||||||
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
|
||||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
|
||||||
look_back_days = (end_date - start_date).days
|
|
||||||
|
|
||||||
finnhub_news_result = interface.get_finnhub_news(
|
|
||||||
ticker, end_date_str, look_back_days
|
|
||||||
)
|
|
||||||
|
|
||||||
return finnhub_news_result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_reddit_stock_info(
|
|
||||||
ticker: Annotated[
|
|
||||||
str,
|
|
||||||
"Ticker of a company. e.g. AAPL, TSM",
|
|
||||||
],
|
|
||||||
curr_date: Annotated[str, "Current date you want to get news for"],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Retrieve the latest news about a given stock from Reddit, given the current date.
|
|
||||||
Args:
|
|
||||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
|
||||||
curr_date (str): current date in yyyy-mm-dd format to get news for
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing the latest news about the company on the given date
|
|
||||||
"""
|
|
||||||
|
|
||||||
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
|
|
||||||
|
|
||||||
return stock_news_results
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_YFin_data(
|
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
|
||||||
Args:
|
|
||||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
||||||
start_date (str): Start date in yyyy-mm-dd format
|
|
||||||
end_date (str): End date in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
|
||||||
"""
|
|
||||||
|
|
||||||
result_data = interface.get_YFin_data(symbol, start_date, end_date)
|
|
||||||
|
|
||||||
return result_data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_YFin_data_online(
|
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
|
||||||
Args:
|
|
||||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
||||||
start_date (str): Start date in yyyy-mm-dd format
|
|
||||||
end_date (str): End date in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
|
||||||
"""
|
|
||||||
|
|
||||||
result_data = interface.get_YFin_data_online(symbol, start_date, end_date)
|
|
||||||
|
|
||||||
return result_data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_stockstats_indicators_report(
|
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
|
||||||
indicator: Annotated[
|
|
||||||
str, "technical indicator to get the analysis and report of"
|
|
||||||
],
|
|
||||||
curr_date: Annotated[
|
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
|
||||||
],
|
|
||||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Retrieve stock stats indicators for a given ticker symbol and indicator.
|
|
||||||
Args:
|
|
||||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
||||||
indicator (str): Technical indicator to get the analysis and report of
|
|
||||||
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
|
||||||
look_back_days (int): How many days to look back, default is 30
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
result_stockstats = interface.get_stock_stats_indicators_window(
|
|
||||||
symbol, indicator, curr_date, look_back_days, False
|
|
||||||
)
|
|
||||||
|
|
||||||
return result_stockstats
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_stockstats_indicators_report_online(
|
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
|
||||||
indicator: Annotated[
|
|
||||||
str, "technical indicator to get the analysis and report of"
|
|
||||||
],
|
|
||||||
curr_date: Annotated[
|
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
|
||||||
],
|
|
||||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Retrieve stock stats indicators for a given ticker symbol and indicator.
|
|
||||||
Args:
|
|
||||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
|
||||||
indicator (str): Technical indicator to get the analysis and report of
|
|
||||||
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
|
||||||
look_back_days (int): How many days to look back, default is 30
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
result_stockstats = interface.get_stock_stats_indicators_window(
|
|
||||||
symbol, indicator, curr_date, look_back_days, True
|
|
||||||
)
|
|
||||||
|
|
||||||
return result_stockstats
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_finnhub_company_insider_sentiment(
|
|
||||||
ticker: Annotated[str, "ticker symbol for the company"],
|
|
||||||
curr_date: Annotated[
|
|
||||||
str,
|
|
||||||
"current date of you are trading at, yyyy-mm-dd",
|
|
||||||
],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve insider sentiment information about a company (retrieved from public SEC information) for the past 30 days
|
|
||||||
Args:
|
|
||||||
ticker (str): ticker symbol of the company
|
|
||||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
||||||
Returns:
|
|
||||||
str: a report of the sentiment in the past 30 days starting at curr_date
|
|
||||||
"""
|
|
||||||
|
|
||||||
data_sentiment = interface.get_finnhub_company_insider_sentiment(
|
|
||||||
ticker, curr_date, 30
|
|
||||||
)
|
|
||||||
|
|
||||||
return data_sentiment
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_finnhub_company_insider_transactions(
|
|
||||||
ticker: Annotated[str, "ticker symbol"],
|
|
||||||
curr_date: Annotated[
|
|
||||||
str,
|
|
||||||
"current date you are trading at, yyyy-mm-dd",
|
|
||||||
],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve insider transaction information about a company (retrieved from public SEC information) for the past 30 days
|
|
||||||
Args:
|
|
||||||
ticker (str): ticker symbol of the company
|
|
||||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
||||||
Returns:
|
|
||||||
str: a report of the company's insider transactions/trading information in the past 30 days
|
|
||||||
"""
|
|
||||||
|
|
||||||
data_trans = interface.get_finnhub_company_insider_transactions(
|
|
||||||
ticker, curr_date, 30
|
|
||||||
)
|
|
||||||
|
|
||||||
return data_trans
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_simfin_balance_sheet(
|
|
||||||
ticker: Annotated[str, "ticker symbol"],
|
|
||||||
freq: Annotated[
|
|
||||||
str,
|
|
||||||
"reporting frequency of the company's financial history: annual/quarterly",
|
|
||||||
],
|
|
||||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve the most recent balance sheet of a company
|
|
||||||
Args:
|
|
||||||
ticker (str): ticker symbol of the company
|
|
||||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
|
||||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
||||||
Returns:
|
|
||||||
str: a report of the company's most recent balance sheet
|
|
||||||
"""
|
|
||||||
|
|
||||||
data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)
|
|
||||||
|
|
||||||
return data_balance_sheet
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_simfin_cashflow(
|
|
||||||
ticker: Annotated[str, "ticker symbol"],
|
|
||||||
freq: Annotated[
|
|
||||||
str,
|
|
||||||
"reporting frequency of the company's financial history: annual/quarterly",
|
|
||||||
],
|
|
||||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve the most recent cash flow statement of a company
|
|
||||||
Args:
|
|
||||||
ticker (str): ticker symbol of the company
|
|
||||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
|
||||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
||||||
Returns:
|
|
||||||
str: a report of the company's most recent cash flow statement
|
|
||||||
"""
|
|
||||||
|
|
||||||
data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)
|
|
||||||
|
|
||||||
return data_cashflow
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_simfin_income_stmt(
|
|
||||||
ticker: Annotated[str, "ticker symbol"],
|
|
||||||
freq: Annotated[
|
|
||||||
str,
|
|
||||||
"reporting frequency of the company's financial history: annual/quarterly",
|
|
||||||
],
|
|
||||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve the most recent income statement of a company
|
|
||||||
Args:
|
|
||||||
ticker (str): ticker symbol of the company
|
|
||||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
|
||||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
||||||
Returns:
|
|
||||||
str: a report of the company's most recent income statement
|
|
||||||
"""
|
|
||||||
|
|
||||||
data_income_stmt = interface.get_simfin_income_statements(
|
|
||||||
ticker, freq, curr_date
|
|
||||||
)
|
|
||||||
|
|
||||||
return data_income_stmt
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_google_news(
|
|
||||||
query: Annotated[str, "Query to search with"],
|
|
||||||
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve the latest news from Google News based on a query and date range.
|
|
||||||
Args:
|
|
||||||
query (str): Query to search with
|
|
||||||
curr_date (str): Current date in yyyy-mm-dd format
|
|
||||||
look_back_days (int): How many days to look back
|
|
||||||
Returns:
|
|
||||||
str: A formatted string containing the latest news from Google News based on the query and date range.
|
|
||||||
"""
|
|
||||||
|
|
||||||
google_news_results = interface.get_google_news(query, curr_date, 7)
|
|
||||||
|
|
||||||
return google_news_results
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_stock_news_openai(
|
|
||||||
ticker: Annotated[str, "the company's ticker"],
|
|
||||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve the latest news about a given stock by using OpenAI's news API.
|
|
||||||
Args:
|
|
||||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
|
||||||
curr_date (str): Current date in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted string containing the latest news about the company on the given date.
|
|
||||||
"""
|
|
||||||
|
|
||||||
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
|
|
||||||
|
|
||||||
return openai_news_results
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_global_news_openai(
|
|
||||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
|
|
||||||
Args:
|
|
||||||
curr_date (str): Current date in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted string containing the latest macroeconomic news on the given date.
|
|
||||||
"""
|
|
||||||
|
|
||||||
openai_news_results = interface.get_global_news_openai(curr_date)
|
|
||||||
|
|
||||||
return openai_news_results
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@tool
|
|
||||||
def get_fundamentals_openai(
|
|
||||||
ticker: Annotated[str, "the company's ticker"],
|
|
||||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve the latest fundamental information about a given stock on a given date by using OpenAI's news API.
|
|
||||||
Args:
|
|
||||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
|
||||||
curr_date (str): Current date in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted string containing the latest fundamental information about the company on the given date.
|
|
||||||
"""
|
|
||||||
|
|
||||||
openai_fundamentals_results = interface.get_fundamentals_openai(
|
|
||||||
ticker, curr_date
|
|
||||||
)
|
|
||||||
|
|
||||||
return openai_fundamentals_results
|
|
||||||
|
|||||||
22
tradingagents/agents/utils/core_stock_tools.py
Normal file
22
tradingagents/agents/utils/core_stock_tools.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from langchain_core.tools import tool
|
||||||
|
from typing import Annotated
|
||||||
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_stock_data(
|
||||||
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve stock price data (OHLCV) for a given ticker symbol.
|
||||||
|
Uses the configured core_stock_apis vendor.
|
||||||
|
Args:
|
||||||
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||||
|
start_date (str): Start date in yyyy-mm-dd format
|
||||||
|
end_date (str): End date in yyyy-mm-dd format
|
||||||
|
Returns:
|
||||||
|
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_stock_data", symbol, start_date, end_date)
|
||||||
77
tradingagents/agents/utils/fundamental_data_tools.py
Normal file
77
tradingagents/agents/utils/fundamental_data_tools.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from langchain_core.tools import tool
|
||||||
|
from typing import Annotated
|
||||||
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_fundamentals(
|
||||||
|
ticker: Annotated[str, "ticker symbol"],
|
||||||
|
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve comprehensive fundamental data for a given ticker symbol.
|
||||||
|
Uses the configured fundamental_data vendor.
|
||||||
|
Args:
|
||||||
|
ticker (str): Ticker symbol of the company
|
||||||
|
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||||
|
Returns:
|
||||||
|
str: A formatted report containing comprehensive fundamental data
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_fundamentals", ticker, curr_date)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_balance_sheet(
|
||||||
|
ticker: Annotated[str, "ticker symbol"],
|
||||||
|
freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly",
|
||||||
|
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve balance sheet data for a given ticker symbol.
|
||||||
|
Uses the configured fundamental_data vendor.
|
||||||
|
Args:
|
||||||
|
ticker (str): Ticker symbol of the company
|
||||||
|
freq (str): Reporting frequency: annual/quarterly (default quarterly)
|
||||||
|
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||||
|
Returns:
|
||||||
|
str: A formatted report containing balance sheet data
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_balance_sheet", ticker, freq, curr_date)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_cashflow(
|
||||||
|
ticker: Annotated[str, "ticker symbol"],
|
||||||
|
freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly",
|
||||||
|
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve cash flow statement data for a given ticker symbol.
|
||||||
|
Uses the configured fundamental_data vendor.
|
||||||
|
Args:
|
||||||
|
ticker (str): Ticker symbol of the company
|
||||||
|
freq (str): Reporting frequency: annual/quarterly (default quarterly)
|
||||||
|
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||||
|
Returns:
|
||||||
|
str: A formatted report containing cash flow statement data
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_cashflow", ticker, freq, curr_date)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_income_statement(
|
||||||
|
ticker: Annotated[str, "ticker symbol"],
|
||||||
|
freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly",
|
||||||
|
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve income statement data for a given ticker symbol.
|
||||||
|
Uses the configured fundamental_data vendor.
|
||||||
|
Args:
|
||||||
|
ticker (str): Ticker symbol of the company
|
||||||
|
freq (str): Reporting frequency: annual/quarterly (default quarterly)
|
||||||
|
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||||
|
Returns:
|
||||||
|
str: A formatted report containing income statement data
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_income_statement", ticker, freq, curr_date)
|
||||||
@@ -1,109 +1,300 @@
|
|||||||
import chromadb
|
"""Append-only markdown decision log for TradingAgents."""
|
||||||
from chromadb.config import Settings
|
|
||||||
from openai import OpenAI
|
from typing import List, Optional
|
||||||
import numpy as np
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.rating import parse_rating
|
||||||
|
|
||||||
|
|
||||||
class FinancialSituationMemory:
|
class TradingMemoryLog:
|
||||||
def __init__(self, name):
|
"""Append-only markdown log of trading decisions and reflections."""
|
||||||
self.client = OpenAI()
|
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
|
||||||
|
|
||||||
def get_embedding(self, text):
|
# HTML comment: cannot appear in LLM prose output, safe as a hard delimiter
|
||||||
"""Get OpenAI embedding for a text"""
|
_SEPARATOR = "\n\n<!-- ENTRY_END -->\n\n"
|
||||||
response = self.client.embeddings.create(
|
# Precompiled patterns — avoids re-compilation on every load_entries() call
|
||||||
model="text-embedding-ada-002", input=text
|
_DECISION_RE = re.compile(r"DECISION:\n(.*?)(?=\nREFLECTION:|\Z)", re.DOTALL)
|
||||||
)
|
_REFLECTION_RE = re.compile(r"REFLECTION:\n(.*?)$", re.DOTALL)
|
||||||
return response.data[0].embedding
|
|
||||||
|
|
||||||
def add_situations(self, situations_and_advice):
|
def __init__(self, config: dict = None):
|
||||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
cfg = config or {}
|
||||||
|
self._log_path = None
|
||||||
|
path = cfg.get("memory_log_path")
|
||||||
|
if path:
|
||||||
|
self._log_path = Path(path).expanduser()
|
||||||
|
self._log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Optional cap on resolved entries. None disables rotation.
|
||||||
|
self._max_entries = cfg.get("memory_log_max_entries")
|
||||||
|
|
||||||
situations = []
|
# --- Write path (Phase A) ---
|
||||||
advice = []
|
|
||||||
ids = []
|
|
||||||
embeddings = []
|
|
||||||
|
|
||||||
offset = self.situation_collection.count()
|
def store_decision(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
trade_date: str,
|
||||||
|
final_trade_decision: str,
|
||||||
|
) -> None:
|
||||||
|
"""Append pending entry at end of propagate(). No LLM call."""
|
||||||
|
if not self._log_path:
|
||||||
|
return
|
||||||
|
# Idempotency guard: fast raw-text scan instead of full parse
|
||||||
|
if self._log_path.exists():
|
||||||
|
raw = self._log_path.read_text(encoding="utf-8")
|
||||||
|
for line in raw.splitlines():
|
||||||
|
if line.startswith(f"[{trade_date} | {ticker} |") and line.endswith("| pending]"):
|
||||||
|
return
|
||||||
|
rating = parse_rating(final_trade_decision)
|
||||||
|
tag = f"[{trade_date} | {ticker} | {rating} | pending]"
|
||||||
|
entry = f"{tag}\n\nDECISION:\n{final_trade_decision}{self._SEPARATOR}"
|
||||||
|
with open(self._log_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(entry)
|
||||||
|
|
||||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
# --- Read path (Phase A) ---
|
||||||
situations.append(situation)
|
|
||||||
advice.append(recommendation)
|
|
||||||
ids.append(str(offset + i))
|
|
||||||
embeddings.append(self.get_embedding(situation))
|
|
||||||
|
|
||||||
self.situation_collection.add(
|
def load_entries(self) -> List[dict]:
|
||||||
documents=situations,
|
"""Parse all entries from log. Returns list of dicts."""
|
||||||
metadatas=[{"recommendation": rec} for rec in advice],
|
if not self._log_path or not self._log_path.exists():
|
||||||
embeddings=embeddings,
|
return []
|
||||||
ids=ids,
|
text = self._log_path.read_text(encoding="utf-8")
|
||||||
)
|
raw_entries = [e.strip() for e in text.split(self._SEPARATOR) if e.strip()]
|
||||||
|
entries = []
|
||||||
|
for raw in raw_entries:
|
||||||
|
parsed = self._parse_entry(raw)
|
||||||
|
if parsed:
|
||||||
|
entries.append(parsed)
|
||||||
|
return entries
|
||||||
|
|
||||||
def get_memories(self, current_situation, n_matches=1):
|
def get_pending_entries(self) -> List[dict]:
|
||||||
"""Find matching recommendations using OpenAI embeddings"""
|
"""Return entries with outcome:pending (for Phase B)."""
|
||||||
query_embedding = self.get_embedding(current_situation)
|
return [e for e in self.load_entries() if e.get("pending")]
|
||||||
|
|
||||||
results = self.situation_collection.query(
|
def get_past_context(self, ticker: str, n_same: int = 5, n_cross: int = 3) -> str:
|
||||||
query_embeddings=[query_embedding],
|
"""Return formatted past context string for agent prompt injection."""
|
||||||
n_results=n_matches,
|
entries = [e for e in self.load_entries() if not e.get("pending")]
|
||||||
include=["metadatas", "documents", "distances"],
|
if not entries:
|
||||||
)
|
return ""
|
||||||
|
|
||||||
matched_results = []
|
same, cross = [], []
|
||||||
for i in range(len(results["documents"][0])):
|
for e in reversed(entries):
|
||||||
matched_results.append(
|
if len(same) >= n_same and len(cross) >= n_cross:
|
||||||
{
|
break
|
||||||
"matched_situation": results["documents"][0][i],
|
if e["ticker"] == ticker and len(same) < n_same:
|
||||||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
same.append(e)
|
||||||
"similarity_score": 1 - results["distances"][0][i],
|
elif e["ticker"] != ticker and len(cross) < n_cross:
|
||||||
}
|
cross.append(e)
|
||||||
|
|
||||||
|
if not same and not cross:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if same:
|
||||||
|
parts.append(f"Past analyses of {ticker} (most recent first):")
|
||||||
|
parts.extend(self._format_full(e) for e in same)
|
||||||
|
if cross:
|
||||||
|
parts.append("Recent cross-ticker lessons:")
|
||||||
|
parts.extend(self._format_reflection_only(e) for e in cross)
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
# --- Update path (Phase B) ---
|
||||||
|
|
||||||
|
def update_with_outcome(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
trade_date: str,
|
||||||
|
raw_return: float,
|
||||||
|
alpha_return: float,
|
||||||
|
holding_days: int,
|
||||||
|
reflection: str,
|
||||||
|
) -> None:
|
||||||
|
"""Replace pending tag and append REFLECTION section using atomic write.
|
||||||
|
|
||||||
|
Finds the first pending entry matching (trade_date, ticker), updates
|
||||||
|
its tag with return figures, and appends a REFLECTION section. Uses
|
||||||
|
a temp-file + os.replace() so a crash mid-write never corrupts the log.
|
||||||
|
"""
|
||||||
|
if not self._log_path or not self._log_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
text = self._log_path.read_text(encoding="utf-8")
|
||||||
|
blocks = text.split(self._SEPARATOR)
|
||||||
|
|
||||||
|
pending_prefix = f"[{trade_date} | {ticker} |"
|
||||||
|
raw_pct = f"{raw_return:+.1%}"
|
||||||
|
alpha_pct = f"{alpha_return:+.1%}"
|
||||||
|
|
||||||
|
updated = False
|
||||||
|
new_blocks = []
|
||||||
|
for block in blocks:
|
||||||
|
stripped = block.strip()
|
||||||
|
if not stripped:
|
||||||
|
new_blocks.append(block)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines = stripped.splitlines()
|
||||||
|
tag_line = lines[0].strip()
|
||||||
|
|
||||||
|
if (
|
||||||
|
not updated
|
||||||
|
and tag_line.startswith(pending_prefix)
|
||||||
|
and tag_line.endswith("| pending]")
|
||||||
|
):
|
||||||
|
# Parse rating from the existing pending tag
|
||||||
|
fields = [f.strip() for f in tag_line[1:-1].split("|")]
|
||||||
|
rating = fields[2]
|
||||||
|
new_tag = (
|
||||||
|
f"[{trade_date} | {ticker} | {rating}"
|
||||||
|
f" | {raw_pct} | {alpha_pct} | {holding_days}d]"
|
||||||
|
)
|
||||||
|
rest = "\n".join(lines[1:])
|
||||||
|
new_blocks.append(
|
||||||
|
f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{reflection}"
|
||||||
|
)
|
||||||
|
updated = True
|
||||||
|
else:
|
||||||
|
new_blocks.append(block)
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_blocks = self._apply_rotation(new_blocks)
|
||||||
|
new_text = self._SEPARATOR.join(new_blocks)
|
||||||
|
tmp_path = self._log_path.with_suffix(".tmp")
|
||||||
|
tmp_path.write_text(new_text, encoding="utf-8")
|
||||||
|
tmp_path.replace(self._log_path)
|
||||||
|
|
||||||
|
def batch_update_with_outcomes(self, updates: List[dict]) -> None:
|
||||||
|
"""Apply multiple outcome updates in a single read + atomic write.
|
||||||
|
|
||||||
|
Each element of updates must have keys: ticker, trade_date,
|
||||||
|
raw_return, alpha_return, holding_days, reflection.
|
||||||
|
"""
|
||||||
|
if not self._log_path or not self._log_path.exists() or not updates:
|
||||||
|
return
|
||||||
|
|
||||||
|
text = self._log_path.read_text(encoding="utf-8")
|
||||||
|
blocks = text.split(self._SEPARATOR)
|
||||||
|
|
||||||
|
# Build lookup keyed by (trade_date, ticker) for O(1) dispatch
|
||||||
|
update_map = {(u["trade_date"], u["ticker"]): u for u in updates}
|
||||||
|
|
||||||
|
new_blocks = []
|
||||||
|
for block in blocks:
|
||||||
|
stripped = block.strip()
|
||||||
|
if not stripped:
|
||||||
|
new_blocks.append(block)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines = stripped.splitlines()
|
||||||
|
tag_line = lines[0].strip()
|
||||||
|
|
||||||
|
matched = False
|
||||||
|
for (trade_date, ticker), upd in list(update_map.items()):
|
||||||
|
pending_prefix = f"[{trade_date} | {ticker} |"
|
||||||
|
if tag_line.startswith(pending_prefix) and tag_line.endswith("| pending]"):
|
||||||
|
fields = [f.strip() for f in tag_line[1:-1].split("|")]
|
||||||
|
rating = fields[2]
|
||||||
|
raw_pct = f"{upd['raw_return']:+.1%}"
|
||||||
|
alpha_pct = f"{upd['alpha_return']:+.1%}"
|
||||||
|
new_tag = (
|
||||||
|
f"[{trade_date} | {ticker} | {rating}"
|
||||||
|
f" | {raw_pct} | {alpha_pct} | {upd['holding_days']}d]"
|
||||||
|
)
|
||||||
|
rest = "\n".join(lines[1:])
|
||||||
|
new_blocks.append(
|
||||||
|
f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{upd['reflection']}"
|
||||||
|
)
|
||||||
|
del update_map[(trade_date, ticker)]
|
||||||
|
matched = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not matched:
|
||||||
|
new_blocks.append(block)
|
||||||
|
|
||||||
|
new_blocks = self._apply_rotation(new_blocks)
|
||||||
|
new_text = self._SEPARATOR.join(new_blocks)
|
||||||
|
tmp_path = self._log_path.with_suffix(".tmp")
|
||||||
|
tmp_path.write_text(new_text, encoding="utf-8")
|
||||||
|
tmp_path.replace(self._log_path)
|
||||||
|
|
||||||
|
# --- Helpers ---
|
||||||
|
|
||||||
|
def _apply_rotation(self, blocks: List[str]) -> List[str]:
|
||||||
|
"""Drop oldest resolved blocks when their count exceeds max_entries.
|
||||||
|
|
||||||
|
Pending blocks are always kept (they represent unprocessed work).
|
||||||
|
Returns ``blocks`` unchanged when rotation is disabled or under cap.
|
||||||
|
"""
|
||||||
|
if not self._max_entries or self._max_entries <= 0:
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
# Tag each block with (kept, is_resolved) by parsing tag-line markers.
|
||||||
|
decisions = []
|
||||||
|
for block in blocks:
|
||||||
|
stripped = block.strip()
|
||||||
|
if not stripped:
|
||||||
|
decisions.append((block, False))
|
||||||
|
continue
|
||||||
|
tag_line = stripped.splitlines()[0].strip()
|
||||||
|
is_resolved = (
|
||||||
|
tag_line.startswith("[")
|
||||||
|
and tag_line.endswith("]")
|
||||||
|
and not tag_line.endswith("| pending]")
|
||||||
)
|
)
|
||||||
|
decisions.append((block, is_resolved))
|
||||||
|
|
||||||
return matched_results
|
resolved_count = sum(1 for _, r in decisions if r)
|
||||||
|
if resolved_count <= self._max_entries:
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
to_drop = resolved_count - self._max_entries
|
||||||
|
kept: List[str] = []
|
||||||
|
for block, is_resolved in decisions:
|
||||||
|
if is_resolved and to_drop > 0:
|
||||||
|
to_drop -= 1
|
||||||
|
continue
|
||||||
|
kept.append(block)
|
||||||
|
return kept
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def _parse_entry(self, raw: str) -> Optional[dict]:
|
||||||
# Example usage
|
lines = raw.strip().splitlines()
|
||||||
matcher = FinancialSituationMemory()
|
if not lines:
|
||||||
|
return None
|
||||||
|
tag_line = lines[0].strip()
|
||||||
|
if not (tag_line.startswith("[") and tag_line.endswith("]")):
|
||||||
|
return None
|
||||||
|
fields = [f.strip() for f in tag_line[1:-1].split("|")]
|
||||||
|
if len(fields) < 4:
|
||||||
|
return None
|
||||||
|
entry = {
|
||||||
|
"date": fields[0],
|
||||||
|
"ticker": fields[1],
|
||||||
|
"rating": fields[2],
|
||||||
|
"pending": fields[3] == "pending",
|
||||||
|
"raw": fields[3] if fields[3] != "pending" else None,
|
||||||
|
"alpha": fields[4] if len(fields) > 4 else None,
|
||||||
|
"holding": fields[5] if len(fields) > 5 else None,
|
||||||
|
}
|
||||||
|
body = "\n".join(lines[1:]).strip()
|
||||||
|
decision_match = self._DECISION_RE.search(body)
|
||||||
|
reflection_match = self._REFLECTION_RE.search(body)
|
||||||
|
entry["decision"] = decision_match.group(1).strip() if decision_match else ""
|
||||||
|
entry["reflection"] = reflection_match.group(1).strip() if reflection_match else ""
|
||||||
|
return entry
|
||||||
|
|
||||||
# Example data
|
def _format_full(self, e: dict) -> str:
|
||||||
example_data = [
|
raw = e["raw"] or "n/a"
|
||||||
(
|
alpha = e["alpha"] or "n/a"
|
||||||
"High inflation rate with rising interest rates and declining consumer spending",
|
holding = e["holding"] or "n/a"
|
||||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {raw} | {alpha} | {holding}]"
|
||||||
),
|
parts = [tag, f"DECISION:\n{e['decision']}"]
|
||||||
(
|
if e["reflection"]:
|
||||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
parts.append(f"REFLECTION:\n{e['reflection']}")
|
||||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
return "\n\n".join(parts)
|
||||||
),
|
|
||||||
(
|
|
||||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
|
||||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Market showing signs of sector rotation with rising yields",
|
|
||||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add the example situations and recommendations
|
def _format_reflection_only(self, e: dict) -> str:
|
||||||
matcher.add_situations(example_data)
|
tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {e['raw'] or 'n/a'}]"
|
||||||
|
if e["reflection"]:
|
||||||
# Example query
|
return f"{tag}\n{e['reflection']}"
|
||||||
current_situation = """
|
text = e["decision"][:300]
|
||||||
Market showing increased volatility in tech sector, with institutional investors
|
suffix = "..." if len(e["decision"]) > 300 else ""
|
||||||
reducing positions and rising interest rates affecting growth stock valuations
|
return f"{tag}\n{text}{suffix}"
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
|
||||||
|
|
||||||
for i, rec in enumerate(recommendations, 1):
|
|
||||||
print(f"\nMatch {i}:")
|
|
||||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
|
||||||
print(f"Matched Situation: {rec['matched_situation']}")
|
|
||||||
print(f"Recommendation: {rec['recommendation']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during recommendation: {str(e)}")
|
|
||||||
|
|||||||
57
tradingagents/agents/utils/news_data_tools.py
Normal file
57
tradingagents/agents/utils/news_data_tools.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from langchain_core.tools import tool
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_news(
|
||||||
|
ticker: Annotated[str, "Ticker symbol"],
|
||||||
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve news data for a given ticker symbol.
|
||||||
|
Uses the configured news_data vendor.
|
||||||
|
Args:
|
||||||
|
ticker (str): Ticker symbol
|
||||||
|
start_date (str): Start date in yyyy-mm-dd format
|
||||||
|
end_date (str): End date in yyyy-mm-dd format
|
||||||
|
Returns:
|
||||||
|
str: A formatted string containing news data
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_news", ticker, start_date, end_date)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_global_news(
|
||||||
|
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||||
|
look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None,
|
||||||
|
limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve global news data.
|
||||||
|
Uses the configured news_data vendor. Defaults for look_back_days and
|
||||||
|
limit come from DEFAULT_CONFIG (global_news_lookback_days,
|
||||||
|
global_news_article_limit); pass explicit values to override.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
curr_date (str): Current date in yyyy-mm-dd format
|
||||||
|
look_back_days (int): Number of days to look back; omit to inherit config
|
||||||
|
limit (int): Maximum number of articles to return; omit to inherit config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A formatted string containing global news data
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_global_news", curr_date, look_back_days, limit)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_insider_transactions(
|
||||||
|
ticker: Annotated[str, "ticker symbol"],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve insider transaction information about a company.
|
||||||
|
Uses the configured news_data vendor.
|
||||||
|
Args:
|
||||||
|
ticker (str): Ticker symbol of the company
|
||||||
|
Returns:
|
||||||
|
str: A report of insider transaction data
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_insider_transactions", ticker)
|
||||||
50
tradingagents/agents/utils/rating.py
Normal file
50
tradingagents/agents/utils/rating.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Shared 5-tier rating vocabulary and a deterministic heuristic parser.
|
||||||
|
|
||||||
|
The same five-tier scale (Buy, Overweight, Hold, Underweight, Sell) is used by:
|
||||||
|
- The Research Manager (investment plan recommendation)
|
||||||
|
- The Portfolio Manager (final position decision)
|
||||||
|
- The signal processor (rating extracted for downstream consumers)
|
||||||
|
- The memory log (rating tag stored alongside each decision entry)
|
||||||
|
|
||||||
|
Centralising it here avoids drift between those call sites.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# Canonical, ordered 5-tier scale (most bullish to most bearish).
|
||||||
|
RATINGS_5_TIER: Tuple[str, ...] = (
|
||||||
|
"Buy", "Overweight", "Hold", "Underweight", "Sell",
|
||||||
|
)
|
||||||
|
|
||||||
|
_RATING_SET = {r.lower() for r in RATINGS_5_TIER}
|
||||||
|
|
||||||
|
# Matches "Rating: X" / "rating - X" / "Rating: **X**" — tolerates markdown
|
||||||
|
# bold wrappers and either a colon or hyphen separator.
|
||||||
|
_RATING_LABEL_RE = re.compile(r"rating.*?[:\-][\s*]*(\w+)", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_rating(text: str, default: str = "Hold") -> str:
|
||||||
|
"""Heuristically extract a 5-tier rating from prose text.
|
||||||
|
|
||||||
|
Two-pass strategy:
|
||||||
|
1. Look for an explicit "Rating: X" label (tolerant of markdown bold).
|
||||||
|
2. Fall back to the first 5-tier rating word found anywhere in the text.
|
||||||
|
|
||||||
|
Returns a Title-cased rating string, or ``default`` if no rating word appears.
|
||||||
|
"""
|
||||||
|
for line in text.splitlines():
|
||||||
|
m = _RATING_LABEL_RE.search(line)
|
||||||
|
if m and m.group(1).lower() in _RATING_SET:
|
||||||
|
return m.group(1).capitalize()
|
||||||
|
|
||||||
|
for line in text.splitlines():
|
||||||
|
for word in line.lower().split():
|
||||||
|
clean = word.strip("*:.,")
|
||||||
|
if clean in _RATING_SET:
|
||||||
|
return clean.capitalize()
|
||||||
|
|
||||||
|
return default
|
||||||
73
tradingagents/agents/utils/structured.py
Normal file
73
tradingagents/agents/utils/structured.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Shared helpers for invoking an agent with structured output and a graceful fallback.
|
||||||
|
|
||||||
|
The Portfolio Manager, Trader, and Research Manager all follow the same
|
||||||
|
canonical pattern:
|
||||||
|
|
||||||
|
1. At agent creation, wrap the LLM with ``with_structured_output(Schema)``
|
||||||
|
so the model returns a typed Pydantic instance. If the provider does
|
||||||
|
not support structured output (rare; mostly older Ollama models), the
|
||||||
|
wrap is skipped and the agent uses free-text generation instead.
|
||||||
|
2. At invocation, run the structured call and render the result back to
|
||||||
|
markdown. If the structured call itself fails for any reason
|
||||||
|
(malformed JSON from a weak model, transient provider issue), fall
|
||||||
|
back to a plain ``llm.invoke`` so the pipeline never blocks.
|
||||||
|
|
||||||
|
Centralising the pattern here keeps the agent factories small and ensures
|
||||||
|
all three agents log the same warnings when fallback fires.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Optional, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]:
|
||||||
|
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
|
||||||
|
|
||||||
|
Logs a warning when the binding fails so the user understands the agent
|
||||||
|
will use free-text generation for every call instead of one-shot fallback.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return llm.with_structured_output(schema)
|
||||||
|
except (NotImplementedError, AttributeError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"%s: provider does not support with_structured_output (%s); "
|
||||||
|
"falling back to free-text generation",
|
||||||
|
agent_name, exc,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_structured_or_freetext(
|
||||||
|
structured_llm: Optional[Any],
|
||||||
|
plain_llm: Any,
|
||||||
|
prompt: Any,
|
||||||
|
render: Callable[[T], str],
|
||||||
|
agent_name: str,
|
||||||
|
) -> str:
|
||||||
|
"""Run the structured call and render to markdown; fall back to free-text on any failure.
|
||||||
|
|
||||||
|
``prompt`` is whatever the underlying LLM accepts (a string for chat
|
||||||
|
invocations, a list of message dicts for chat models that take that
|
||||||
|
shape). The same value is forwarded to the free-text path so the
|
||||||
|
fallback sees the same input the structured call did.
|
||||||
|
"""
|
||||||
|
if structured_llm is not None:
|
||||||
|
try:
|
||||||
|
result = structured_llm.invoke(prompt)
|
||||||
|
return render(result)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"%s: structured-output invocation failed (%s); retrying once as free text",
|
||||||
|
agent_name, exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = plain_llm.invoke(prompt)
|
||||||
|
return response.content
|
||||||
32
tradingagents/agents/utils/technical_indicators_tools.py
Normal file
32
tradingagents/agents/utils/technical_indicators_tools.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from langchain_core.tools import tool
|
||||||
|
from typing import Annotated
|
||||||
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_indicators(
|
||||||
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||||
|
curr_date: Annotated[str, "The current trading date you are trading on, YYYY-mm-dd"],
|
||||||
|
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve a single technical indicator for a given ticker symbol.
|
||||||
|
Uses the configured technical_indicators vendor.
|
||||||
|
Args:
|
||||||
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||||
|
indicator (str): A single technical indicator name, e.g. 'rsi', 'macd'. Call this tool once per indicator.
|
||||||
|
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
||||||
|
look_back_days (int): How many days to look back, default is 30
|
||||||
|
Returns:
|
||||||
|
str: A formatted dataframe containing the technical indicators for the specified ticker symbol and indicator.
|
||||||
|
"""
|
||||||
|
# LLMs sometimes pass multiple indicators as a comma-separated string;
|
||||||
|
# split and process each individually.
|
||||||
|
indicators = [i.strip().lower() for i in indicator.split(",") if i.strip()]
|
||||||
|
results = []
|
||||||
|
for ind in indicators:
|
||||||
|
try:
|
||||||
|
results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days))
|
||||||
|
except ValueError as e:
|
||||||
|
results.append(str(e))
|
||||||
|
return "\n\n".join(results)
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from .finnhub_utils import get_data_in_range
|
|
||||||
from .googlenews_utils import getNewsData
|
|
||||||
from .yfin_utils import YFinanceUtils
|
|
||||||
from .reddit_utils import fetch_top_from_category
|
|
||||||
from .stockstats_utils import StockstatsUtils
|
|
||||||
from .yfin_utils import YFinanceUtils
|
|
||||||
|
|
||||||
from .interface import (
|
|
||||||
# News and sentiment functions
|
|
||||||
get_finnhub_news,
|
|
||||||
get_finnhub_company_insider_sentiment,
|
|
||||||
get_finnhub_company_insider_transactions,
|
|
||||||
get_google_news,
|
|
||||||
get_reddit_global_news,
|
|
||||||
get_reddit_company_news,
|
|
||||||
# Financial statements functions
|
|
||||||
get_simfin_balance_sheet,
|
|
||||||
get_simfin_cashflow,
|
|
||||||
get_simfin_income_statements,
|
|
||||||
# Technical analysis functions
|
|
||||||
get_stock_stats_indicators_window,
|
|
||||||
get_stockstats_indicator,
|
|
||||||
# Market data functions
|
|
||||||
get_YFin_data_window,
|
|
||||||
get_YFin_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# News and sentiment functions
|
|
||||||
"get_finnhub_news",
|
|
||||||
"get_finnhub_company_insider_sentiment",
|
|
||||||
"get_finnhub_company_insider_transactions",
|
|
||||||
"get_google_news",
|
|
||||||
"get_reddit_global_news",
|
|
||||||
"get_reddit_company_news",
|
|
||||||
# Financial statements functions
|
|
||||||
"get_simfin_balance_sheet",
|
|
||||||
"get_simfin_cashflow",
|
|
||||||
"get_simfin_income_statements",
|
|
||||||
# Technical analysis functions
|
|
||||||
"get_stock_stats_indicators_window",
|
|
||||||
"get_stockstats_indicator",
|
|
||||||
# Market data functions
|
|
||||||
"get_YFin_data_window",
|
|
||||||
"get_YFin_data",
|
|
||||||
]
|
|
||||||
|
|||||||
5
tradingagents/dataflows/alpha_vantage.py
Normal file
5
tradingagents/dataflows/alpha_vantage.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Import functions from specialized modules
|
||||||
|
from .alpha_vantage_stock import get_stock
|
||||||
|
from .alpha_vantage_indicator import get_indicator
|
||||||
|
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
|
||||||
|
from .alpha_vantage_news import get_news, get_global_news, get_insider_transactions
|
||||||
122
tradingagents/dataflows/alpha_vantage_common.py
Normal file
122
tradingagents/dataflows/alpha_vantage_common.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
|
API_BASE_URL = "https://www.alphavantage.co/query"
|
||||||
|
|
||||||
|
def get_api_key() -> str:
|
||||||
|
"""Retrieve the API key for Alpha Vantage from environment variables."""
|
||||||
|
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
def format_datetime_for_api(date_input) -> str:
|
||||||
|
"""Convert various date formats to YYYYMMDDTHHMM format required by Alpha Vantage API."""
|
||||||
|
if isinstance(date_input, str):
|
||||||
|
# If already in correct format, return as-is
|
||||||
|
if len(date_input) == 13 and 'T' in date_input:
|
||||||
|
return date_input
|
||||||
|
# Try to parse common date formats
|
||||||
|
try:
|
||||||
|
dt = datetime.strptime(date_input, "%Y-%m-%d")
|
||||||
|
return dt.strftime("%Y%m%dT0000")
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
dt = datetime.strptime(date_input, "%Y-%m-%d %H:%M")
|
||||||
|
return dt.strftime("%Y%m%dT%H%M")
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Unsupported date format: {date_input}")
|
||||||
|
elif isinstance(date_input, datetime):
|
||||||
|
return date_input.strftime("%Y%m%dT%H%M")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Date must be string or datetime object, got {type(date_input)}")
|
||||||
|
|
||||||
|
class AlphaVantageRateLimitError(Exception):
|
||||||
|
"""Exception raised when Alpha Vantage API rate limit is exceeded."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _make_api_request(function_name: str, params: dict) -> dict | str:
|
||||||
|
"""Helper function to make API requests and handle responses.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AlphaVantageRateLimitError: When API rate limit is exceeded
|
||||||
|
"""
|
||||||
|
# Create a copy of params to avoid modifying the original
|
||||||
|
api_params = params.copy()
|
||||||
|
api_params.update({
|
||||||
|
"function": function_name,
|
||||||
|
"apikey": get_api_key(),
|
||||||
|
"source": "trading_agents",
|
||||||
|
})
|
||||||
|
|
||||||
|
# Handle entitlement parameter if present in params or global variable
|
||||||
|
current_entitlement = globals().get('_current_entitlement')
|
||||||
|
entitlement = api_params.get("entitlement") or current_entitlement
|
||||||
|
|
||||||
|
if entitlement:
|
||||||
|
api_params["entitlement"] = entitlement
|
||||||
|
elif "entitlement" in api_params:
|
||||||
|
# Remove entitlement if it's None or empty
|
||||||
|
api_params.pop("entitlement", None)
|
||||||
|
|
||||||
|
response = requests.get(API_BASE_URL, params=api_params)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_text = response.text
|
||||||
|
|
||||||
|
# Check if response is JSON (error responses are typically JSON)
|
||||||
|
try:
|
||||||
|
response_json = json.loads(response_text)
|
||||||
|
# Check for rate limit error
|
||||||
|
if "Information" in response_json:
|
||||||
|
info_message = response_json["Information"]
|
||||||
|
if "rate limit" in info_message.lower() or "api key" in info_message.lower():
|
||||||
|
raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Response is not JSON (likely CSV data), which is normal
|
||||||
|
pass
|
||||||
|
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> str:
|
||||||
|
"""
|
||||||
|
Filter CSV data to include only rows within the specified date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csv_data: CSV string from Alpha Vantage API
|
||||||
|
start_date: Start date in yyyy-mm-dd format
|
||||||
|
end_date: End date in yyyy-mm-dd format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered CSV string
|
||||||
|
"""
|
||||||
|
if not csv_data or csv_data.strip() == "":
|
||||||
|
return csv_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse CSV data
|
||||||
|
df = pd.read_csv(StringIO(csv_data))
|
||||||
|
|
||||||
|
# Assume the first column is the date column (timestamp)
|
||||||
|
date_col = df.columns[0]
|
||||||
|
df[date_col] = pd.to_datetime(df[date_col])
|
||||||
|
|
||||||
|
# Filter by date range
|
||||||
|
start_dt = pd.to_datetime(start_date)
|
||||||
|
end_dt = pd.to_datetime(end_date)
|
||||||
|
|
||||||
|
filtered_df = df[(df[date_col] >= start_dt) & (df[date_col] <= end_dt)]
|
||||||
|
|
||||||
|
# Convert back to CSV string
|
||||||
|
return filtered_df.to_csv(index=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If filtering fails, return original data with a warning
|
||||||
|
print(f"Warning: Failed to filter CSV data by date range: {e}")
|
||||||
|
return csv_data
|
||||||
55
tradingagents/dataflows/alpha_vantage_fundamentals.py
Normal file
55
tradingagents/dataflows/alpha_vantage_fundamentals.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from .alpha_vantage_common import _make_api_request
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_reports_by_date(result, curr_date: str):
|
||||||
|
"""Filter annualReports/quarterlyReports to exclude entries after curr_date.
|
||||||
|
|
||||||
|
Prevents look-ahead bias by removing fiscal periods that end after
|
||||||
|
the simulation's current date.
|
||||||
|
"""
|
||||||
|
if not curr_date or not isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
for key in ("annualReports", "quarterlyReports"):
|
||||||
|
if key in result:
|
||||||
|
result[key] = [
|
||||||
|
r for r in result[key]
|
||||||
|
if r.get("fiscalDateEnding", "") <= curr_date
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker (str): Ticker symbol of the company
|
||||||
|
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Company overview data including financial ratios and key metrics
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"symbol": ticker,
|
||||||
|
}
|
||||||
|
|
||||||
|
return _make_api_request("OVERVIEW", params)
|
||||||
|
|
||||||
|
|
||||||
|
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||||
|
"""Retrieve balance sheet data for a given ticker symbol using Alpha Vantage."""
|
||||||
|
result = _make_api_request("BALANCE_SHEET", {"symbol": ticker})
|
||||||
|
return _filter_reports_by_date(result, curr_date)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||||
|
"""Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage."""
|
||||||
|
result = _make_api_request("CASH_FLOW", {"symbol": ticker})
|
||||||
|
return _filter_reports_by_date(result, curr_date)
|
||||||
|
|
||||||
|
|
||||||
|
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||||
|
"""Retrieve income statement data for a given ticker symbol using Alpha Vantage."""
|
||||||
|
result = _make_api_request("INCOME_STATEMENT", {"symbol": ticker})
|
||||||
|
return _filter_reports_by_date(result, curr_date)
|
||||||
|
|
||||||
222
tradingagents/dataflows/alpha_vantage_indicator.py
Normal file
222
tradingagents/dataflows/alpha_vantage_indicator.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
from .alpha_vantage_common import _make_api_request
|
||||||
|
|
||||||
|
def get_indicator(
|
||||||
|
symbol: str,
|
||||||
|
indicator: str,
|
||||||
|
curr_date: str,
|
||||||
|
look_back_days: int,
|
||||||
|
interval: str = "daily",
|
||||||
|
time_period: int = 14,
|
||||||
|
series_type: str = "close"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Returns Alpha Vantage technical indicator values over a time window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: ticker symbol of the company
|
||||||
|
indicator: technical indicator to get the analysis and report of
|
||||||
|
curr_date: The current trading date you are trading on, YYYY-mm-dd
|
||||||
|
look_back_days: how many days to look back
|
||||||
|
interval: Time interval (daily, weekly, monthly)
|
||||||
|
time_period: Number of data points for calculation
|
||||||
|
series_type: The desired price type (close, open, high, low)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String containing indicator values and description
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
|
supported_indicators = {
|
||||||
|
"close_50_sma": ("50 SMA", "close"),
|
||||||
|
"close_200_sma": ("200 SMA", "close"),
|
||||||
|
"close_10_ema": ("10 EMA", "close"),
|
||||||
|
"macd": ("MACD", "close"),
|
||||||
|
"macds": ("MACD Signal", "close"),
|
||||||
|
"macdh": ("MACD Histogram", "close"),
|
||||||
|
"rsi": ("RSI", "close"),
|
||||||
|
"boll": ("Bollinger Middle", "close"),
|
||||||
|
"boll_ub": ("Bollinger Upper Band", "close"),
|
||||||
|
"boll_lb": ("Bollinger Lower Band", "close"),
|
||||||
|
"atr": ("ATR", None),
|
||||||
|
"vwma": ("VWMA", "close")
|
||||||
|
}
|
||||||
|
|
||||||
|
indicator_descriptions = {
|
||||||
|
"close_50_sma": "50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.",
|
||||||
|
"close_200_sma": "200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.",
|
||||||
|
"close_10_ema": "10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.",
|
||||||
|
"macd": "MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets.",
|
||||||
|
"macds": "MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives.",
|
||||||
|
"macdh": "MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets.",
|
||||||
|
"rsi": "RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.",
|
||||||
|
"boll": "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.",
|
||||||
|
"boll_ub": "Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.",
|
||||||
|
"boll_lb": "Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.",
|
||||||
|
"atr": "ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.",
|
||||||
|
"vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses."
|
||||||
|
}
|
||||||
|
|
||||||
|
if indicator not in supported_indicators:
|
||||||
|
raise ValueError(
|
||||||
|
f"Indicator {indicator} is not supported. Please choose from: {list(supported_indicators.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
before = curr_date_dt - relativedelta(days=look_back_days)
|
||||||
|
|
||||||
|
# Get the full data for the period instead of making individual calls
|
||||||
|
_, required_series_type = supported_indicators[indicator]
|
||||||
|
|
||||||
|
# Use the provided series_type or fall back to the required one
|
||||||
|
if required_series_type:
|
||||||
|
series_type = required_series_type
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get indicator data for the period
|
||||||
|
if indicator == "close_50_sma":
|
||||||
|
data = _make_api_request("SMA", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"time_period": "50",
|
||||||
|
"series_type": series_type,
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator == "close_200_sma":
|
||||||
|
data = _make_api_request("SMA", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"time_period": "200",
|
||||||
|
"series_type": series_type,
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator == "close_10_ema":
|
||||||
|
data = _make_api_request("EMA", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"time_period": "10",
|
||||||
|
"series_type": series_type,
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator == "macd":
|
||||||
|
data = _make_api_request("MACD", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"series_type": series_type,
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator == "macds":
|
||||||
|
data = _make_api_request("MACD", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"series_type": series_type,
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator == "macdh":
|
||||||
|
data = _make_api_request("MACD", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"series_type": series_type,
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator == "rsi":
|
||||||
|
data = _make_api_request("RSI", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"time_period": str(time_period),
|
||||||
|
"series_type": series_type,
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator in ["boll", "boll_ub", "boll_lb"]:
|
||||||
|
data = _make_api_request("BBANDS", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"time_period": "20",
|
||||||
|
"series_type": series_type,
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator == "atr":
|
||||||
|
data = _make_api_request("ATR", {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"time_period": str(time_period),
|
||||||
|
"datatype": "csv"
|
||||||
|
})
|
||||||
|
elif indicator == "vwma":
|
||||||
|
# Alpha Vantage doesn't have direct VWMA, so we'll return an informative message
|
||||||
|
# In a real implementation, this would need to be calculated from OHLCV data
|
||||||
|
return f"## VWMA (Volume Weighted Moving Average) for {symbol}:\n\nVWMA calculation requires OHLCV data and is not directly available from Alpha Vantage API.\nThis indicator would need to be calculated from the raw stock data using volume-weighted price averaging.\n\n{indicator_descriptions.get('vwma', 'No description available.')}"
|
||||||
|
else:
|
||||||
|
return f"Error: Indicator {indicator} not implemented yet."
|
||||||
|
|
||||||
|
# Parse CSV data and extract values for the date range
|
||||||
|
lines = data.strip().split('\n')
|
||||||
|
if len(lines) < 2:
|
||||||
|
return f"Error: No data returned for {indicator}"
|
||||||
|
|
||||||
|
# Parse header and data
|
||||||
|
header = [col.strip() for col in lines[0].split(',')]
|
||||||
|
try:
|
||||||
|
date_col_idx = header.index('time')
|
||||||
|
except ValueError:
|
||||||
|
return f"Error: 'time' column not found in data for {indicator}. Available columns: {header}"
|
||||||
|
|
||||||
|
# Map internal indicator names to expected CSV column names from Alpha Vantage
|
||||||
|
col_name_map = {
|
||||||
|
"macd": "MACD", "macds": "MACD_Signal", "macdh": "MACD_Hist",
|
||||||
|
"boll": "Real Middle Band", "boll_ub": "Real Upper Band", "boll_lb": "Real Lower Band",
|
||||||
|
"rsi": "RSI", "atr": "ATR", "close_10_ema": "EMA",
|
||||||
|
"close_50_sma": "SMA", "close_200_sma": "SMA"
|
||||||
|
}
|
||||||
|
|
||||||
|
target_col_name = col_name_map.get(indicator)
|
||||||
|
|
||||||
|
if not target_col_name:
|
||||||
|
# Default to the second column if no specific mapping exists
|
||||||
|
value_col_idx = 1
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
value_col_idx = header.index(target_col_name)
|
||||||
|
except ValueError:
|
||||||
|
return f"Error: Column '{target_col_name}' not found for indicator '{indicator}'. Available columns: {header}"
|
||||||
|
|
||||||
|
result_data = []
|
||||||
|
for line in lines[1:]:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
values = line.split(',')
|
||||||
|
if len(values) > value_col_idx:
|
||||||
|
try:
|
||||||
|
date_str = values[date_col_idx].strip()
|
||||||
|
# Parse the date
|
||||||
|
date_dt = datetime.strptime(date_str, "%Y-%m-%d")
|
||||||
|
|
||||||
|
# Check if date is in our range
|
||||||
|
if before <= date_dt <= curr_date_dt:
|
||||||
|
value = values[value_col_idx].strip()
|
||||||
|
result_data.append((date_dt, value))
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sort by date and format output
|
||||||
|
result_data.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
ind_string = ""
|
||||||
|
for date_dt, value in result_data:
|
||||||
|
ind_string += f"{date_dt.strftime('%Y-%m-%d')}: {value}\n"
|
||||||
|
|
||||||
|
if not ind_string:
|
||||||
|
ind_string = "No data available for the specified date range.\n"
|
||||||
|
|
||||||
|
result_str = (
|
||||||
|
f"## {indicator.upper()} values from {before.strftime('%Y-%m-%d')} to {curr_date}:\n\n"
|
||||||
|
+ ind_string
|
||||||
|
+ "\n\n"
|
||||||
|
+ indicator_descriptions.get(indicator, "No description available.")
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_str
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting Alpha Vantage indicator data for {indicator}: {e}")
|
||||||
|
return f"Error retrieving {indicator} data: {str(e)}"
|
||||||
71
tradingagents/dataflows/alpha_vantage_news.py
Normal file
71
tradingagents/dataflows/alpha_vantage_news.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||||
|
|
||||||
|
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||||
|
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
||||||
|
|
||||||
|
Covers stocks, cryptocurrencies, forex, and topics like fiscal policy, mergers & acquisitions, IPOs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock symbol for news articles.
|
||||||
|
start_date: Start date for news search.
|
||||||
|
end_date: End date for news search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing news sentiment data or JSON string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"tickers": ticker,
|
||||||
|
"time_from": format_datetime_for_api(start_date),
|
||||||
|
"time_to": format_datetime_for_api(end_date),
|
||||||
|
}
|
||||||
|
|
||||||
|
return _make_api_request("NEWS_SENTIMENT", params)
|
||||||
|
|
||||||
|
def get_global_news(curr_date, look_back_days: int = 7, limit: int = 50) -> dict[str, str] | str:
|
||||||
|
"""Returns global market news & sentiment data without ticker-specific filtering.
|
||||||
|
|
||||||
|
Covers broad market topics like financial markets, economy, and more.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
curr_date: Current date in yyyy-mm-dd format.
|
||||||
|
look_back_days: Number of days to look back (default 7).
|
||||||
|
limit: Maximum number of articles (default 50).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing global news sentiment data or JSON string.
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
# Calculate start date
|
||||||
|
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
start_dt = curr_dt - timedelta(days=look_back_days)
|
||||||
|
start_date = start_dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"topics": "financial_markets,economy_macro,economy_monetary",
|
||||||
|
"time_from": format_datetime_for_api(start_date),
|
||||||
|
"time_to": format_datetime_for_api(curr_date),
|
||||||
|
"limit": str(limit),
|
||||||
|
}
|
||||||
|
|
||||||
|
return _make_api_request("NEWS_SENTIMENT", params)
|
||||||
|
|
||||||
|
|
||||||
|
def get_insider_transactions(symbol: str) -> dict[str, str] | str:
|
||||||
|
"""Returns latest and historical insider transactions by key stakeholders.
|
||||||
|
|
||||||
|
Covers transactions by founders, executives, board members, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Ticker symbol. Example: "IBM".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing insider transaction data or JSON string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"symbol": symbol,
|
||||||
|
}
|
||||||
|
|
||||||
|
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||||
38
tradingagents/dataflows/alpha_vantage_stock.py
Normal file
38
tradingagents/dataflows/alpha_vantage_stock.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
|
||||||
|
|
||||||
|
def get_stock(
|
||||||
|
symbol: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Returns raw daily OHLCV values, adjusted close values, and historical split/dividend events
|
||||||
|
filtered to the specified date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: The name of the equity. For example: symbol=IBM
|
||||||
|
start_date: Start date in yyyy-mm-dd format
|
||||||
|
end_date: End date in yyyy-mm-dd format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSV string containing the daily adjusted time series data filtered to the date range.
|
||||||
|
"""
|
||||||
|
# Parse dates to determine the range
|
||||||
|
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
today = datetime.now()
|
||||||
|
|
||||||
|
# Choose outputsize based on whether the requested range is within the latest 100 days
|
||||||
|
# Compact returns latest 100 data points, so check if start_date is recent enough
|
||||||
|
days_from_today_to_start = (today - start_dt).days
|
||||||
|
outputsize = "compact" if days_from_today_to_start < 100 else "full"
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"symbol": symbol,
|
||||||
|
"outputsize": outputsize,
|
||||||
|
"datatype": "csv",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
|
||||||
|
|
||||||
|
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||||
@@ -1,33 +1,41 @@
|
|||||||
import tradingagents.default_config as default_config
|
from copy import deepcopy
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
# Use default config but allow it to be overridden
|
||||||
_config: Optional[Dict] = None
|
_config: Optional[Dict] = None
|
||||||
DATA_DIR: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_config():
|
def initialize_config():
|
||||||
"""Initialize the configuration with default values."""
|
"""Initialize the configuration with default values."""
|
||||||
global _config, DATA_DIR
|
global _config
|
||||||
if _config is None:
|
if _config is None:
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
DATA_DIR = _config["data_dir"]
|
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Dict):
|
def set_config(config: Dict):
|
||||||
"""Update the configuration with custom values."""
|
"""Update the configuration with custom values.
|
||||||
global _config, DATA_DIR
|
|
||||||
if _config is None:
|
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
partial update like ``{"data_vendors": {"core_stock_apis": "alpha_vantage"}}``
|
||||||
_config.update(config)
|
keeps the other nested keys from the default; scalar keys are replaced.
|
||||||
DATA_DIR = _config["data_dir"]
|
"""
|
||||||
|
global _config
|
||||||
|
initialize_config()
|
||||||
|
incoming = deepcopy(config)
|
||||||
|
for key, value in incoming.items():
|
||||||
|
if isinstance(value, dict) and isinstance(_config.get(key), dict):
|
||||||
|
_config[key].update(value)
|
||||||
|
else:
|
||||||
|
_config[key] = value
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> Dict:
|
def get_config() -> Dict:
|
||||||
"""Get the current configuration."""
|
"""Get the current configuration."""
|
||||||
if _config is None:
|
if _config is None:
|
||||||
initialize_config()
|
initialize_config()
|
||||||
return _config.copy()
|
return deepcopy(_config)
|
||||||
|
|
||||||
|
|
||||||
# Initialize with default config
|
# Initialize with default config
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None):
|
|
||||||
"""
|
|
||||||
Gets finnhub data saved and processed on disk.
|
|
||||||
Args:
|
|
||||||
start_date (str): Start date in YYYY-MM-DD format.
|
|
||||||
end_date (str): End date in YYYY-MM-DD format.
|
|
||||||
data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported.
|
|
||||||
data_dir (str): Directory where the data is saved.
|
|
||||||
period (str): Default to none, if there is a period specified, should be annual or quarterly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if period:
|
|
||||||
data_path = os.path.join(
|
|
||||||
data_dir,
|
|
||||||
"finnhub_data",
|
|
||||||
data_type,
|
|
||||||
f"{ticker}_{period}_data_formatted.json",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
data_path = os.path.join(
|
|
||||||
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
data = open(data_path, "r")
|
|
||||||
data = json.load(data)
|
|
||||||
|
|
||||||
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)
|
|
||||||
filtered_data = {}
|
|
||||||
for key, value in data.items():
|
|
||||||
if start_date <= key <= end_date and len(value) > 0:
|
|
||||||
filtered_data[key] = value
|
|
||||||
return filtered_data
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
import json
|
|
||||||
import requests
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from datetime import datetime
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
from tenacity import (
|
|
||||||
retry,
|
|
||||||
stop_after_attempt,
|
|
||||||
wait_exponential,
|
|
||||||
retry_if_exception_type,
|
|
||||||
retry_if_result,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_rate_limited(response):
|
|
||||||
"""Check if the response indicates rate limiting (status code 429)"""
|
|
||||||
return response.status_code == 429
|
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
retry=(retry_if_result(is_rate_limited)),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
||||||
stop=stop_after_attempt(5),
|
|
||||||
)
|
|
||||||
def make_request(url, headers):
|
|
||||||
"""Make a request with retry logic for rate limiting"""
|
|
||||||
# Random delay before each request to avoid detection
|
|
||||||
time.sleep(random.uniform(2, 6))
|
|
||||||
response = requests.get(url, headers=headers)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def getNewsData(query, start_date, end_date):
|
|
||||||
"""
|
|
||||||
Scrape Google News search results for a given query and date range.
|
|
||||||
query: str - search query
|
|
||||||
start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy
|
|
||||||
end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy
|
|
||||||
"""
|
|
||||||
if "-" in start_date:
|
|
||||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
|
||||||
start_date = start_date.strftime("%m/%d/%Y")
|
|
||||||
if "-" in end_date:
|
|
||||||
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
|
||||||
end_date = end_date.strftime("%m/%d/%Y")
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"User-Agent": (
|
|
||||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
|
||||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
|
||||||
"Chrome/101.0.4951.54 Safari/537.36"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
news_results = []
|
|
||||||
page = 0
|
|
||||||
while True:
|
|
||||||
offset = page * 10
|
|
||||||
url = (
|
|
||||||
f"https://www.google.com/search?q={query}"
|
|
||||||
f"&tbs=cdr:1,cd_min:{start_date},cd_max:{end_date}"
|
|
||||||
f"&tbm=nws&start={offset}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = make_request(url, headers)
|
|
||||||
soup = BeautifulSoup(response.content, "html.parser")
|
|
||||||
results_on_page = soup.select("div.SoaBEf")
|
|
||||||
|
|
||||||
if not results_on_page:
|
|
||||||
break # No more results found
|
|
||||||
|
|
||||||
for el in results_on_page:
|
|
||||||
try:
|
|
||||||
link = el.find("a")["href"]
|
|
||||||
title = el.select_one("div.MBeuO").get_text()
|
|
||||||
snippet = el.select_one(".GI74Re").get_text()
|
|
||||||
date = el.select_one(".LfVVr").get_text()
|
|
||||||
source = el.select_one(".NUnG9d span").get_text()
|
|
||||||
news_results.append(
|
|
||||||
{
|
|
||||||
"link": link,
|
|
||||||
"title": title,
|
|
||||||
"snippet": snippet,
|
|
||||||
"date": date,
|
|
||||||
"source": source,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing result: {e}")
|
|
||||||
# If one of the fields is not found, skip this result
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Update the progress bar with the current count of results scraped
|
|
||||||
|
|
||||||
# Check for the "Next" link (pagination)
|
|
||||||
next_link = soup.find("a", id="pnnext")
|
|
||||||
if not next_link:
|
|
||||||
break
|
|
||||||
|
|
||||||
page += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed after multiple retries: {e}")
|
|
||||||
break
|
|
||||||
|
|
||||||
return news_results
|
|
||||||
@@ -1,804 +1,162 @@
|
|||||||
from typing import Annotated, Dict
|
from typing import Annotated
|
||||||
from .reddit_utils import fetch_top_from_category
|
|
||||||
from .yfin_utils import *
|
# Import from vendor-specific modules
|
||||||
from .stockstats_utils import *
|
from .y_finance import (
|
||||||
from .googlenews_utils import *
|
get_YFin_data_online,
|
||||||
from .finnhub_utils import get_data_in_range
|
get_stock_stats_indicators_window,
|
||||||
from dateutil.relativedelta import relativedelta
|
get_fundamentals as get_yfinance_fundamentals,
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
get_balance_sheet as get_yfinance_balance_sheet,
|
||||||
from datetime import datetime
|
get_cashflow as get_yfinance_cashflow,
|
||||||
import json
|
get_income_statement as get_yfinance_income_statement,
|
||||||
import os
|
get_insider_transactions as get_yfinance_insider_transactions,
|
||||||
import pandas as pd
|
)
|
||||||
from tqdm import tqdm
|
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
|
||||||
import yfinance as yf
|
from .alpha_vantage import (
|
||||||
from openai import OpenAI
|
get_stock as get_alpha_vantage_stock,
|
||||||
from .config import get_config, set_config, DATA_DIR
|
get_indicator as get_alpha_vantage_indicator,
|
||||||
|
get_fundamentals as get_alpha_vantage_fundamentals,
|
||||||
|
get_balance_sheet as get_alpha_vantage_balance_sheet,
|
||||||
def get_finnhub_news(
|
get_cashflow as get_alpha_vantage_cashflow,
|
||||||
ticker: Annotated[
|
get_income_statement as get_alpha_vantage_income_statement,
|
||||||
str,
|
get_insider_transactions as get_alpha_vantage_insider_transactions,
|
||||||
"Search query of a company's, e.g. 'AAPL, TSM, etc.",
|
get_news as get_alpha_vantage_news,
|
||||||
],
|
get_global_news as get_alpha_vantage_global_news,
|
||||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
)
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
from .alpha_vantage_common import AlphaVantageRateLimitError
|
||||||
):
|
|
||||||
"""
|
# Configuration and routing logic
|
||||||
Retrieve news about a company within a time frame
|
from .config import get_config
|
||||||
|
|
||||||
Args
|
# Tools organized by category
|
||||||
ticker (str): ticker for the company you are interested in
|
TOOLS_CATEGORIES = {
|
||||||
start_date (str): Start date in yyyy-mm-dd format
|
"core_stock_apis": {
|
||||||
end_date (str): End date in yyyy-mm-dd format
|
"description": "OHLCV stock price data",
|
||||||
Returns
|
"tools": [
|
||||||
str: dataframe containing the news of the company in the time frame
|
"get_stock_data"
|
||||||
|
]
|
||||||
"""
|
},
|
||||||
|
"technical_indicators": {
|
||||||
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
"description": "Technical analysis indicators",
|
||||||
before = start_date - relativedelta(days=look_back_days)
|
"tools": [
|
||||||
before = before.strftime("%Y-%m-%d")
|
"get_indicators"
|
||||||
|
]
|
||||||
result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
|
},
|
||||||
|
"fundamental_data": {
|
||||||
if len(result) == 0:
|
"description": "Company fundamentals",
|
||||||
return ""
|
"tools": [
|
||||||
|
"get_fundamentals",
|
||||||
combined_result = ""
|
"get_balance_sheet",
|
||||||
for day, data in result.items():
|
"get_cashflow",
|
||||||
if len(data) == 0:
|
"get_income_statement"
|
||||||
continue
|
]
|
||||||
for entry in data:
|
},
|
||||||
current_news = (
|
"news_data": {
|
||||||
"### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"]
|
"description": "News and insider data",
|
||||||
)
|
"tools": [
|
||||||
combined_result += current_news + "\n\n"
|
"get_news",
|
||||||
|
"get_global_news",
|
||||||
return f"## {ticker} News, from {before} to {curr_date}:\n" + str(combined_result)
|
"get_insider_transactions",
|
||||||
|
]
|
||||||
|
|
||||||
def get_finnhub_company_insider_sentiment(
|
|
||||||
ticker: Annotated[str, "ticker symbol for the company"],
|
|
||||||
curr_date: Annotated[
|
|
||||||
str,
|
|
||||||
"current date of you are trading at, yyyy-mm-dd",
|
|
||||||
],
|
|
||||||
look_back_days: Annotated[int, "number of days to look back"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days
|
|
||||||
Args:
|
|
||||||
ticker (str): ticker symbol of the company
|
|
||||||
curr_date (str): current date you are trading on, yyyy-mm-dd
|
|
||||||
Returns:
|
|
||||||
str: a report of the sentiment in the past 15 days starting at curr_date
|
|
||||||
"""
|
|
||||||
|
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
|
||||||
before = date_obj - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR)
|
|
||||||
|
|
||||||
if len(data) == 0:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
result_str = ""
|
|
||||||
seen_dicts = []
|
|
||||||
for date, senti_list in data.items():
|
|
||||||
for entry in senti_list:
|
|
||||||
if entry not in seen_dicts:
|
|
||||||
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
|
|
||||||
seen_dicts.append(entry)
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"## {ticker} Insider Sentiment Data for {before} to {curr_date}:\n"
|
|
||||||
+ result_str
|
|
||||||
+ "The change field refers to the net buying/selling from all insiders' transactions. The mspr field refers to monthly share purchase ratio."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_finnhub_company_insider_transactions(
|
|
||||||
ticker: Annotated[str, "ticker symbol"],
|
|
||||||
curr_date: Annotated[
|
|
||||||
str,
|
|
||||||
"current date you are trading at, yyyy-mm-dd",
|
|
||||||
],
|
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days
|
|
||||||
Args:
|
|
||||||
ticker (str): ticker symbol of the company
|
|
||||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
|
||||||
Returns:
|
|
||||||
str: a report of the company's insider transaction/trading informtaion in the past 15 days
|
|
||||||
"""
|
|
||||||
|
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
|
||||||
before = date_obj - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR)
|
|
||||||
|
|
||||||
if len(data) == 0:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
result_str = ""
|
|
||||||
|
|
||||||
seen_dicts = []
|
|
||||||
for date, senti_list in data.items():
|
|
||||||
for entry in senti_list:
|
|
||||||
if entry not in seen_dicts:
|
|
||||||
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
|
|
||||||
seen_dicts.append(entry)
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"## {ticker} insider transactions from {before} to {curr_date}:\n"
|
|
||||||
+ result_str
|
|
||||||
+ "The change field reflects the variation in share count—here a negative number indicates a reduction in holdings—while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_simfin_balance_sheet(
|
|
||||||
ticker: Annotated[str, "ticker symbol"],
|
|
||||||
freq: Annotated[
|
|
||||||
str,
|
|
||||||
"reporting frequency of the company's financial history: annual / quarterly",
|
|
||||||
],
|
|
||||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
||||||
):
|
|
||||||
data_path = os.path.join(
|
|
||||||
DATA_DIR,
|
|
||||||
"fundamental_data",
|
|
||||||
"simfin_data_all",
|
|
||||||
"balance_sheet",
|
|
||||||
"companies",
|
|
||||||
"us",
|
|
||||||
f"us-balance-{freq}.csv",
|
|
||||||
)
|
|
||||||
df = pd.read_csv(data_path, sep=";")
|
|
||||||
|
|
||||||
# Convert date strings to datetime objects and remove any time components
|
|
||||||
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
|
|
||||||
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
|
|
||||||
|
|
||||||
# Convert the current date to datetime and normalize
|
|
||||||
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
|
|
||||||
|
|
||||||
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
|
|
||||||
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
|
|
||||||
|
|
||||||
# Check if there are any available reports; if not, return a notification
|
|
||||||
if filtered_df.empty:
|
|
||||||
print("No balance sheet available before the given current date.")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Get the most recent balance sheet by selecting the row with the latest Publish Date
|
|
||||||
latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
|
|
||||||
|
|
||||||
# drop the SimFinID column
|
|
||||||
latest_balance_sheet = latest_balance_sheet.drop("SimFinId")
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"## {freq} balance sheet for {ticker} released on {str(latest_balance_sheet['Publish Date'])[0:10]}: \n"
|
|
||||||
+ str(latest_balance_sheet)
|
|
||||||
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of assets, liabilities, and equity. Assets are grouped as current (liquid items like cash and receivables) and noncurrent (long-term investments and property). Liabilities are split between short-term obligations and long-term debts, while equity reflects shareholder funds such as paid-in capital and retained earnings. Together, these components ensure that total assets equal the sum of liabilities and equity."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_simfin_cashflow(
|
|
||||||
ticker: Annotated[str, "ticker symbol"],
|
|
||||||
freq: Annotated[
|
|
||||||
str,
|
|
||||||
"reporting frequency of the company's financial history: annual / quarterly",
|
|
||||||
],
|
|
||||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
||||||
):
|
|
||||||
data_path = os.path.join(
|
|
||||||
DATA_DIR,
|
|
||||||
"fundamental_data",
|
|
||||||
"simfin_data_all",
|
|
||||||
"cash_flow",
|
|
||||||
"companies",
|
|
||||||
"us",
|
|
||||||
f"us-cashflow-{freq}.csv",
|
|
||||||
)
|
|
||||||
df = pd.read_csv(data_path, sep=";")
|
|
||||||
|
|
||||||
# Convert date strings to datetime objects and remove any time components
|
|
||||||
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
|
|
||||||
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
|
|
||||||
|
|
||||||
# Convert the current date to datetime and normalize
|
|
||||||
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
|
|
||||||
|
|
||||||
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
|
|
||||||
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
|
|
||||||
|
|
||||||
# Check if there are any available reports; if not, return a notification
|
|
||||||
if filtered_df.empty:
|
|
||||||
print("No cash flow statement available before the given current date.")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
|
|
||||||
latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
|
|
||||||
|
|
||||||
# drop the SimFinID column
|
|
||||||
latest_cash_flow = latest_cash_flow.drop("SimFinId")
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"## {freq} cash flow statement for {ticker} released on {str(latest_cash_flow['Publish Date'])[0:10]}: \n"
|
|
||||||
+ str(latest_cash_flow)
|
|
||||||
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of cash movements. Operating activities show cash generated from core business operations, including net income adjustments for non-cash items and working capital changes. Investing activities cover asset acquisitions/disposals and investments. Financing activities include debt transactions, equity issuances/repurchases, and dividend payments. The net change in cash represents the overall increase or decrease in the company's cash position during the reporting period."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_simfin_income_statements(
|
|
||||||
ticker: Annotated[str, "ticker symbol"],
|
|
||||||
freq: Annotated[
|
|
||||||
str,
|
|
||||||
"reporting frequency of the company's financial history: annual / quarterly",
|
|
||||||
],
|
|
||||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
|
||||||
):
|
|
||||||
data_path = os.path.join(
|
|
||||||
DATA_DIR,
|
|
||||||
"fundamental_data",
|
|
||||||
"simfin_data_all",
|
|
||||||
"income_statements",
|
|
||||||
"companies",
|
|
||||||
"us",
|
|
||||||
f"us-income-{freq}.csv",
|
|
||||||
)
|
|
||||||
df = pd.read_csv(data_path, sep=";")
|
|
||||||
|
|
||||||
# Convert date strings to datetime objects and remove any time components
|
|
||||||
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
|
|
||||||
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
|
|
||||||
|
|
||||||
# Convert the current date to datetime and normalize
|
|
||||||
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
|
|
||||||
|
|
||||||
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
|
|
||||||
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
|
|
||||||
|
|
||||||
# Check if there are any available reports; if not, return a notification
|
|
||||||
if filtered_df.empty:
|
|
||||||
print("No income statement available before the given current date.")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Get the most recent income statement by selecting the row with the latest Publish Date
|
|
||||||
latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
|
|
||||||
|
|
||||||
# drop the SimFinID column
|
|
||||||
latest_income = latest_income.drop("SimFinId")
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"## {freq} income statement for {ticker} released on {str(latest_income['Publish Date'])[0:10]}: \n"
|
|
||||||
+ str(latest_income)
|
|
||||||
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a comprehensive breakdown of the company's financial performance. Starting with Revenue, it shows Cost of Revenue and resulting Gross Profit. Operating Expenses are detailed, including SG&A, R&D, and Depreciation. The statement then shows Operating Income, followed by non-operating items and Interest Expense, leading to Pretax Income. After accounting for Income Tax and any Extraordinary items, it concludes with Net Income, representing the company's bottom-line profit or loss for the period."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_google_news(
|
|
||||||
query: Annotated[str, "Query to search with"],
|
|
||||||
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
|
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
|
||||||
) -> str:
|
|
||||||
query = query.replace(" ", "+")
|
|
||||||
|
|
||||||
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
|
||||||
before = start_date - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
news_results = getNewsData(query, before, curr_date)
|
|
||||||
|
|
||||||
news_str = ""
|
|
||||||
|
|
||||||
for news in news_results:
|
|
||||||
news_str += (
|
|
||||||
f"### {news['title']} (source: {news['source']}) \n\n{news['snippet']}\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(news_results) == 0:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_reddit_global_news(
|
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
|
||||||
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Retrieve the latest top reddit news
|
|
||||||
Args:
|
|
||||||
start_date: Start date in yyyy-mm-dd format
|
|
||||||
end_date: End date in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
|
||||||
"""
|
|
||||||
|
|
||||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
|
||||||
before = start_date - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
posts = []
|
|
||||||
# iterate from start_date to end_date
|
|
||||||
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
|
||||||
|
|
||||||
total_iterations = (start_date - curr_date).days + 1
|
|
||||||
pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations)
|
|
||||||
|
|
||||||
while curr_date <= start_date:
|
|
||||||
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
|
||||||
fetch_result = fetch_top_from_category(
|
|
||||||
"global_news",
|
|
||||||
curr_date_str,
|
|
||||||
max_limit_per_day,
|
|
||||||
data_path=os.path.join(DATA_DIR, "reddit_data"),
|
|
||||||
)
|
|
||||||
posts.extend(fetch_result)
|
|
||||||
curr_date += relativedelta(days=1)
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
pbar.close()
|
|
||||||
|
|
||||||
if len(posts) == 0:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
news_str = ""
|
|
||||||
for post in posts:
|
|
||||||
if post["content"] == "":
|
|
||||||
news_str += f"### {post['title']}\n\n"
|
|
||||||
else:
|
|
||||||
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
|
|
||||||
|
|
||||||
return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_reddit_company_news(
|
|
||||||
ticker: Annotated[str, "ticker symbol of the company"],
|
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
|
||||||
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Retrieve the latest top reddit news
|
|
||||||
Args:
|
|
||||||
ticker: ticker symbol of the company
|
|
||||||
start_date: Start date in yyyy-mm-dd format
|
|
||||||
end_date: End date in yyyy-mm-dd format
|
|
||||||
Returns:
|
|
||||||
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
|
||||||
"""
|
|
||||||
|
|
||||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
|
||||||
before = start_date - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
posts = []
|
|
||||||
# iterate from start_date to end_date
|
|
||||||
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
|
||||||
|
|
||||||
total_iterations = (start_date - curr_date).days + 1
|
|
||||||
pbar = tqdm(
|
|
||||||
desc=f"Getting Company News for {ticker} on {start_date}",
|
|
||||||
total=total_iterations,
|
|
||||||
)
|
|
||||||
|
|
||||||
while curr_date <= start_date:
|
|
||||||
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
|
||||||
fetch_result = fetch_top_from_category(
|
|
||||||
"company_news",
|
|
||||||
curr_date_str,
|
|
||||||
max_limit_per_day,
|
|
||||||
ticker,
|
|
||||||
data_path=os.path.join(DATA_DIR, "reddit_data"),
|
|
||||||
)
|
|
||||||
posts.extend(fetch_result)
|
|
||||||
curr_date += relativedelta(days=1)
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
pbar.close()
|
|
||||||
|
|
||||||
if len(posts) == 0:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
news_str = ""
|
|
||||||
for post in posts:
|
|
||||||
if post["content"] == "":
|
|
||||||
news_str += f"### {post['title']}\n\n"
|
|
||||||
else:
|
|
||||||
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
|
|
||||||
|
|
||||||
return f"##{ticker} News Reddit, from {before} to {curr_date}:\n\n{news_str}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_stock_stats_indicators_window(
|
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
|
||||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
|
||||||
curr_date: Annotated[
|
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
|
||||||
],
|
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
|
||||||
online: Annotated[bool, "to fetch data online or offline"],
|
|
||||||
) -> str:
|
|
||||||
|
|
||||||
best_ind_params = {
|
|
||||||
# Moving Averages
|
|
||||||
"close_50_sma": (
|
|
||||||
"50 SMA: A medium-term trend indicator. "
|
|
||||||
"Usage: Identify trend direction and serve as dynamic support/resistance. "
|
|
||||||
"Tips: It lags price; combine with faster indicators for timely signals."
|
|
||||||
),
|
|
||||||
"close_200_sma": (
|
|
||||||
"200 SMA: A long-term trend benchmark. "
|
|
||||||
"Usage: Confirm overall market trend and identify golden/death cross setups. "
|
|
||||||
"Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries."
|
|
||||||
),
|
|
||||||
"close_10_ema": (
|
|
||||||
"10 EMA: A responsive short-term average. "
|
|
||||||
"Usage: Capture quick shifts in momentum and potential entry points. "
|
|
||||||
"Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals."
|
|
||||||
),
|
|
||||||
# MACD Related
|
|
||||||
"macd": (
|
|
||||||
"MACD: Computes momentum via differences of EMAs. "
|
|
||||||
"Usage: Look for crossovers and divergence as signals of trend changes. "
|
|
||||||
"Tips: Confirm with other indicators in low-volatility or sideways markets."
|
|
||||||
),
|
|
||||||
"macds": (
|
|
||||||
"MACD Signal: An EMA smoothing of the MACD line. "
|
|
||||||
"Usage: Use crossovers with the MACD line to trigger trades. "
|
|
||||||
"Tips: Should be part of a broader strategy to avoid false positives."
|
|
||||||
),
|
|
||||||
"macdh": (
|
|
||||||
"MACD Histogram: Shows the gap between the MACD line and its signal. "
|
|
||||||
"Usage: Visualize momentum strength and spot divergence early. "
|
|
||||||
"Tips: Can be volatile; complement with additional filters in fast-moving markets."
|
|
||||||
),
|
|
||||||
# Momentum Indicators
|
|
||||||
"rsi": (
|
|
||||||
"RSI: Measures momentum to flag overbought/oversold conditions. "
|
|
||||||
"Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. "
|
|
||||||
"Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis."
|
|
||||||
),
|
|
||||||
# Volatility Indicators
|
|
||||||
"boll": (
|
|
||||||
"Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. "
|
|
||||||
"Usage: Acts as a dynamic benchmark for price movement. "
|
|
||||||
"Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals."
|
|
||||||
),
|
|
||||||
"boll_ub": (
|
|
||||||
"Bollinger Upper Band: Typically 2 standard deviations above the middle line. "
|
|
||||||
"Usage: Signals potential overbought conditions and breakout zones. "
|
|
||||||
"Tips: Confirm signals with other tools; prices may ride the band in strong trends."
|
|
||||||
),
|
|
||||||
"boll_lb": (
|
|
||||||
"Bollinger Lower Band: Typically 2 standard deviations below the middle line. "
|
|
||||||
"Usage: Indicates potential oversold conditions. "
|
|
||||||
"Tips: Use additional analysis to avoid false reversal signals."
|
|
||||||
),
|
|
||||||
"atr": (
|
|
||||||
"ATR: Averages true range to measure volatility. "
|
|
||||||
"Usage: Set stop-loss levels and adjust position sizes based on current market volatility. "
|
|
||||||
"Tips: It's a reactive measure, so use it as part of a broader risk management strategy."
|
|
||||||
),
|
|
||||||
# Volume-Based Indicators
|
|
||||||
"vwma": (
|
|
||||||
"VWMA: A moving average weighted by volume. "
|
|
||||||
"Usage: Confirm trends by integrating price action with volume data. "
|
|
||||||
"Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses."
|
|
||||||
),
|
|
||||||
"mfi": (
|
|
||||||
"MFI: The Money Flow Index is a momentum indicator that uses both price and volume to measure buying and selling pressure. "
|
|
||||||
"Usage: Identify overbought (>80) or oversold (<20) conditions and confirm the strength of trends or reversals. "
|
|
||||||
"Tips: Use alongside RSI or MACD to confirm signals; divergence between price and MFI can indicate potential reversals."
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if indicator not in best_ind_params:
|
VENDOR_LIST = [
|
||||||
raise ValueError(
|
"yfinance",
|
||||||
f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
|
"alpha_vantage",
|
||||||
)
|
]
|
||||||
|
|
||||||
end_date = curr_date
|
# Mapping of methods to their vendor-specific implementations
|
||||||
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
VENDOR_METHODS = {
|
||||||
before = curr_date - relativedelta(days=look_back_days)
|
# core_stock_apis
|
||||||
|
"get_stock_data": {
|
||||||
|
"alpha_vantage": get_alpha_vantage_stock,
|
||||||
|
"yfinance": get_YFin_data_online,
|
||||||
|
},
|
||||||
|
# technical_indicators
|
||||||
|
"get_indicators": {
|
||||||
|
"alpha_vantage": get_alpha_vantage_indicator,
|
||||||
|
"yfinance": get_stock_stats_indicators_window,
|
||||||
|
},
|
||||||
|
# fundamental_data
|
||||||
|
"get_fundamentals": {
|
||||||
|
"alpha_vantage": get_alpha_vantage_fundamentals,
|
||||||
|
"yfinance": get_yfinance_fundamentals,
|
||||||
|
},
|
||||||
|
"get_balance_sheet": {
|
||||||
|
"alpha_vantage": get_alpha_vantage_balance_sheet,
|
||||||
|
"yfinance": get_yfinance_balance_sheet,
|
||||||
|
},
|
||||||
|
"get_cashflow": {
|
||||||
|
"alpha_vantage": get_alpha_vantage_cashflow,
|
||||||
|
"yfinance": get_yfinance_cashflow,
|
||||||
|
},
|
||||||
|
"get_income_statement": {
|
||||||
|
"alpha_vantage": get_alpha_vantage_income_statement,
|
||||||
|
"yfinance": get_yfinance_income_statement,
|
||||||
|
},
|
||||||
|
# news_data
|
||||||
|
"get_news": {
|
||||||
|
"alpha_vantage": get_alpha_vantage_news,
|
||||||
|
"yfinance": get_news_yfinance,
|
||||||
|
},
|
||||||
|
"get_global_news": {
|
||||||
|
"yfinance": get_global_news_yfinance,
|
||||||
|
"alpha_vantage": get_alpha_vantage_global_news,
|
||||||
|
},
|
||||||
|
"get_insider_transactions": {
|
||||||
|
"alpha_vantage": get_alpha_vantage_insider_transactions,
|
||||||
|
"yfinance": get_yfinance_insider_transactions,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
if not online:
|
def get_category_for_method(method: str) -> str:
|
||||||
# read from YFin data
|
"""Get the category that contains the specified method."""
|
||||||
data = pd.read_csv(
|
for category, info in TOOLS_CATEGORIES.items():
|
||||||
os.path.join(
|
if method in info["tools"]:
|
||||||
DATA_DIR,
|
return category
|
||||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
raise ValueError(f"Method '{method}' not found in any category")
|
||||||
)
|
|
||||||
)
|
|
||||||
data["Date"] = pd.to_datetime(data["Date"], utc=True)
|
|
||||||
dates_in_df = data["Date"].astype(str).str[:10]
|
|
||||||
|
|
||||||
ind_string = ""
|
def get_vendor(category: str, method: str = None) -> str:
|
||||||
while curr_date >= before:
|
"""Get the configured vendor for a data category or specific tool method.
|
||||||
# only do the trading dates
|
Tool-level configuration takes precedence over category-level.
|
||||||
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
|
"""
|
||||||
indicator_value = get_stockstats_indicator(
|
config = get_config()
|
||||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
|
|
||||||
)
|
|
||||||
|
|
||||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
# Check tool-level configuration first (if method provided)
|
||||||
|
if method:
|
||||||
|
tool_vendors = config.get("tool_vendors", {})
|
||||||
|
if method in tool_vendors:
|
||||||
|
return tool_vendors[method]
|
||||||
|
|
||||||
curr_date = curr_date - relativedelta(days=1)
|
# Fall back to category-level configuration
|
||||||
else:
|
return config.get("data_vendors", {}).get(category, "default")
|
||||||
# online gathering
|
|
||||||
ind_string = ""
|
|
||||||
while curr_date >= before:
|
|
||||||
indicator_value = get_stockstats_indicator(
|
|
||||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
|
|
||||||
)
|
|
||||||
|
|
||||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
def route_to_vendor(method: str, *args, **kwargs):
|
||||||
|
"""Route method calls to appropriate vendor implementation with fallback support."""
|
||||||
|
category = get_category_for_method(method)
|
||||||
|
vendor_config = get_vendor(category, method)
|
||||||
|
primary_vendors = [v.strip() for v in vendor_config.split(',')]
|
||||||
|
|
||||||
curr_date = curr_date - relativedelta(days=1)
|
if method not in VENDOR_METHODS:
|
||||||
|
raise ValueError(f"Method '{method}' not supported")
|
||||||
|
|
||||||
result_str = (
|
# Build fallback chain: primary vendors first, then remaining available vendors
|
||||||
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
|
all_available_vendors = list(VENDOR_METHODS[method].keys())
|
||||||
+ ind_string
|
fallback_vendors = primary_vendors.copy()
|
||||||
+ "\n\n"
|
for vendor in all_available_vendors:
|
||||||
+ best_ind_params.get(indicator, "No description available.")
|
if vendor not in fallback_vendors:
|
||||||
)
|
fallback_vendors.append(vendor)
|
||||||
|
|
||||||
return result_str
|
for vendor in fallback_vendors:
|
||||||
|
if vendor not in VENDOR_METHODS[method]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vendor_impl = VENDOR_METHODS[method][vendor]
|
||||||
|
impl_func = vendor_impl[0] if isinstance(vendor_impl, list) else vendor_impl
|
||||||
|
|
||||||
def get_stockstats_indicator(
|
try:
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
return impl_func(*args, **kwargs)
|
||||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
except AlphaVantageRateLimitError:
|
||||||
curr_date: Annotated[
|
continue # Only rate limits trigger fallback
|
||||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
|
||||||
],
|
|
||||||
online: Annotated[bool, "to fetch data online or offline"],
|
|
||||||
) -> str:
|
|
||||||
|
|
||||||
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
raise RuntimeError(f"No available vendor for '{method}'")
|
||||||
curr_date = curr_date.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
try:
|
|
||||||
indicator_value = StockstatsUtils.get_stock_stats(
|
|
||||||
symbol,
|
|
||||||
indicator,
|
|
||||||
curr_date,
|
|
||||||
os.path.join(DATA_DIR, "market_data", "price_data"),
|
|
||||||
online=online,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
return str(indicator_value)
|
|
||||||
|
|
||||||
|
|
||||||
def get_YFin_data_window(
|
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
|
||||||
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
|
||||||
) -> str:
|
|
||||||
# calculate past days
|
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
|
||||||
before = date_obj - relativedelta(days=look_back_days)
|
|
||||||
start_date = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
# read in data
|
|
||||||
data = pd.read_csv(
|
|
||||||
os.path.join(
|
|
||||||
DATA_DIR,
|
|
||||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract just the date part for comparison
|
|
||||||
data["DateOnly"] = data["Date"].str[:10]
|
|
||||||
|
|
||||||
# Filter data between the start and end dates (inclusive)
|
|
||||||
filtered_data = data[
|
|
||||||
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Drop the temporary column we created
|
|
||||||
filtered_data = filtered_data.drop("DateOnly", axis=1)
|
|
||||||
|
|
||||||
# Set pandas display options to show the full DataFrame
|
|
||||||
with pd.option_context(
|
|
||||||
"display.max_rows", None, "display.max_columns", None, "display.width", None
|
|
||||||
):
|
|
||||||
df_string = filtered_data.to_string()
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n"
|
|
||||||
+ df_string
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_YFin_data_online(
|
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
):
|
|
||||||
|
|
||||||
datetime.strptime(start_date, "%Y-%m-%d")
|
|
||||||
datetime.strptime(end_date, "%Y-%m-%d")
|
|
||||||
|
|
||||||
# Create ticker object
|
|
||||||
ticker = yf.Ticker(symbol.upper())
|
|
||||||
|
|
||||||
# Fetch historical data for the specified date range
|
|
||||||
data = ticker.history(start=start_date, end=end_date)
|
|
||||||
|
|
||||||
# Check if data is empty
|
|
||||||
if data.empty:
|
|
||||||
return (
|
|
||||||
f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove timezone info from index for cleaner output
|
|
||||||
if data.index.tz is not None:
|
|
||||||
data.index = data.index.tz_localize(None)
|
|
||||||
|
|
||||||
# Round numerical values to 2 decimal places for cleaner display
|
|
||||||
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
|
|
||||||
for col in numeric_columns:
|
|
||||||
if col in data.columns:
|
|
||||||
data[col] = data[col].round(2)
|
|
||||||
|
|
||||||
# Convert DataFrame to CSV string
|
|
||||||
csv_string = data.to_csv()
|
|
||||||
|
|
||||||
# Add header information
|
|
||||||
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
|
|
||||||
header += f"# Total records: {len(data)}\n"
|
|
||||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
|
||||||
|
|
||||||
return header + csv_string
|
|
||||||
|
|
||||||
|
|
||||||
def get_YFin_data(
|
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
end_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
|
||||||
) -> str:
|
|
||||||
# read in data
|
|
||||||
data = pd.read_csv(
|
|
||||||
os.path.join(
|
|
||||||
DATA_DIR,
|
|
||||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if end_date > "2025-03-25":
|
|
||||||
raise Exception(
|
|
||||||
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract just the date part for comparison
|
|
||||||
data["DateOnly"] = data["Date"].str[:10]
|
|
||||||
|
|
||||||
# Filter data between the start and end dates (inclusive)
|
|
||||||
filtered_data = data[
|
|
||||||
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Drop the temporary column we created
|
|
||||||
filtered_data = filtered_data.drop("DateOnly", axis=1)
|
|
||||||
|
|
||||||
# remove the index from the dataframe
|
|
||||||
filtered_data = filtered_data.reset_index(drop=True)
|
|
||||||
|
|
||||||
return filtered_data
|
|
||||||
|
|
||||||
|
|
||||||
def get_stock_news_openai(ticker, curr_date):
|
|
||||||
client = OpenAI()
|
|
||||||
|
|
||||||
response = client.responses.create(
|
|
||||||
model="gpt-4.1-mini",
|
|
||||||
input=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "input_text",
|
|
||||||
"text": f"Can you search Social Media for {ticker} on TSLA from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
text={"format": {"type": "text"}},
|
|
||||||
reasoning={},
|
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"type": "web_search_preview",
|
|
||||||
"user_location": {"type": "approximate"},
|
|
||||||
"search_context_size": "low",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
temperature=1,
|
|
||||||
max_output_tokens=4096,
|
|
||||||
top_p=1,
|
|
||||||
store=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.output[1].content[0].text
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_news_openai(curr_date):
|
|
||||||
client = OpenAI()
|
|
||||||
|
|
||||||
response = client.responses.create(
|
|
||||||
model="gpt-4.1-mini",
|
|
||||||
input=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "input_text",
|
|
||||||
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
text={"format": {"type": "text"}},
|
|
||||||
reasoning={},
|
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"type": "web_search_preview",
|
|
||||||
"user_location": {"type": "approximate"},
|
|
||||||
"search_context_size": "low",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
temperature=1,
|
|
||||||
max_output_tokens=4096,
|
|
||||||
top_p=1,
|
|
||||||
store=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.output[1].content[0].text
|
|
||||||
|
|
||||||
|
|
||||||
def get_fundamentals_openai(ticker, curr_date):
|
|
||||||
client = OpenAI()
|
|
||||||
|
|
||||||
response = client.responses.create(
|
|
||||||
model="gpt-4.1-mini",
|
|
||||||
input=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "input_text",
|
|
||||||
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
text={"format": {"type": "text"}},
|
|
||||||
reasoning={},
|
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"type": "web_search_preview",
|
|
||||||
"user_location": {"type": "approximate"},
|
|
||||||
"search_context_size": "low",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
temperature=1,
|
|
||||||
max_output_tokens=4096,
|
|
||||||
top_p=1,
|
|
||||||
store=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.output[1].content[0].text
|
|
||||||
106
tradingagents/dataflows/reddit.py
Normal file
106
tradingagents/dataflows/reddit.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""Reddit search fetcher for ticker-specific discussion posts.
|
||||||
|
|
||||||
|
Uses Reddit's public JSON endpoints (``reddit.com/r/{sub}/search.json``)
|
||||||
|
which do not require an API key. Public throughput is ~10 requests per
|
||||||
|
minute per IP, well within budget for a single agent run that queries
|
||||||
|
a handful of finance subreddits per ticker.
|
||||||
|
|
||||||
|
Returns formatted plaintext blocks ready for prompt injection. Degrades
|
||||||
|
gracefully — returns a placeholder string rather than raising, so callers
|
||||||
|
never have to special-case missing data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Iterable
|
||||||
|
from urllib.error import HTTPError, URLError
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_API = "https://www.reddit.com/r/{sub}/search.json?{qs}"
|
||||||
|
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
|
||||||
|
|
||||||
|
# Default subreddits ordered roughly by signal density for ticker-specific
|
||||||
|
# discussion. wallstreetbets has the most volume but most noise; stocks /
|
||||||
|
# investing trend more measured. Caller can override.
|
||||||
|
DEFAULT_SUBREDDITS = ("wallstreetbets", "stocks", "investing")
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_subreddit(
|
||||||
|
ticker: str,
|
||||||
|
sub: str,
|
||||||
|
limit: int,
|
||||||
|
timeout: float,
|
||||||
|
) -> list[dict]:
|
||||||
|
qs = urlencode({
|
||||||
|
"q": ticker,
|
||||||
|
"restrict_sr": "on",
|
||||||
|
"sort": "new",
|
||||||
|
"t": "week", # last 7 days
|
||||||
|
"limit": limit,
|
||||||
|
})
|
||||||
|
url = _API.format(sub=sub, qs=qs)
|
||||||
|
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
|
||||||
|
try:
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
payload = json.loads(resp.read())
|
||||||
|
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
|
||||||
|
logger.warning("Reddit fetch failed for r/%s · %s: %s", sub, ticker, exc)
|
||||||
|
return []
|
||||||
|
children = (payload.get("data") or {}).get("children") or []
|
||||||
|
return [c.get("data", {}) for c in children if isinstance(c, dict)]
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_reddit_posts(
|
||||||
|
ticker: str,
|
||||||
|
subreddits: Iterable[str] = DEFAULT_SUBREDDITS,
|
||||||
|
limit_per_sub: int = 5,
|
||||||
|
timeout: float = 10.0,
|
||||||
|
inter_request_delay: float = 0.4,
|
||||||
|
) -> str:
|
||||||
|
"""Fetch recent Reddit posts mentioning ``ticker`` across finance
|
||||||
|
subreddits and return them as a formatted plaintext block.
|
||||||
|
|
||||||
|
``inter_request_delay`` keeps us under Reddit's public rate limit
|
||||||
|
(~10 req/min per IP) even if the caller queries many subreddits.
|
||||||
|
"""
|
||||||
|
blocks = []
|
||||||
|
total_posts = 0
|
||||||
|
for i, sub in enumerate(subreddits):
|
||||||
|
if i > 0:
|
||||||
|
time.sleep(inter_request_delay)
|
||||||
|
posts = _fetch_subreddit(ticker, sub, limit_per_sub, timeout)
|
||||||
|
total_posts += len(posts)
|
||||||
|
if not posts:
|
||||||
|
blocks.append(f"r/{sub}: <no posts found mentioning {ticker.upper()} in the past 7 days>")
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines = [f"r/{sub} — {len(posts)} recent posts mentioning {ticker.upper()}:"]
|
||||||
|
for p in posts:
|
||||||
|
title = (p.get("title") or "").replace("\n", " ").strip()
|
||||||
|
score = p.get("score", 0)
|
||||||
|
comments = p.get("num_comments", 0)
|
||||||
|
created = p.get("created_utc")
|
||||||
|
created_str = (
|
||||||
|
time.strftime("%Y-%m-%d", time.gmtime(created)) if created else "?"
|
||||||
|
)
|
||||||
|
selftext = (p.get("selftext") or "").replace("\n", " ").strip()
|
||||||
|
if len(selftext) > 240:
|
||||||
|
selftext = selftext[:240] + "…"
|
||||||
|
lines.append(
|
||||||
|
f" [{created_str} · {score:>4}↑ · {comments:>3}c] {title}"
|
||||||
|
+ (f"\n body excerpt: {selftext}" if selftext else "")
|
||||||
|
)
|
||||||
|
blocks.append("\n".join(lines))
|
||||||
|
|
||||||
|
if total_posts == 0:
|
||||||
|
return (
|
||||||
|
f"<no Reddit posts found mentioning {ticker.upper()} across "
|
||||||
|
f"{', '.join(f'r/{s}' for s in subreddits)} in the past 7 days>"
|
||||||
|
)
|
||||||
|
return "\n\n".join(blocks)
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
import requests
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Annotated
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
ticker_to_company = {
|
|
||||||
"AAPL": "Apple",
|
|
||||||
"MSFT": "Microsoft",
|
|
||||||
"GOOGL": "Google",
|
|
||||||
"AMZN": "Amazon",
|
|
||||||
"TSLA": "Tesla",
|
|
||||||
"NVDA": "Nvidia",
|
|
||||||
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
|
|
||||||
"JPM": "JPMorgan Chase OR JP Morgan",
|
|
||||||
"JNJ": "Johnson & Johnson OR JNJ",
|
|
||||||
"V": "Visa",
|
|
||||||
"WMT": "Walmart",
|
|
||||||
"META": "Meta OR Facebook",
|
|
||||||
"AMD": "AMD",
|
|
||||||
"INTC": "Intel",
|
|
||||||
"QCOM": "Qualcomm",
|
|
||||||
"BABA": "Alibaba",
|
|
||||||
"ADBE": "Adobe",
|
|
||||||
"NFLX": "Netflix",
|
|
||||||
"CRM": "Salesforce",
|
|
||||||
"PYPL": "PayPal",
|
|
||||||
"PLTR": "Palantir",
|
|
||||||
"MU": "Micron",
|
|
||||||
"SQ": "Block OR Square",
|
|
||||||
"ZM": "Zoom",
|
|
||||||
"CSCO": "Cisco",
|
|
||||||
"SHOP": "Shopify",
|
|
||||||
"ORCL": "Oracle",
|
|
||||||
"X": "Twitter OR X",
|
|
||||||
"SPOT": "Spotify",
|
|
||||||
"AVGO": "Broadcom",
|
|
||||||
"ASML": "ASML ",
|
|
||||||
"TWLO": "Twilio",
|
|
||||||
"SNAP": "Snap Inc.",
|
|
||||||
"TEAM": "Atlassian",
|
|
||||||
"SQSP": "Squarespace",
|
|
||||||
"UBER": "Uber",
|
|
||||||
"ROKU": "Roku",
|
|
||||||
"PINS": "Pinterest",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_top_from_category(
|
|
||||||
category: Annotated[
|
|
||||||
str, "Category to fetch top post from. Collection of subreddits."
|
|
||||||
],
|
|
||||||
date: Annotated[str, "Date to fetch top posts from."],
|
|
||||||
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
|
||||||
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
|
|
||||||
data_path: Annotated[
|
|
||||||
str,
|
|
||||||
"Path to the data folder. Default is 'reddit_data'.",
|
|
||||||
] = "reddit_data",
|
|
||||||
):
|
|
||||||
base_path = data_path
|
|
||||||
|
|
||||||
all_content = []
|
|
||||||
|
|
||||||
if max_limit < len(os.listdir(os.path.join(base_path, category))):
|
|
||||||
raise ValueError(
|
|
||||||
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
|
|
||||||
)
|
|
||||||
|
|
||||||
limit_per_subreddit = max_limit // len(
|
|
||||||
os.listdir(os.path.join(base_path, category))
|
|
||||||
)
|
|
||||||
|
|
||||||
for data_file in os.listdir(os.path.join(base_path, category)):
|
|
||||||
# check if data_file is a .jsonl file
|
|
||||||
if not data_file.endswith(".jsonl"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
all_content_curr_subreddit = []
|
|
||||||
|
|
||||||
with open(os.path.join(base_path, category, data_file), "rb") as f:
|
|
||||||
for i, line in enumerate(f):
|
|
||||||
# skip empty lines
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
parsed_line = json.loads(line)
|
|
||||||
|
|
||||||
# select only lines that are from the date
|
|
||||||
post_date = datetime.utcfromtimestamp(
|
|
||||||
parsed_line["created_utc"]
|
|
||||||
).strftime("%Y-%m-%d")
|
|
||||||
if post_date != date:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if is company_news, check that the title or the content has the company's name (query) mentioned
|
|
||||||
if "company" in category and query:
|
|
||||||
search_terms = []
|
|
||||||
if "OR" in ticker_to_company[query]:
|
|
||||||
search_terms = ticker_to_company[query].split(" OR ")
|
|
||||||
else:
|
|
||||||
search_terms = [ticker_to_company[query]]
|
|
||||||
|
|
||||||
search_terms.append(query)
|
|
||||||
|
|
||||||
found = False
|
|
||||||
for term in search_terms:
|
|
||||||
if re.search(
|
|
||||||
term, parsed_line["title"], re.IGNORECASE
|
|
||||||
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
|
||||||
found = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
continue
|
|
||||||
|
|
||||||
post = {
|
|
||||||
"title": parsed_line["title"],
|
|
||||||
"content": parsed_line["selftext"],
|
|
||||||
"url": parsed_line["url"],
|
|
||||||
"upvotes": parsed_line["ups"],
|
|
||||||
"posted_date": post_date,
|
|
||||||
}
|
|
||||||
|
|
||||||
all_content_curr_subreddit.append(post)
|
|
||||||
|
|
||||||
# sort all_content_curr_subreddit by upvote_ratio in descending order
|
|
||||||
all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True)
|
|
||||||
|
|
||||||
all_content.extend(all_content_curr_subreddit[:limit_per_subreddit])
|
|
||||||
|
|
||||||
return all_content
|
|
||||||
@@ -1,9 +1,110 @@
|
|||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
|
from yfinance.exceptions import YFRateLimitError
|
||||||
from stockstats import wrap
|
from stockstats import wrap
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
import os
|
import os
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
|
from .utils import safe_ticker_component
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def yf_retry(func, max_retries=3, base_delay=2.0):
|
||||||
|
"""Execute a yfinance call with exponential backoff on rate limits.
|
||||||
|
|
||||||
|
yfinance raises YFRateLimitError on HTTP 429 responses but does not
|
||||||
|
retry them internally. This wrapper adds retry logic specifically
|
||||||
|
for rate limits. Other exceptions propagate immediately.
|
||||||
|
"""
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except YFRateLimitError:
|
||||||
|
if attempt < max_retries:
|
||||||
|
delay = base_delay * (2 ** attempt)
|
||||||
|
logger.warning(f"Yahoo Finance rate limited, retrying in {delay:.0f}s (attempt {attempt + 1}/{max_retries})")
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps."""
|
||||||
|
data["Date"] = pd.to_datetime(data["Date"], errors="coerce")
|
||||||
|
data = data.dropna(subset=["Date"])
|
||||||
|
|
||||||
|
price_cols = [c for c in ["Open", "High", "Low", "Close", "Volume"] if c in data.columns]
|
||||||
|
data[price_cols] = data[price_cols].apply(pd.to_numeric, errors="coerce")
|
||||||
|
data = data.dropna(subset=["Close"])
|
||||||
|
data[price_cols] = data[price_cols].ffill().bfill()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
||||||
|
"""Fetch OHLCV data with caching, filtered to prevent look-ahead bias.
|
||||||
|
|
||||||
|
Downloads 15 years of data up to today and caches per symbol. On
|
||||||
|
subsequent calls the cache is reused. Rows after curr_date are
|
||||||
|
filtered out so backtests never see future prices.
|
||||||
|
"""
|
||||||
|
# Reject ticker values that would escape the cache directory when
|
||||||
|
# interpolated into the cache filename (e.g. ``../../tmp/x``).
|
||||||
|
safe_symbol = safe_ticker_component(symbol)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
curr_date_dt = pd.to_datetime(curr_date)
|
||||||
|
|
||||||
|
# Cache uses a fixed window (15y to today) so one file per symbol
|
||||||
|
today_date = pd.Timestamp.today()
|
||||||
|
start_date = today_date - pd.DateOffset(years=5)
|
||||||
|
start_str = start_date.strftime("%Y-%m-%d")
|
||||||
|
end_str = today_date.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||||
|
data_file = os.path.join(
|
||||||
|
config["data_cache_dir"],
|
||||||
|
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.exists(data_file):
|
||||||
|
data = pd.read_csv(data_file, on_bad_lines="skip", encoding="utf-8")
|
||||||
|
else:
|
||||||
|
data = yf_retry(lambda: yf.download(
|
||||||
|
symbol,
|
||||||
|
start=start_str,
|
||||||
|
end=end_str,
|
||||||
|
multi_level_index=False,
|
||||||
|
progress=False,
|
||||||
|
auto_adjust=True,
|
||||||
|
))
|
||||||
|
data = data.reset_index()
|
||||||
|
data.to_csv(data_file, index=False, encoding="utf-8")
|
||||||
|
|
||||||
|
data = _clean_dataframe(data)
|
||||||
|
|
||||||
|
# Filter to curr_date to prevent look-ahead bias in backtesting
|
||||||
|
data = data[data["Date"] <= curr_date_dt]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def filter_financials_by_date(data: pd.DataFrame, curr_date: str) -> pd.DataFrame:
|
||||||
|
"""Drop financial statement columns (fiscal period timestamps) after curr_date.
|
||||||
|
|
||||||
|
yfinance financial statements use fiscal period end dates as columns.
|
||||||
|
Columns after curr_date represent future data and are removed to
|
||||||
|
prevent look-ahead bias.
|
||||||
|
"""
|
||||||
|
if not curr_date or data.empty:
|
||||||
|
return data
|
||||||
|
cutoff = pd.Timestamp(curr_date)
|
||||||
|
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
|
||||||
|
return data.loc[:, mask]
|
||||||
|
|
||||||
|
|
||||||
class StockstatsUtils:
|
class StockstatsUtils:
|
||||||
@@ -16,69 +117,14 @@ class StockstatsUtils:
|
|||||||
curr_date: Annotated[
|
curr_date: Annotated[
|
||||||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
||||||
],
|
],
|
||||||
data_dir: Annotated[
|
|
||||||
str,
|
|
||||||
"directory where the stock data is stored.",
|
|
||||||
],
|
|
||||||
online: Annotated[
|
|
||||||
bool,
|
|
||||||
"whether to use online tools to fetch data or offline tools. If True, will use online tools.",
|
|
||||||
] = False,
|
|
||||||
):
|
):
|
||||||
df = None
|
data = load_ohlcv(symbol, curr_date)
|
||||||
data = None
|
df = wrap(data)
|
||||||
|
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||||
if not online:
|
curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d")
|
||||||
try:
|
|
||||||
data = pd.read_csv(
|
|
||||||
os.path.join(
|
|
||||||
data_dir,
|
|
||||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
df = wrap(data)
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
|
|
||||||
else:
|
|
||||||
# Get today's date as YYYY-mm-dd to add to cache
|
|
||||||
today_date = pd.Timestamp.today()
|
|
||||||
curr_date = pd.to_datetime(curr_date)
|
|
||||||
|
|
||||||
end_date = today_date
|
|
||||||
start_date = today_date - pd.DateOffset(years=15)
|
|
||||||
start_date = start_date.strftime("%Y-%m-%d")
|
|
||||||
end_date = end_date.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
# Get config and ensure cache directory exists
|
|
||||||
config = get_config()
|
|
||||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
|
||||||
|
|
||||||
data_file = os.path.join(
|
|
||||||
config["data_cache_dir"],
|
|
||||||
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.exists(data_file):
|
|
||||||
data = pd.read_csv(data_file)
|
|
||||||
data["Date"] = pd.to_datetime(data["Date"])
|
|
||||||
else:
|
|
||||||
data = yf.download(
|
|
||||||
symbol,
|
|
||||||
start=start_date,
|
|
||||||
end=end_date,
|
|
||||||
multi_level_index=False,
|
|
||||||
progress=False,
|
|
||||||
auto_adjust=True,
|
|
||||||
)
|
|
||||||
data = data.reset_index()
|
|
||||||
data.to_csv(data_file, index=False)
|
|
||||||
|
|
||||||
df = wrap(data)
|
|
||||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
|
||||||
curr_date = curr_date.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
df[indicator] # trigger stockstats to calculate the indicator
|
df[indicator] # trigger stockstats to calculate the indicator
|
||||||
matching_rows = df[df["Date"].str.startswith(curr_date)]
|
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
|
||||||
|
|
||||||
if not matching_rows.empty:
|
if not matching_rows.empty:
|
||||||
indicator_value = matching_rows[indicator].values[0]
|
indicator_value = matching_rows[indicator].values[0]
|
||||||
|
|||||||
83
tradingagents/dataflows/stocktwits.py
Normal file
83
tradingagents/dataflows/stocktwits.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""StockTwits public symbol-stream fetcher.
|
||||||
|
|
||||||
|
StockTwits exposes a per-symbol message stream at
|
||||||
|
``api.stocktwits.com/api/2/streams/symbol/{ticker}.json`` that requires no
|
||||||
|
API key, no OAuth, and no registration. Each message includes a
|
||||||
|
user-labeled sentiment field (``Bullish``/``Bearish``/null), the message
|
||||||
|
body, timestamp, and posting user.
|
||||||
|
|
||||||
|
The function is deliberately self-contained: short timeout, graceful
|
||||||
|
degradation on any HTTP or parse failure, and a string return type so
|
||||||
|
the calling agent gets a uniform interface regardless of whether the
|
||||||
|
network call succeeded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.error import HTTPError, URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_API = "https://api.stocktwits.com/api/2/streams/symbol/{ticker}.json"
|
||||||
|
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_stocktwits_messages(ticker: str, limit: int = 30, timeout: float = 10.0) -> str:
|
||||||
|
"""Fetch recent StockTwits messages for ``ticker`` and return them as a
|
||||||
|
formatted plaintext block ready for prompt injection.
|
||||||
|
|
||||||
|
Returns a placeholder string when the endpoint is unreachable, the
|
||||||
|
symbol has no messages, or the response shape is unexpected — the
|
||||||
|
caller never has to special-case None or exceptions.
|
||||||
|
"""
|
||||||
|
url = _API.format(ticker=ticker.upper())
|
||||||
|
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
|
||||||
|
try:
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
|
||||||
|
logger.warning("StockTwits fetch failed for %s: %s", ticker, exc)
|
||||||
|
return f"<stocktwits unavailable: {type(exc).__name__}>"
|
||||||
|
|
||||||
|
messages = data.get("messages", []) if isinstance(data, dict) else []
|
||||||
|
if not messages:
|
||||||
|
return f"<no StockTwits messages found for ${ticker.upper()}>"
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
bullish = bearish = unlabeled = 0
|
||||||
|
for m in messages[:limit]:
|
||||||
|
created = m.get("created_at", "")
|
||||||
|
user = (m.get("user") or {}).get("username", "?")
|
||||||
|
entities = m.get("entities") or {}
|
||||||
|
sentiment_obj = entities.get("sentiment") or {}
|
||||||
|
sentiment = sentiment_obj.get("basic") if isinstance(sentiment_obj, dict) else None
|
||||||
|
body = (m.get("body") or "").replace("\n", " ").strip()
|
||||||
|
if len(body) > 280:
|
||||||
|
body = body[:280] + "…"
|
||||||
|
|
||||||
|
if sentiment == "Bullish":
|
||||||
|
bullish += 1
|
||||||
|
tag = "Bullish"
|
||||||
|
elif sentiment == "Bearish":
|
||||||
|
bearish += 1
|
||||||
|
tag = "Bearish"
|
||||||
|
else:
|
||||||
|
unlabeled += 1
|
||||||
|
tag = "no-label"
|
||||||
|
lines.append(f"[{created} · @{user} · {tag}] {body}")
|
||||||
|
|
||||||
|
total = bullish + bearish + unlabeled
|
||||||
|
bull_pct = round(100 * bullish / total) if total else 0
|
||||||
|
bear_pct = round(100 * bearish / total) if total else 0
|
||||||
|
summary = (
|
||||||
|
f"Bullish: {bullish} ({bull_pct}%) · "
|
||||||
|
f"Bearish: {bearish} ({bear_pct}%) · "
|
||||||
|
f"Unlabeled: {unlabeled} · "
|
||||||
|
f"Total: {total} most-recent messages"
|
||||||
|
)
|
||||||
|
return summary + "\n\n" + "\n".join(lines)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datetime import date, timedelta, datetime
|
from datetime import date, timedelta, datetime
|
||||||
@@ -6,9 +7,43 @@ from typing import Annotated
|
|||||||
|
|
||||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||||
|
|
||||||
|
# Tickers can contain letters, digits, dot, dash, underscore, and caret
|
||||||
|
# (for index symbols like ^GSPC). Anything else is rejected so the value
|
||||||
|
# never escapes a containing directory when interpolated into a path.
|
||||||
|
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
|
||||||
|
"""Validate ``value`` is safe to interpolate into a filesystem path.
|
||||||
|
|
||||||
|
Tickers come from user CLI input or from LLM tool calls, both of which
|
||||||
|
can be influenced by attacker-controlled content (e.g. prompt injection
|
||||||
|
embedded in fetched news). Without validation, a value like
|
||||||
|
``"../../../etc/foo"`` flows into ``os.path.join`` / ``Path /`` and
|
||||||
|
escapes the configured cache, checkpoint, or results directory.
|
||||||
|
|
||||||
|
Returns ``value`` unchanged when it matches the allowed pattern; raises
|
||||||
|
``ValueError`` otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str) or not value:
|
||||||
|
raise ValueError(f"ticker must be a non-empty string, got {value!r}")
|
||||||
|
if len(value) > max_len:
|
||||||
|
raise ValueError(f"ticker exceeds {max_len} chars: {value!r}")
|
||||||
|
if not _TICKER_PATH_RE.fullmatch(value):
|
||||||
|
raise ValueError(
|
||||||
|
f"ticker contains characters not allowed in a filesystem path: {value!r}"
|
||||||
|
)
|
||||||
|
# The regex above allows '.', so values like '.', '..', '...' would pass,
|
||||||
|
# and as a path component they traverse the parent directory. Reject any
|
||||||
|
# value that's only dots.
|
||||||
|
if set(value) == {"."}:
|
||||||
|
raise ValueError(f"ticker cannot consist solely of dots: {value!r}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
||||||
if save_path:
|
if save_path:
|
||||||
data.to_csv(save_path)
|
data.to_csv(save_path, encoding="utf-8")
|
||||||
print(f"{tag} saved to {save_path}")
|
print(f"{tag} saved to {save_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
422
tradingagents/dataflows/y_finance.py
Normal file
422
tradingagents/dataflows/y_finance.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
from datetime import datetime
|
||||||
|
from dateutil.relativedelta import relativedelta
|
||||||
|
import pandas as pd
|
||||||
|
import yfinance as yf
|
||||||
|
import os
|
||||||
|
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date
|
||||||
|
|
||||||
|
def get_YFin_data_online(
|
||||||
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||||
|
):
|
||||||
|
|
||||||
|
datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
|
||||||
|
# Create ticker object
|
||||||
|
ticker = yf.Ticker(symbol.upper())
|
||||||
|
|
||||||
|
# Fetch historical data for the specified date range
|
||||||
|
data = yf_retry(lambda: ticker.history(start=start_date, end=end_date))
|
||||||
|
|
||||||
|
# Check if data is empty
|
||||||
|
if data.empty:
|
||||||
|
return (
|
||||||
|
f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove timezone info from index for cleaner output
|
||||||
|
if data.index.tz is not None:
|
||||||
|
data.index = data.index.tz_localize(None)
|
||||||
|
|
||||||
|
# Round numerical values to 2 decimal places for cleaner display
|
||||||
|
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
|
||||||
|
for col in numeric_columns:
|
||||||
|
if col in data.columns:
|
||||||
|
data[col] = data[col].round(2)
|
||||||
|
|
||||||
|
# Convert DataFrame to CSV string
|
||||||
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
|
# Add header information
|
||||||
|
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
|
||||||
|
header += f"# Total records: {len(data)}\n"
|
||||||
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
|
return header + csv_string
|
||||||
|
|
||||||
|
def get_stock_stats_indicators_window(
|
||||||
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||||
|
curr_date: Annotated[
|
||||||
|
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||||
|
],
|
||||||
|
look_back_days: Annotated[int, "how many days to look back"],
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
best_ind_params = {
|
||||||
|
# Moving Averages
|
||||||
|
"close_50_sma": (
|
||||||
|
"50 SMA: A medium-term trend indicator. "
|
||||||
|
"Usage: Identify trend direction and serve as dynamic support/resistance. "
|
||||||
|
"Tips: It lags price; combine with faster indicators for timely signals."
|
||||||
|
),
|
||||||
|
"close_200_sma": (
|
||||||
|
"200 SMA: A long-term trend benchmark. "
|
||||||
|
"Usage: Confirm overall market trend and identify golden/death cross setups. "
|
||||||
|
"Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries."
|
||||||
|
),
|
||||||
|
"close_10_ema": (
|
||||||
|
"10 EMA: A responsive short-term average. "
|
||||||
|
"Usage: Capture quick shifts in momentum and potential entry points. "
|
||||||
|
"Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals."
|
||||||
|
),
|
||||||
|
# MACD Related
|
||||||
|
"macd": (
|
||||||
|
"MACD: Computes momentum via differences of EMAs. "
|
||||||
|
"Usage: Look for crossovers and divergence as signals of trend changes. "
|
||||||
|
"Tips: Confirm with other indicators in low-volatility or sideways markets."
|
||||||
|
),
|
||||||
|
"macds": (
|
||||||
|
"MACD Signal: An EMA smoothing of the MACD line. "
|
||||||
|
"Usage: Use crossovers with the MACD line to trigger trades. "
|
||||||
|
"Tips: Should be part of a broader strategy to avoid false positives."
|
||||||
|
),
|
||||||
|
"macdh": (
|
||||||
|
"MACD Histogram: Shows the gap between the MACD line and its signal. "
|
||||||
|
"Usage: Visualize momentum strength and spot divergence early. "
|
||||||
|
"Tips: Can be volatile; complement with additional filters in fast-moving markets."
|
||||||
|
),
|
||||||
|
# Momentum Indicators
|
||||||
|
"rsi": (
|
||||||
|
"RSI: Measures momentum to flag overbought/oversold conditions. "
|
||||||
|
"Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. "
|
||||||
|
"Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis."
|
||||||
|
),
|
||||||
|
# Volatility Indicators
|
||||||
|
"boll": (
|
||||||
|
"Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. "
|
||||||
|
"Usage: Acts as a dynamic benchmark for price movement. "
|
||||||
|
"Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals."
|
||||||
|
),
|
||||||
|
"boll_ub": (
|
||||||
|
"Bollinger Upper Band: Typically 2 standard deviations above the middle line. "
|
||||||
|
"Usage: Signals potential overbought conditions and breakout zones. "
|
||||||
|
"Tips: Confirm signals with other tools; prices may ride the band in strong trends."
|
||||||
|
),
|
||||||
|
"boll_lb": (
|
||||||
|
"Bollinger Lower Band: Typically 2 standard deviations below the middle line. "
|
||||||
|
"Usage: Indicates potential oversold conditions. "
|
||||||
|
"Tips: Use additional analysis to avoid false reversal signals."
|
||||||
|
),
|
||||||
|
"atr": (
|
||||||
|
"ATR: Averages true range to measure volatility. "
|
||||||
|
"Usage: Set stop-loss levels and adjust position sizes based on current market volatility. "
|
||||||
|
"Tips: It's a reactive measure, so use it as part of a broader risk management strategy."
|
||||||
|
),
|
||||||
|
# Volume-Based Indicators
|
||||||
|
"vwma": (
|
||||||
|
"VWMA: A moving average weighted by volume. "
|
||||||
|
"Usage: Confirm trends by integrating price action with volume data. "
|
||||||
|
"Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses."
|
||||||
|
),
|
||||||
|
"mfi": (
|
||||||
|
"MFI: The Money Flow Index is a momentum indicator that uses both price and volume to measure buying and selling pressure. "
|
||||||
|
"Usage: Identify overbought (>80) or oversold (<20) conditions and confirm the strength of trends or reversals. "
|
||||||
|
"Tips: Use alongside RSI or MACD to confirm signals; divergence between price and MFI can indicate potential reversals."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if indicator not in best_ind_params:
|
||||||
|
raise ValueError(
|
||||||
|
f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
end_date = curr_date
|
||||||
|
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
before = curr_date_dt - relativedelta(days=look_back_days)
|
||||||
|
|
||||||
|
# Optimized: Get stock data once and calculate indicators for all dates
|
||||||
|
try:
|
||||||
|
indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date)
|
||||||
|
|
||||||
|
# Generate the date range we need
|
||||||
|
current_dt = curr_date_dt
|
||||||
|
date_values = []
|
||||||
|
|
||||||
|
while current_dt >= before:
|
||||||
|
date_str = current_dt.strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
# Look up the indicator value for this date
|
||||||
|
if date_str in indicator_data:
|
||||||
|
indicator_value = indicator_data[date_str]
|
||||||
|
else:
|
||||||
|
indicator_value = "N/A: Not a trading day (weekend or holiday)"
|
||||||
|
|
||||||
|
date_values.append((date_str, indicator_value))
|
||||||
|
current_dt = current_dt - relativedelta(days=1)
|
||||||
|
|
||||||
|
# Build the result string
|
||||||
|
ind_string = ""
|
||||||
|
for date_str, value in date_values:
|
||||||
|
ind_string += f"{date_str}: {value}\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting bulk stockstats data: {e}")
|
||||||
|
# Fallback to original implementation if bulk method fails
|
||||||
|
ind_string = ""
|
||||||
|
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
while curr_date_dt >= before:
|
||||||
|
indicator_value = get_stockstats_indicator(
|
||||||
|
symbol, indicator, curr_date_dt.strftime("%Y-%m-%d")
|
||||||
|
)
|
||||||
|
ind_string += f"{curr_date_dt.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||||
|
curr_date_dt = curr_date_dt - relativedelta(days=1)
|
||||||
|
|
||||||
|
result_str = (
|
||||||
|
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
|
||||||
|
+ ind_string
|
||||||
|
+ "\n\n"
|
||||||
|
+ best_ind_params.get(indicator, "No description available.")
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_str
|
||||||
|
|
||||||
|
|
||||||
|
def _get_stock_stats_bulk(
|
||||||
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
indicator: Annotated[str, "technical indicator to calculate"],
|
||||||
|
curr_date: Annotated[str, "current date for reference"]
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Optimized bulk calculation of stock stats indicators.
|
||||||
|
Fetches data once and calculates indicator for all available dates.
|
||||||
|
Returns dict mapping date strings to indicator values.
|
||||||
|
"""
|
||||||
|
from stockstats import wrap
|
||||||
|
|
||||||
|
data = load_ohlcv(symbol, curr_date)
|
||||||
|
df = wrap(data)
|
||||||
|
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Calculate the indicator for all rows at once
|
||||||
|
df[indicator] # This triggers stockstats to calculate the indicator
|
||||||
|
|
||||||
|
# Create a dictionary mapping date strings to indicator values
|
||||||
|
result_dict = {}
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
date_str = row["Date"]
|
||||||
|
indicator_value = row[indicator]
|
||||||
|
|
||||||
|
# Handle NaN/None values
|
||||||
|
if pd.isna(indicator_value):
|
||||||
|
result_dict[date_str] = "N/A"
|
||||||
|
else:
|
||||||
|
result_dict[date_str] = str(indicator_value)
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_stockstats_indicator(
|
||||||
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||||
|
curr_date: Annotated[
|
||||||
|
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||||
|
],
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
curr_date = curr_date_dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
try:
|
||||||
|
indicator_value = StockstatsUtils.get_stock_stats(
|
||||||
|
symbol,
|
||||||
|
indicator,
|
||||||
|
curr_date,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return str(indicator_value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fundamentals(
|
||||||
|
ticker: Annotated[str, "ticker symbol of the company"],
|
||||||
|
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
||||||
|
):
|
||||||
|
"""Get company fundamentals overview from yfinance."""
|
||||||
|
try:
|
||||||
|
ticker_obj = yf.Ticker(ticker.upper())
|
||||||
|
info = yf_retry(lambda: ticker_obj.info)
|
||||||
|
|
||||||
|
if not info:
|
||||||
|
return f"No fundamentals data found for symbol '{ticker}'"
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
("Name", info.get("longName")),
|
||||||
|
("Sector", info.get("sector")),
|
||||||
|
("Industry", info.get("industry")),
|
||||||
|
("Market Cap", info.get("marketCap")),
|
||||||
|
("PE Ratio (TTM)", info.get("trailingPE")),
|
||||||
|
("Forward PE", info.get("forwardPE")),
|
||||||
|
("PEG Ratio", info.get("pegRatio")),
|
||||||
|
("Price to Book", info.get("priceToBook")),
|
||||||
|
("EPS (TTM)", info.get("trailingEps")),
|
||||||
|
("Forward EPS", info.get("forwardEps")),
|
||||||
|
("Dividend Yield", info.get("dividendYield")),
|
||||||
|
("Beta", info.get("beta")),
|
||||||
|
("52 Week High", info.get("fiftyTwoWeekHigh")),
|
||||||
|
("52 Week Low", info.get("fiftyTwoWeekLow")),
|
||||||
|
("50 Day Average", info.get("fiftyDayAverage")),
|
||||||
|
("200 Day Average", info.get("twoHundredDayAverage")),
|
||||||
|
("Revenue (TTM)", info.get("totalRevenue")),
|
||||||
|
("Gross Profit", info.get("grossProfits")),
|
||||||
|
("EBITDA", info.get("ebitda")),
|
||||||
|
("Net Income", info.get("netIncomeToCommon")),
|
||||||
|
("Profit Margin", info.get("profitMargins")),
|
||||||
|
("Operating Margin", info.get("operatingMargins")),
|
||||||
|
("Return on Equity", info.get("returnOnEquity")),
|
||||||
|
("Return on Assets", info.get("returnOnAssets")),
|
||||||
|
("Debt to Equity", info.get("debtToEquity")),
|
||||||
|
("Current Ratio", info.get("currentRatio")),
|
||||||
|
("Book Value", info.get("bookValue")),
|
||||||
|
("Free Cash Flow", info.get("freeCashflow")),
|
||||||
|
]
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for label, value in fields:
|
||||||
|
if value is not None:
|
||||||
|
lines.append(f"{label}: {value}")
|
||||||
|
|
||||||
|
header = f"# Company Fundamentals for {ticker.upper()}\n"
|
||||||
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
|
return header + "\n".join(lines)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error retrieving fundamentals for {ticker}: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_balance_sheet(
|
||||||
|
ticker: Annotated[str, "ticker symbol of the company"],
|
||||||
|
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||||
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
|
):
|
||||||
|
"""Get balance sheet data from yfinance."""
|
||||||
|
try:
|
||||||
|
ticker_obj = yf.Ticker(ticker.upper())
|
||||||
|
|
||||||
|
if freq.lower() == "quarterly":
|
||||||
|
data = yf_retry(lambda: ticker_obj.quarterly_balance_sheet)
|
||||||
|
else:
|
||||||
|
data = yf_retry(lambda: ticker_obj.balance_sheet)
|
||||||
|
|
||||||
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
|
if data.empty:
|
||||||
|
return f"No balance sheet data found for symbol '{ticker}'"
|
||||||
|
|
||||||
|
# Convert to CSV string for consistency with other functions
|
||||||
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
|
# Add header information
|
||||||
|
header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n"
|
||||||
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
|
return header + csv_string
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error retrieving balance sheet for {ticker}: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cashflow(
|
||||||
|
ticker: Annotated[str, "ticker symbol of the company"],
|
||||||
|
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||||
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
|
):
|
||||||
|
"""Get cash flow data from yfinance."""
|
||||||
|
try:
|
||||||
|
ticker_obj = yf.Ticker(ticker.upper())
|
||||||
|
|
||||||
|
if freq.lower() == "quarterly":
|
||||||
|
data = yf_retry(lambda: ticker_obj.quarterly_cashflow)
|
||||||
|
else:
|
||||||
|
data = yf_retry(lambda: ticker_obj.cashflow)
|
||||||
|
|
||||||
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
|
if data.empty:
|
||||||
|
return f"No cash flow data found for symbol '{ticker}'"
|
||||||
|
|
||||||
|
# Convert to CSV string for consistency with other functions
|
||||||
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
|
# Add header information
|
||||||
|
header = f"# Cash Flow data for {ticker.upper()} ({freq})\n"
|
||||||
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
|
return header + csv_string
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error retrieving cash flow for {ticker}: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_income_statement(
|
||||||
|
ticker: Annotated[str, "ticker symbol of the company"],
|
||||||
|
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||||
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
|
):
|
||||||
|
"""Get income statement data from yfinance."""
|
||||||
|
try:
|
||||||
|
ticker_obj = yf.Ticker(ticker.upper())
|
||||||
|
|
||||||
|
if freq.lower() == "quarterly":
|
||||||
|
data = yf_retry(lambda: ticker_obj.quarterly_income_stmt)
|
||||||
|
else:
|
||||||
|
data = yf_retry(lambda: ticker_obj.income_stmt)
|
||||||
|
|
||||||
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
|
if data.empty:
|
||||||
|
return f"No income statement data found for symbol '{ticker}'"
|
||||||
|
|
||||||
|
# Convert to CSV string for consistency with other functions
|
||||||
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
|
# Add header information
|
||||||
|
header = f"# Income Statement data for {ticker.upper()} ({freq})\n"
|
||||||
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
|
return header + csv_string
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error retrieving income statement for {ticker}: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_insider_transactions(
|
||||||
|
ticker: Annotated[str, "ticker symbol of the company"]
|
||||||
|
):
|
||||||
|
"""Get insider transactions data from yfinance."""
|
||||||
|
try:
|
||||||
|
ticker_obj = yf.Ticker(ticker.upper())
|
||||||
|
data = yf_retry(lambda: ticker_obj.insider_transactions)
|
||||||
|
|
||||||
|
if data is None or data.empty:
|
||||||
|
return f"No insider transactions data found for symbol '{ticker}'"
|
||||||
|
|
||||||
|
# Convert to CSV string for consistency with other functions
|
||||||
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
|
# Add header information
|
||||||
|
header = f"# Insider Transactions data for {ticker.upper()}\n"
|
||||||
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
|
return header + csv_string
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
# gets data/stats
|
|
||||||
|
|
||||||
import yfinance as yf
|
|
||||||
from typing import Annotated, Callable, Any, Optional
|
|
||||||
from pandas import DataFrame
|
|
||||||
import pandas as pd
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
from .utils import save_output, SavePathType, decorate_all_methods
|
|
||||||
|
|
||||||
|
|
||||||
def init_ticker(func: Callable) -> Callable:
|
|
||||||
"""Decorator to initialize yf.Ticker and pass it to the function."""
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any:
|
|
||||||
ticker = yf.Ticker(symbol)
|
|
||||||
return func(ticker, *args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@decorate_all_methods(init_ticker)
|
|
||||||
class YFinanceUtils:
|
|
||||||
|
|
||||||
def get_stock_data(
|
|
||||||
symbol: Annotated[str, "ticker symbol"],
|
|
||||||
start_date: Annotated[
|
|
||||||
str, "start date for retrieving stock price data, YYYY-mm-dd"
|
|
||||||
],
|
|
||||||
end_date: Annotated[
|
|
||||||
str, "end date for retrieving stock price data, YYYY-mm-dd"
|
|
||||||
],
|
|
||||||
save_path: SavePathType = None,
|
|
||||||
) -> DataFrame:
|
|
||||||
"""retrieve stock price data for designated ticker symbol"""
|
|
||||||
ticker = symbol
|
|
||||||
# add one day to the end_date so that the data range is inclusive
|
|
||||||
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
|
|
||||||
end_date = end_date.strftime("%Y-%m-%d")
|
|
||||||
stock_data = ticker.history(start=start_date, end=end_date)
|
|
||||||
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
|
|
||||||
return stock_data
|
|
||||||
|
|
||||||
def get_stock_info(
|
|
||||||
symbol: Annotated[str, "ticker symbol"],
|
|
||||||
) -> dict:
|
|
||||||
"""Fetches and returns latest stock information."""
|
|
||||||
ticker = symbol
|
|
||||||
stock_info = ticker.info
|
|
||||||
return stock_info
|
|
||||||
|
|
||||||
def get_company_info(
|
|
||||||
symbol: Annotated[str, "ticker symbol"],
|
|
||||||
save_path: Optional[str] = None,
|
|
||||||
) -> DataFrame:
|
|
||||||
"""Fetches and returns company information as a DataFrame."""
|
|
||||||
ticker = symbol
|
|
||||||
info = ticker.info
|
|
||||||
company_info = {
|
|
||||||
"Company Name": info.get("shortName", "N/A"),
|
|
||||||
"Industry": info.get("industry", "N/A"),
|
|
||||||
"Sector": info.get("sector", "N/A"),
|
|
||||||
"Country": info.get("country", "N/A"),
|
|
||||||
"Website": info.get("website", "N/A"),
|
|
||||||
}
|
|
||||||
company_info_df = DataFrame([company_info])
|
|
||||||
if save_path:
|
|
||||||
company_info_df.to_csv(save_path)
|
|
||||||
print(f"Company info for {ticker.ticker} saved to {save_path}")
|
|
||||||
return company_info_df
|
|
||||||
|
|
||||||
def get_stock_dividends(
|
|
||||||
symbol: Annotated[str, "ticker symbol"],
|
|
||||||
save_path: Optional[str] = None,
|
|
||||||
) -> DataFrame:
|
|
||||||
"""Fetches and returns the latest dividends data as a DataFrame."""
|
|
||||||
ticker = symbol
|
|
||||||
dividends = ticker.dividends
|
|
||||||
if save_path:
|
|
||||||
dividends.to_csv(save_path)
|
|
||||||
print(f"Dividends for {ticker.ticker} saved to {save_path}")
|
|
||||||
return dividends
|
|
||||||
|
|
||||||
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
|
||||||
"""Fetches and returns the latest income statement of the company as a DataFrame."""
|
|
||||||
ticker = symbol
|
|
||||||
income_stmt = ticker.financials
|
|
||||||
return income_stmt
|
|
||||||
|
|
||||||
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
|
||||||
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
|
|
||||||
ticker = symbol
|
|
||||||
balance_sheet = ticker.balance_sheet
|
|
||||||
return balance_sheet
|
|
||||||
|
|
||||||
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
|
||||||
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
|
|
||||||
ticker = symbol
|
|
||||||
cash_flow = ticker.cashflow
|
|
||||||
return cash_flow
|
|
||||||
|
|
||||||
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
|
|
||||||
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
|
|
||||||
ticker = symbol
|
|
||||||
recommendations = ticker.recommendations
|
|
||||||
if recommendations.empty:
|
|
||||||
return None, 0 # No recommendations available
|
|
||||||
|
|
||||||
# Assuming 'period' column exists and needs to be excluded
|
|
||||||
row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary
|
|
||||||
|
|
||||||
# Find the maximum voting result
|
|
||||||
max_votes = row_0.max()
|
|
||||||
majority_voting_result = row_0[row_0 == max_votes].index.tolist()
|
|
||||||
|
|
||||||
return majority_voting_result[0], max_votes
|
|
||||||
202
tradingagents/dataflows/yfinance_news.py
Normal file
202
tradingagents/dataflows/yfinance_news.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""yfinance-based news data fetching functions."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import yfinance as yf
|
||||||
|
from datetime import datetime
|
||||||
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
|
from .stockstats_utils import yf_retry
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_article_data(article: dict) -> dict:
|
||||||
|
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
|
||||||
|
# Handle nested content structure
|
||||||
|
if "content" in article:
|
||||||
|
content = article["content"]
|
||||||
|
title = content.get("title", "No title")
|
||||||
|
summary = content.get("summary", "")
|
||||||
|
provider = content.get("provider", {})
|
||||||
|
publisher = provider.get("displayName", "Unknown")
|
||||||
|
|
||||||
|
# Get URL from canonicalUrl or clickThroughUrl
|
||||||
|
url_obj = content.get("canonicalUrl") or content.get("clickThroughUrl") or {}
|
||||||
|
link = url_obj.get("url", "")
|
||||||
|
|
||||||
|
# Get publish date
|
||||||
|
pub_date_str = content.get("pubDate", "")
|
||||||
|
pub_date = None
|
||||||
|
if pub_date_str:
|
||||||
|
try:
|
||||||
|
pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00"))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"title": title,
|
||||||
|
"summary": summary,
|
||||||
|
"publisher": publisher,
|
||||||
|
"link": link,
|
||||||
|
"pub_date": pub_date,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Fallback for flat structure
|
||||||
|
return {
|
||||||
|
"title": article.get("title", "No title"),
|
||||||
|
"summary": article.get("summary", ""),
|
||||||
|
"publisher": article.get("publisher", "Unknown"),
|
||||||
|
"link": article.get("link", ""),
|
||||||
|
"pub_date": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_news_yfinance(
|
||||||
|
ticker: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve news for a specific stock ticker using yfinance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol (e.g., "AAPL")
|
||||||
|
start_date: Start date in yyyy-mm-dd format
|
||||||
|
end_date: End date in yyyy-mm-dd format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string containing news articles
|
||||||
|
"""
|
||||||
|
article_limit = get_config()["news_article_limit"]
|
||||||
|
try:
|
||||||
|
stock = yf.Ticker(ticker)
|
||||||
|
news = yf_retry(lambda: stock.get_news(count=article_limit))
|
||||||
|
|
||||||
|
if not news:
|
||||||
|
return f"No news found for {ticker}"
|
||||||
|
|
||||||
|
# Parse date range for filtering
|
||||||
|
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
|
||||||
|
news_str = ""
|
||||||
|
filtered_count = 0
|
||||||
|
|
||||||
|
for article in news:
|
||||||
|
data = _extract_article_data(article)
|
||||||
|
|
||||||
|
# Filter by date if publish time is available
|
||||||
|
if data["pub_date"]:
|
||||||
|
pub_date_naive = data["pub_date"].replace(tzinfo=None)
|
||||||
|
if not (start_dt <= pub_date_naive <= end_dt + relativedelta(days=1)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
news_str += f"### {data['title']} (source: {data['publisher']})\n"
|
||||||
|
if data["summary"]:
|
||||||
|
news_str += f"{data['summary']}\n"
|
||||||
|
if data["link"]:
|
||||||
|
news_str += f"Link: {data['link']}\n"
|
||||||
|
news_str += "\n"
|
||||||
|
filtered_count += 1
|
||||||
|
|
||||||
|
if filtered_count == 0:
|
||||||
|
return f"No news found for {ticker} between {start_date} and {end_date}"
|
||||||
|
|
||||||
|
return f"## {ticker} News, from {start_date} to {end_date}:\n\n{news_str}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error fetching news for {ticker}: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_news_yfinance(
|
||||||
|
curr_date: str,
|
||||||
|
look_back_days: Optional[int] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve global/macro economic news using yfinance Search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
curr_date: Current date in yyyy-mm-dd format
|
||||||
|
look_back_days: Number of days to look back. ``None`` falls back to
|
||||||
|
``global_news_lookback_days`` from the active config.
|
||||||
|
limit: Maximum number of articles to return. ``None`` falls back to
|
||||||
|
``global_news_article_limit`` from the active config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string containing global news articles
|
||||||
|
"""
|
||||||
|
config = get_config()
|
||||||
|
if look_back_days is None:
|
||||||
|
look_back_days = config["global_news_lookback_days"]
|
||||||
|
if limit is None:
|
||||||
|
limit = config["global_news_article_limit"]
|
||||||
|
search_queries = config["global_news_queries"]
|
||||||
|
|
||||||
|
all_news = []
|
||||||
|
seen_titles = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for query in search_queries:
|
||||||
|
search = yf_retry(lambda q=query: yf.Search(
|
||||||
|
query=q,
|
||||||
|
news_count=limit,
|
||||||
|
enable_fuzzy_query=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
if search.news:
|
||||||
|
for article in search.news:
|
||||||
|
# Handle both flat and nested structures
|
||||||
|
if "content" in article:
|
||||||
|
data = _extract_article_data(article)
|
||||||
|
title = data["title"]
|
||||||
|
else:
|
||||||
|
title = article.get("title", "")
|
||||||
|
|
||||||
|
# Deduplicate by title
|
||||||
|
if title and title not in seen_titles:
|
||||||
|
seen_titles.add(title)
|
||||||
|
all_news.append(article)
|
||||||
|
|
||||||
|
if len(all_news) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not all_news:
|
||||||
|
return f"No global news found for {curr_date}"
|
||||||
|
|
||||||
|
# Calculate date range
|
||||||
|
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
start_dt = curr_dt - relativedelta(days=look_back_days)
|
||||||
|
start_date = start_dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
news_str = ""
|
||||||
|
for article in all_news[:limit]:
|
||||||
|
# Handle both flat and nested structures
|
||||||
|
if "content" in article:
|
||||||
|
data = _extract_article_data(article)
|
||||||
|
# Skip articles published after curr_date (look-ahead guard)
|
||||||
|
if data.get("pub_date"):
|
||||||
|
pub_naive = data["pub_date"].replace(tzinfo=None) if hasattr(data["pub_date"], "replace") else data["pub_date"]
|
||||||
|
if pub_naive > curr_dt + relativedelta(days=1):
|
||||||
|
continue
|
||||||
|
title = data["title"]
|
||||||
|
publisher = data["publisher"]
|
||||||
|
link = data["link"]
|
||||||
|
summary = data["summary"]
|
||||||
|
else:
|
||||||
|
title = article.get("title", "No title")
|
||||||
|
publisher = article.get("publisher", "Unknown")
|
||||||
|
link = article.get("link", "")
|
||||||
|
summary = ""
|
||||||
|
|
||||||
|
news_str += f"### {title} (source: {publisher})\n"
|
||||||
|
if summary:
|
||||||
|
news_str += f"{summary}\n"
|
||||||
|
if link:
|
||||||
|
news_str += f"Link: {link}\n"
|
||||||
|
news_str += "\n"
|
||||||
|
|
||||||
|
return f"## Global Market News, from {start_date} to {curr_date}:\n\n{news_str}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error fetching global news: {str(e)}"
|
||||||
@@ -1,19 +1,121 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
DEFAULT_CONFIG = {
|
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
|
||||||
|
|
||||||
|
# Single source of truth for env-var → config-key overrides. To expose
|
||||||
|
# a new config key for environment-based override, add a row here — no
|
||||||
|
# entry-point script changes required. Coercion is driven by the type
|
||||||
|
# of the existing default, so users can keep writing plain strings in
|
||||||
|
# their .env file.
|
||||||
|
_ENV_OVERRIDES = {
|
||||||
|
"TRADINGAGENTS_LLM_PROVIDER": "llm_provider",
|
||||||
|
"TRADINGAGENTS_DEEP_THINK_LLM": "deep_think_llm",
|
||||||
|
"TRADINGAGENTS_QUICK_THINK_LLM": "quick_think_llm",
|
||||||
|
"TRADINGAGENTS_LLM_BACKEND_URL": "backend_url",
|
||||||
|
"TRADINGAGENTS_OUTPUT_LANGUAGE": "output_language",
|
||||||
|
"TRADINGAGENTS_MAX_DEBATE_ROUNDS": "max_debate_rounds",
|
||||||
|
"TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds",
|
||||||
|
"TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled",
|
||||||
|
"TRADINGAGENTS_BENCHMARK_TICKER": "benchmark_ticker",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce(value: str, reference):
|
||||||
|
"""Coerce env-var string to the type of the existing default value."""
|
||||||
|
if isinstance(reference, bool):
|
||||||
|
return value.strip().lower() in ("true", "1", "yes", "on")
|
||||||
|
if isinstance(reference, int) and not isinstance(reference, bool):
|
||||||
|
return int(value)
|
||||||
|
if isinstance(reference, float):
|
||||||
|
return float(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_env_overrides(config: dict) -> dict:
|
||||||
|
"""Apply TRADINGAGENTS_* env vars to the config dict in-place."""
|
||||||
|
for env_var, key in _ENV_OVERRIDES.items():
|
||||||
|
raw = os.environ.get(env_var)
|
||||||
|
if raw is None or raw == "":
|
||||||
|
continue
|
||||||
|
config[key] = _coerce(raw, config.get(key))
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = _apply_env_overrides({
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
||||||
"data_cache_dir": os.path.join(
|
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"memory_log_path": os.getenv("TRADINGAGENTS_MEMORY_LOG_PATH", os.path.join(_TRADINGAGENTS_HOME, "memory", "trading_memory.md")),
|
||||||
"dataflows/data_cache",
|
# Optional cap on the number of resolved memory log entries. When set,
|
||||||
),
|
# the oldest resolved entries are pruned once this limit is exceeded.
|
||||||
|
# Pending entries are never pruned. None disables rotation entirely.
|
||||||
|
"memory_log_max_entries": None,
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"deep_think_llm": "o4-mini",
|
"llm_provider": "openai",
|
||||||
"quick_think_llm": "gpt-4o-mini",
|
"deep_think_llm": "gpt-5.4",
|
||||||
|
"quick_think_llm": "gpt-5.4-mini",
|
||||||
|
# When None, each provider's client falls back to its own default endpoint
|
||||||
|
# (api.openai.com for OpenAI, generativelanguage.googleapis.com for Gemini, ...).
|
||||||
|
# The CLI overrides this per provider when the user picks one. Keeping a
|
||||||
|
# provider-specific URL here would leak (e.g. OpenAI's /v1 was previously
|
||||||
|
# being forwarded to Gemini, producing malformed request URLs).
|
||||||
|
"backend_url": None,
|
||||||
|
# Provider-specific thinking configuration
|
||||||
|
"google_thinking_level": None, # "high", "minimal", etc.
|
||||||
|
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||||
|
"anthropic_effort": None, # "high", "medium", "low"
|
||||||
|
# Checkpoint/resume: when True, LangGraph saves state after each node
|
||||||
|
# so a crashed run can resume from the last successful step.
|
||||||
|
"checkpoint_enabled": False,
|
||||||
|
# Output language for analyst reports and final decision
|
||||||
|
# Internal agent debate stays in English for reasoning quality
|
||||||
|
"output_language": "English",
|
||||||
# Debate and discussion settings
|
# Debate and discussion settings
|
||||||
"max_debate_rounds": 1,
|
"max_debate_rounds": 1,
|
||||||
"max_risk_discuss_rounds": 1,
|
"max_risk_discuss_rounds": 1,
|
||||||
"max_recur_limit": 100,
|
"max_recur_limit": 100,
|
||||||
# Tool settings
|
# News / data fetching parameters
|
||||||
"online_tools": True,
|
# Increase for longer lookback strategies or to broaden macro coverage;
|
||||||
}
|
# decrease to reduce token usage in agent prompts.
|
||||||
|
"news_article_limit": 20, # max articles per ticker (ticker-news)
|
||||||
|
"global_news_article_limit": 10, # max articles for global/macro news
|
||||||
|
"global_news_lookback_days": 7, # macro news lookback window
|
||||||
|
# Search queries used by get_global_news for macro headlines. Extend or
|
||||||
|
# replace to broaden geographic / sector coverage.
|
||||||
|
"global_news_queries": [
|
||||||
|
"Federal Reserve interest rates inflation",
|
||||||
|
"S&P 500 earnings GDP economic outlook",
|
||||||
|
"geopolitical risk trade war sanctions",
|
||||||
|
"ECB Bank of England BOJ central bank policy",
|
||||||
|
"oil commodities supply chain energy",
|
||||||
|
],
|
||||||
|
# Data vendor configuration
|
||||||
|
# Category-level configuration (default for all tools in category)
|
||||||
|
"data_vendors": {
|
||||||
|
"core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance
|
||||||
|
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
|
||||||
|
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||||
|
"news_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||||
|
},
|
||||||
|
# Tool-level configuration (takes precedence over category-level)
|
||||||
|
"tool_vendors": {
|
||||||
|
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
||||||
|
},
|
||||||
|
# Benchmark for alpha calculation in the reflection layer.
|
||||||
|
# ``benchmark_ticker`` (when set) overrides the suffix map for all
|
||||||
|
# tickers; leave it None to use ``benchmark_map`` for auto-detection
|
||||||
|
# based on the ticker's exchange suffix. SPY remains the US default
|
||||||
|
# so the reflection label keeps reading "Alpha vs SPY" for US tickers
|
||||||
|
# while non-US tickers get their regional index automatically.
|
||||||
|
"benchmark_ticker": None,
|
||||||
|
"benchmark_map": {
|
||||||
|
".NS": "^NSEI", # NSE India (Nifty 50)
|
||||||
|
".BO": "^BSESN", # BSE India (Sensex)
|
||||||
|
".T": "^N225", # Tokyo (Nikkei 225)
|
||||||
|
".HK": "^HSI", # Hong Kong (Hang Seng)
|
||||||
|
".L": "^FTSE", # London (FTSE 100)
|
||||||
|
".TO": "^GSPTSE", # Toronto (TSX Composite)
|
||||||
|
".AX": "^AXJO", # Australia (ASX 200)
|
||||||
|
"": "SPY", # default for US-listed tickers (no suffix)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|||||||
90
tradingagents/graph/checkpointer.py
Normal file
90
tradingagents/graph/checkpointer.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""LangGraph checkpoint support for resumable analysis runs.
|
||||||
|
|
||||||
|
Per-ticker SQLite databases so concurrent tickers don't contend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
|
|
||||||
|
|
||||||
|
def _db_path(data_dir: str | Path, ticker: str) -> Path:
|
||||||
|
"""Return the SQLite checkpoint DB path for a ticker."""
|
||||||
|
# Reject ticker values that would escape the checkpoints directory.
|
||||||
|
safe = safe_ticker_component(ticker).upper()
|
||||||
|
p = Path(data_dir) / "checkpoints"
|
||||||
|
p.mkdir(parents=True, exist_ok=True)
|
||||||
|
return p / f"{safe}.db"
|
||||||
|
|
||||||
|
|
||||||
|
def thread_id(ticker: str, date: str) -> str:
|
||||||
|
"""Deterministic thread ID for a ticker+date pair."""
|
||||||
|
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
|
||||||
|
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
conn = sqlite3.connect(str(db), check_same_thread=False)
|
||||||
|
try:
|
||||||
|
saver = SqliteSaver(conn)
|
||||||
|
saver.setup()
|
||||||
|
yield saver
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
|
||||||
|
"""Check whether a resumable checkpoint exists for ticker+date."""
|
||||||
|
return checkpoint_step(data_dir, ticker, date) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
|
||||||
|
"""Return the step number of the latest checkpoint, or None if none exists."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
if not db.exists():
|
||||||
|
return None
|
||||||
|
tid = thread_id(ticker, date)
|
||||||
|
with get_checkpointer(data_dir, ticker) as saver:
|
||||||
|
config = {"configurable": {"thread_id": tid}}
|
||||||
|
cp = saver.get_tuple(config)
|
||||||
|
if cp is None:
|
||||||
|
return None
|
||||||
|
return cp.metadata.get("step")
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all_checkpoints(data_dir: str | Path) -> int:
|
||||||
|
"""Remove all checkpoint DBs. Returns number of files deleted."""
|
||||||
|
cp_dir = Path(data_dir) / "checkpoints"
|
||||||
|
if not cp_dir.exists():
|
||||||
|
return 0
|
||||||
|
dbs = list(cp_dir.glob("*.db"))
|
||||||
|
for db in dbs:
|
||||||
|
db.unlink()
|
||||||
|
return len(dbs)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
|
||||||
|
"""Remove checkpoint for a specific ticker+date by deleting the thread's rows."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
if not db.exists():
|
||||||
|
return
|
||||||
|
tid = thread_id(ticker, date)
|
||||||
|
conn = sqlite3.connect(str(db))
|
||||||
|
try:
|
||||||
|
for table in ("writes", "checkpoints"):
|
||||||
|
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
@@ -59,9 +59,9 @@ class ConditionalLogic:
|
|||||||
if (
|
if (
|
||||||
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
|
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
|
||||||
): # 3 rounds of back-and-forth between 3 agents
|
): # 3 rounds of back-and-forth between 3 agents
|
||||||
return "Risk Judge"
|
return "Portfolio Manager"
|
||||||
if state["risk_debate_state"]["latest_speaker"].startswith("Risky"):
|
if state["risk_debate_state"]["latest_speaker"].startswith("Aggressive"):
|
||||||
return "Safe Analyst"
|
return "Conservative Analyst"
|
||||||
if state["risk_debate_state"]["latest_speaker"].startswith("Safe"):
|
if state["risk_debate_state"]["latest_speaker"].startswith("Conservative"):
|
||||||
return "Neutral Analyst"
|
return "Neutral Analyst"
|
||||||
return "Risky Analyst"
|
return "Aggressive Analyst"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# TradingAgents/graph/propagation.py
|
# TradingAgents/graph/propagation.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, List, Optional
|
||||||
from tradingagents.agents.utils.agent_states import (
|
from tradingagents.agents.utils.agent_states import (
|
||||||
AgentState,
|
AgentState,
|
||||||
InvestDebateState,
|
InvestDebateState,
|
||||||
@@ -16,22 +16,35 @@ class Propagator:
|
|||||||
self.max_recur_limit = max_recur_limit
|
self.max_recur_limit = max_recur_limit
|
||||||
|
|
||||||
def create_initial_state(
|
def create_initial_state(
|
||||||
self, company_name: str, trade_date: str
|
self, company_name: str, trade_date: str, past_context: str = ""
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create the initial state for the agent graph."""
|
"""Create the initial state for the agent graph."""
|
||||||
return {
|
return {
|
||||||
"messages": [("human", company_name)],
|
"messages": [("human", company_name)],
|
||||||
"company_of_interest": company_name,
|
"company_of_interest": company_name,
|
||||||
"trade_date": str(trade_date),
|
"trade_date": str(trade_date),
|
||||||
|
"past_context": past_context,
|
||||||
"investment_debate_state": InvestDebateState(
|
"investment_debate_state": InvestDebateState(
|
||||||
{"history": "", "current_response": "", "count": 0}
|
{
|
||||||
|
"bull_history": "",
|
||||||
|
"bear_history": "",
|
||||||
|
"history": "",
|
||||||
|
"current_response": "",
|
||||||
|
"judge_decision": "",
|
||||||
|
"count": 0,
|
||||||
|
}
|
||||||
),
|
),
|
||||||
"risk_debate_state": RiskDebateState(
|
"risk_debate_state": RiskDebateState(
|
||||||
{
|
{
|
||||||
|
"aggressive_history": "",
|
||||||
|
"conservative_history": "",
|
||||||
|
"neutral_history": "",
|
||||||
"history": "",
|
"history": "",
|
||||||
"current_risky_response": "",
|
"latest_speaker": "",
|
||||||
"current_safe_response": "",
|
"current_aggressive_response": "",
|
||||||
|
"current_conservative_response": "",
|
||||||
"current_neutral_response": "",
|
"current_neutral_response": "",
|
||||||
|
"judge_decision": "",
|
||||||
"count": 0,
|
"count": 0,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
@@ -41,9 +54,17 @@ class Propagator:
|
|||||||
"news_report": "",
|
"news_report": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_graph_args(self) -> Dict[str, Any]:
|
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
|
||||||
"""Get arguments for the graph invocation."""
|
"""Get arguments for the graph invocation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callbacks: Optional list of callback handlers for tool execution tracking.
|
||||||
|
Note: LLM callbacks are handled separately via LLM constructor.
|
||||||
|
"""
|
||||||
|
config = {"recursion_limit": self.max_recur_limit}
|
||||||
|
if callbacks:
|
||||||
|
config["callbacks"] = callbacks
|
||||||
return {
|
return {
|
||||||
"stream_mode": "values",
|
"stream_mode": "values",
|
||||||
"config": {"recursion_limit": self.max_recur_limit},
|
"config": config,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,121 +1,57 @@
|
|||||||
# TradingAgents/graph/reflection.py
|
# TradingAgents/graph/reflection.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
|
|
||||||
class Reflector:
|
class Reflector:
|
||||||
"""Handles reflection on decisions and updating memory."""
|
"""Handles reflection on trading decisions."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: Any):
|
||||||
"""Initialize the reflector with an LLM."""
|
"""Initialize the reflector with an LLM."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
self.reflection_system_prompt = self._get_reflection_prompt()
|
self.log_reflection_prompt = self._get_log_reflection_prompt()
|
||||||
|
|
||||||
def _get_reflection_prompt(self) -> str:
|
def _get_log_reflection_prompt(self) -> str:
|
||||||
"""Get the system prompt for reflection."""
|
"""Concise prompt for reflect_on_final_decision (Phase B log entries).
|
||||||
return """
|
|
||||||
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
|
|
||||||
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
|
|
||||||
|
|
||||||
1. Reasoning:
|
Produces 2-4 sentences of plain prose — compact enough to be re-injected
|
||||||
- For each trading decision, determine whether it was correct or incorrect. A correct decision results in an increase in returns, while an incorrect decision does the opposite.
|
into future agent prompts without bloating the context window.
|
||||||
- Analyze the contributing factors to each success or mistake. Consider:
|
"""
|
||||||
- Market intelligence.
|
return (
|
||||||
- Technical indicators.
|
"You are a trading analyst reviewing your own past decision now that the outcome is known.\n"
|
||||||
- Technical signals.
|
"Write exactly 2-4 sentences of plain prose (no bullets, no headers, no markdown).\n\n"
|
||||||
- Price movement analysis.
|
"Cover in order:\n"
|
||||||
- Overall market data analysis
|
"1. Was the directional call correct? (cite the alpha figure)\n"
|
||||||
- News analysis.
|
"2. Which part of the investment thesis held or failed?\n"
|
||||||
- Social media and sentiment analysis.
|
"3. One concrete lesson to apply to the next similar analysis.\n\n"
|
||||||
- Fundamental data analysis.
|
"Be specific and terse. Your output will be stored verbatim in a decision log "
|
||||||
- Weight the importance of each factor in the decision-making process.
|
"and re-read by future analysts, so every word must earn its place."
|
||||||
|
)
|
||||||
|
|
||||||
2. Improvement:
|
def reflect_on_final_decision(
|
||||||
- For any incorrect decisions, propose revisions to maximize returns.
|
self,
|
||||||
- Provide a detailed list of corrective actions or improvements, including specific recommendations (e.g., changing a decision from HOLD to BUY on a particular date).
|
final_decision: str,
|
||||||
|
raw_return: float,
|
||||||
3. Summary:
|
alpha_return: float,
|
||||||
- Summarize the lessons learned from the successes and mistakes.
|
benchmark_name: str = "SPY",
|
||||||
- Highlight how these lessons can be adapted for future trading scenarios and draw connections between similar situations to apply the knowledge gained.
|
|
||||||
|
|
||||||
4. Query:
|
|
||||||
- Extract key insights from the summary into a concise sentence of no more than 1000 tokens.
|
|
||||||
- Ensure the condensed sentence captures the essence of the lessons and reasoning for easy reference.
|
|
||||||
|
|
||||||
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
|
|
||||||
"""Extract the current market situation from the state."""
|
|
||||||
curr_market_report = current_state["market_report"]
|
|
||||||
curr_sentiment_report = current_state["sentiment_report"]
|
|
||||||
curr_news_report = current_state["news_report"]
|
|
||||||
curr_fundamentals_report = current_state["fundamentals_report"]
|
|
||||||
|
|
||||||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
|
||||||
|
|
||||||
def _reflect_on_component(
|
|
||||||
self, component_type: str, report: str, situation: str, returns_losses
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate reflection for a component."""
|
"""Single reflection call on the final trade decision with outcome context.
|
||||||
|
|
||||||
|
Used by Phase B deferred reflection. The final_trade_decision already
|
||||||
|
synthesises all analyst insights, so no separate market context is needed.
|
||||||
|
``benchmark_name`` is the label used for the alpha line (e.g. ``"SPY"``
|
||||||
|
for US tickers, ``"^N225"`` for ``.T`` listings); defaults to SPY for
|
||||||
|
callers that haven't been updated to thread the benchmark through.
|
||||||
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
("system", self.reflection_system_prompt),
|
("system", self.log_reflection_prompt),
|
||||||
(
|
(
|
||||||
"human",
|
"human",
|
||||||
f"Returns: {returns_losses}\n\nAnalysis/Decision: {report}\n\nObjective Market Reports for Reference: {situation}",
|
(
|
||||||
|
f"Raw return: {raw_return:+.1%}\n"
|
||||||
|
f"Alpha vs {benchmark_name}: {alpha_return:+.1%}\n\n"
|
||||||
|
f"Final Decision:\n{final_decision}"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
return self.quick_thinking_llm.invoke(messages).content
|
||||||
result = self.quick_thinking_llm.invoke(messages).content
|
|
||||||
return result
|
|
||||||
|
|
||||||
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
|
|
||||||
"""Reflect on bull researcher's analysis and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"BULL", bull_debate_history, situation, returns_losses
|
|
||||||
)
|
|
||||||
bull_memory.add_situations([(situation, result)])
|
|
||||||
|
|
||||||
def reflect_bear_researcher(self, current_state, returns_losses, bear_memory):
|
|
||||||
"""Reflect on bear researcher's analysis and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"BEAR", bear_debate_history, situation, returns_losses
|
|
||||||
)
|
|
||||||
bear_memory.add_situations([(situation, result)])
|
|
||||||
|
|
||||||
def reflect_trader(self, current_state, returns_losses, trader_memory):
|
|
||||||
"""Reflect on trader's decision and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
trader_decision = current_state["trader_investment_plan"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"TRADER", trader_decision, situation, returns_losses
|
|
||||||
)
|
|
||||||
trader_memory.add_situations([(situation, result)])
|
|
||||||
|
|
||||||
def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory):
|
|
||||||
"""Reflect on investment judge's decision and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"INVEST JUDGE", judge_decision, situation, returns_losses
|
|
||||||
)
|
|
||||||
invest_judge_memory.add_situations([(situation, result)])
|
|
||||||
|
|
||||||
def reflect_risk_manager(self, current_state, returns_losses, risk_manager_memory):
|
|
||||||
"""Reflect on risk manager's decision and update memory."""
|
|
||||||
situation = self._extract_current_situation(current_state)
|
|
||||||
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
|
||||||
|
|
||||||
result = self._reflect_on_component(
|
|
||||||
"RISK JUDGE", judge_decision, situation, returns_losses
|
|
||||||
)
|
|
||||||
risk_manager_memory.add_situations([(situation, result)])
|
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
# TradingAgents/graph/setup.py
|
# TradingAgents/graph/setup.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Any, Dict
|
||||||
from langchain_openai import ChatOpenAI
|
from langgraph.graph import END, START, StateGraph
|
||||||
from langgraph.graph import END, StateGraph, START
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
from tradingagents.agents.utils.agent_utils import Toolkit
|
|
||||||
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
|
|
||||||
@@ -17,27 +15,15 @@ class GraphSetup:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quick_thinking_llm: ChatOpenAI,
|
quick_thinking_llm: Any,
|
||||||
deep_thinking_llm: ChatOpenAI,
|
deep_thinking_llm: Any,
|
||||||
toolkit: Toolkit,
|
|
||||||
tool_nodes: Dict[str, ToolNode],
|
tool_nodes: Dict[str, ToolNode],
|
||||||
bull_memory,
|
|
||||||
bear_memory,
|
|
||||||
trader_memory,
|
|
||||||
invest_judge_memory,
|
|
||||||
risk_manager_memory,
|
|
||||||
conditional_logic: ConditionalLogic,
|
conditional_logic: ConditionalLogic,
|
||||||
):
|
):
|
||||||
"""Initialize with required components."""
|
"""Initialize with required components."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
self.deep_thinking_llm = deep_thinking_llm
|
self.deep_thinking_llm = deep_thinking_llm
|
||||||
self.toolkit = toolkit
|
|
||||||
self.tool_nodes = tool_nodes
|
self.tool_nodes = tool_nodes
|
||||||
self.bull_memory = bull_memory
|
|
||||||
self.bear_memory = bear_memory
|
|
||||||
self.trader_memory = trader_memory
|
|
||||||
self.invest_judge_memory = invest_judge_memory
|
|
||||||
self.risk_manager_memory = risk_manager_memory
|
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
@@ -62,51 +48,47 @@ class GraphSetup:
|
|||||||
|
|
||||||
if "market" in selected_analysts:
|
if "market" in selected_analysts:
|
||||||
analyst_nodes["market"] = create_market_analyst(
|
analyst_nodes["market"] = create_market_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit
|
self.quick_thinking_llm
|
||||||
)
|
)
|
||||||
delete_nodes["market"] = create_msg_delete()
|
delete_nodes["market"] = create_msg_delete()
|
||||||
tool_nodes["market"] = self.tool_nodes["market"]
|
tool_nodes["market"] = self.tool_nodes["market"]
|
||||||
|
|
||||||
if "social" in selected_analysts:
|
if "social" in selected_analysts:
|
||||||
analyst_nodes["social"] = create_social_media_analyst(
|
# "social" selector key preserved for back-compat with existing
|
||||||
self.quick_thinking_llm, self.toolkit
|
# user configs; the underlying agent has been renamed to
|
||||||
|
# sentiment_analyst (the old name advertised social-media data
|
||||||
|
# the agent never had access to — see issue #557).
|
||||||
|
analyst_nodes["social"] = create_sentiment_analyst(
|
||||||
|
self.quick_thinking_llm
|
||||||
)
|
)
|
||||||
delete_nodes["social"] = create_msg_delete()
|
delete_nodes["social"] = create_msg_delete()
|
||||||
tool_nodes["social"] = self.tool_nodes["social"]
|
tool_nodes["social"] = self.tool_nodes["social"]
|
||||||
|
|
||||||
if "news" in selected_analysts:
|
if "news" in selected_analysts:
|
||||||
analyst_nodes["news"] = create_news_analyst(
|
analyst_nodes["news"] = create_news_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit
|
self.quick_thinking_llm
|
||||||
)
|
)
|
||||||
delete_nodes["news"] = create_msg_delete()
|
delete_nodes["news"] = create_msg_delete()
|
||||||
tool_nodes["news"] = self.tool_nodes["news"]
|
tool_nodes["news"] = self.tool_nodes["news"]
|
||||||
|
|
||||||
if "fundamentals" in selected_analysts:
|
if "fundamentals" in selected_analysts:
|
||||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
||||||
self.quick_thinking_llm, self.toolkit
|
self.quick_thinking_llm
|
||||||
)
|
)
|
||||||
delete_nodes["fundamentals"] = create_msg_delete()
|
delete_nodes["fundamentals"] = create_msg_delete()
|
||||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||||
|
|
||||||
# Create researcher and manager nodes
|
# Create researcher and manager nodes
|
||||||
bull_researcher_node = create_bull_researcher(
|
bull_researcher_node = create_bull_researcher(self.quick_thinking_llm)
|
||||||
self.quick_thinking_llm, self.bull_memory
|
bear_researcher_node = create_bear_researcher(self.quick_thinking_llm)
|
||||||
)
|
research_manager_node = create_research_manager(self.deep_thinking_llm)
|
||||||
bear_researcher_node = create_bear_researcher(
|
trader_node = create_trader(self.quick_thinking_llm)
|
||||||
self.quick_thinking_llm, self.bear_memory
|
|
||||||
)
|
|
||||||
research_manager_node = create_research_manager(
|
|
||||||
self.deep_thinking_llm, self.invest_judge_memory
|
|
||||||
)
|
|
||||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
|
||||||
|
|
||||||
# Create risk analysis nodes
|
# Create risk analysis nodes
|
||||||
risky_analyst = create_risky_debator(self.quick_thinking_llm)
|
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm)
|
||||||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||||
safe_analyst = create_safe_debator(self.quick_thinking_llm)
|
conservative_analyst = create_conservative_debator(self.quick_thinking_llm)
|
||||||
risk_manager_node = create_risk_manager(
|
portfolio_manager_node = create_portfolio_manager(self.deep_thinking_llm)
|
||||||
self.deep_thinking_llm, self.risk_manager_memory
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create workflow
|
# Create workflow
|
||||||
workflow = StateGraph(AgentState)
|
workflow = StateGraph(AgentState)
|
||||||
@@ -124,10 +106,10 @@ class GraphSetup:
|
|||||||
workflow.add_node("Bear Researcher", bear_researcher_node)
|
workflow.add_node("Bear Researcher", bear_researcher_node)
|
||||||
workflow.add_node("Research Manager", research_manager_node)
|
workflow.add_node("Research Manager", research_manager_node)
|
||||||
workflow.add_node("Trader", trader_node)
|
workflow.add_node("Trader", trader_node)
|
||||||
workflow.add_node("Risky Analyst", risky_analyst)
|
workflow.add_node("Aggressive Analyst", aggressive_analyst)
|
||||||
workflow.add_node("Neutral Analyst", neutral_analyst)
|
workflow.add_node("Neutral Analyst", neutral_analyst)
|
||||||
workflow.add_node("Safe Analyst", safe_analyst)
|
workflow.add_node("Conservative Analyst", conservative_analyst)
|
||||||
workflow.add_node("Risk Judge", risk_manager_node)
|
workflow.add_node("Portfolio Manager", portfolio_manager_node)
|
||||||
|
|
||||||
# Define edges
|
# Define edges
|
||||||
# Start with the first analyst
|
# Start with the first analyst
|
||||||
@@ -173,33 +155,32 @@ class GraphSetup:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
workflow.add_edge("Research Manager", "Trader")
|
workflow.add_edge("Research Manager", "Trader")
|
||||||
workflow.add_edge("Trader", "Risky Analyst")
|
workflow.add_edge("Trader", "Aggressive Analyst")
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"Risky Analyst",
|
"Aggressive Analyst",
|
||||||
self.conditional_logic.should_continue_risk_analysis,
|
self.conditional_logic.should_continue_risk_analysis,
|
||||||
{
|
{
|
||||||
"Safe Analyst": "Safe Analyst",
|
"Conservative Analyst": "Conservative Analyst",
|
||||||
"Risk Judge": "Risk Judge",
|
"Portfolio Manager": "Portfolio Manager",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"Safe Analyst",
|
"Conservative Analyst",
|
||||||
self.conditional_logic.should_continue_risk_analysis,
|
self.conditional_logic.should_continue_risk_analysis,
|
||||||
{
|
{
|
||||||
"Neutral Analyst": "Neutral Analyst",
|
"Neutral Analyst": "Neutral Analyst",
|
||||||
"Risk Judge": "Risk Judge",
|
"Portfolio Manager": "Portfolio Manager",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"Neutral Analyst",
|
"Neutral Analyst",
|
||||||
self.conditional_logic.should_continue_risk_analysis,
|
self.conditional_logic.should_continue_risk_analysis,
|
||||||
{
|
{
|
||||||
"Risky Analyst": "Risky Analyst",
|
"Aggressive Analyst": "Aggressive Analyst",
|
||||||
"Risk Judge": "Risk Judge",
|
"Portfolio Manager": "Portfolio Manager",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow.add_edge("Risk Judge", END)
|
workflow.add_edge("Portfolio Manager", END)
|
||||||
|
|
||||||
# Compile and return
|
return workflow
|
||||||
return workflow.compile()
|
|
||||||
|
|||||||
@@ -1,31 +1,31 @@
|
|||||||
# TradingAgents/graph/signal_processing.py
|
"""Extract the 5-tier portfolio rating from the Portfolio Manager's decision.
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
The Portfolio Manager produces a typed ``PortfolioDecision`` via structured
|
||||||
|
output and renders it to markdown that always carries a ``**Rating**: X``
|
||||||
|
header (see :func:`tradingagents.agents.schemas.render_pm_decision`). The
|
||||||
|
deterministic heuristic in :mod:`tradingagents.agents.utils.rating` is more
|
||||||
|
than sufficient to extract that rating; no extra LLM call is needed.
|
||||||
|
|
||||||
|
This module exists for backwards compatibility with callers that expect a
|
||||||
|
``SignalProcessor.process_signal(text)`` interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.rating import parse_rating
|
||||||
|
|
||||||
|
|
||||||
class SignalProcessor:
|
class SignalProcessor:
|
||||||
"""Processes trading signals to extract actionable decisions."""
|
"""Read the 5-tier rating out of a Portfolio Manager decision."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: Any = None):
|
||||||
"""Initialize with an LLM for processing."""
|
# The LLM argument is accepted for backwards compatibility but no
|
||||||
|
# longer used: the PM's structured output guarantees the rating is
|
||||||
|
# parseable from the rendered markdown without a second LLM call.
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
|
|
||||||
def process_signal(self, full_signal: str) -> str:
|
def process_signal(self, full_signal: str) -> str:
|
||||||
"""
|
"""Return one of Buy / Overweight / Hold / Underweight / Sell."""
|
||||||
Process a full trading signal to extract the core decision.
|
return parse_rating(full_signal)
|
||||||
|
|
||||||
Args:
|
|
||||||
full_signal: Complete trading signal text
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Extracted decision (BUY, SELL, or HOLD)
|
|
||||||
"""
|
|
||||||
messages = [
|
|
||||||
(
|
|
||||||
"system",
|
|
||||||
"You are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
|
|
||||||
),
|
|
||||||
("human", full_signal),
|
|
||||||
]
|
|
||||||
|
|
||||||
return self.quick_thinking_llm.invoke(messages).content
|
|
||||||
|
|||||||
@@ -1,24 +1,45 @@
|
|||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from datetime import date
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
from typing import Dict, Any, Tuple, List, Optional
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
import yfinance as yf
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
from tradingagents.agents.utils.agent_states import (
|
from tradingagents.agents.utils.agent_states import (
|
||||||
AgentState,
|
AgentState,
|
||||||
InvestDebateState,
|
InvestDebateState,
|
||||||
RiskDebateState,
|
RiskDebateState,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.interface import set_config
|
from tradingagents.dataflows.config import set_config
|
||||||
|
|
||||||
|
# Import the new abstract tool methods from agent_utils
|
||||||
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
get_stock_data,
|
||||||
|
get_indicators,
|
||||||
|
get_fundamentals,
|
||||||
|
get_balance_sheet,
|
||||||
|
get_cashflow,
|
||||||
|
get_income_statement,
|
||||||
|
get_news,
|
||||||
|
get_insider_transactions,
|
||||||
|
get_global_news
|
||||||
|
)
|
||||||
|
|
||||||
|
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
from .setup import GraphSetup
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
@@ -34,6 +55,7 @@ class TradingAgentsGraph:
|
|||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
debug=False,
|
debug=False,
|
||||||
config: Dict[str, Any] = None,
|
config: Dict[str, Any] = None,
|
||||||
|
callbacks: Optional[List] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the trading agents graph and components.
|
"""Initialize the trading agents graph and components.
|
||||||
|
|
||||||
@@ -41,52 +63,62 @@ class TradingAgentsGraph:
|
|||||||
selected_analysts: List of analyst types to include
|
selected_analysts: List of analyst types to include
|
||||||
debug: Whether to run in debug mode
|
debug: Whether to run in debug mode
|
||||||
config: Configuration dictionary. If None, uses default config
|
config: Configuration dictionary. If None, uses default config
|
||||||
|
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||||||
"""
|
"""
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or DEFAULT_CONFIG
|
||||||
|
self.callbacks = callbacks or []
|
||||||
|
|
||||||
# Update the interface's config
|
# Update the interface's config
|
||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
||||||
# Create necessary directories
|
# Create necessary directories
|
||||||
os.makedirs(
|
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
|
||||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
os.makedirs(self.config["results_dir"], exist_ok=True)
|
||||||
exist_ok=True,
|
|
||||||
|
# Initialize LLMs with provider-specific thinking configuration
|
||||||
|
llm_kwargs = self._get_provider_kwargs()
|
||||||
|
|
||||||
|
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||||||
|
if self.callbacks:
|
||||||
|
llm_kwargs["callbacks"] = self.callbacks
|
||||||
|
|
||||||
|
deep_client = create_llm_client(
|
||||||
|
provider=self.config["llm_provider"],
|
||||||
|
model=self.config["deep_think_llm"],
|
||||||
|
base_url=self.config.get("backend_url"),
|
||||||
|
**llm_kwargs,
|
||||||
|
)
|
||||||
|
quick_client = create_llm_client(
|
||||||
|
provider=self.config["llm_provider"],
|
||||||
|
model=self.config["quick_think_llm"],
|
||||||
|
base_url=self.config.get("backend_url"),
|
||||||
|
**llm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize LLMs
|
self.deep_thinking_llm = deep_client.get_llm()
|
||||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"])
|
self.quick_thinking_llm = quick_client.get_llm()
|
||||||
self.quick_thinking_llm = ChatOpenAI(
|
|
||||||
model=self.config["quick_think_llm"], temperature=0.1
|
|
||||||
)
|
|
||||||
self.toolkit = Toolkit(config=self.config)
|
|
||||||
|
|
||||||
# Initialize memories
|
self.memory_log = TradingMemoryLog(self.config)
|
||||||
self.bull_memory = FinancialSituationMemory("bull_memory")
|
|
||||||
self.bear_memory = FinancialSituationMemory("bear_memory")
|
|
||||||
self.trader_memory = FinancialSituationMemory("trader_memory")
|
|
||||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory")
|
|
||||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory")
|
|
||||||
|
|
||||||
# Create tool nodes
|
# Create tool nodes
|
||||||
self.tool_nodes = self._create_tool_nodes()
|
self.tool_nodes = self._create_tool_nodes()
|
||||||
|
|
||||||
# Initialize components
|
# Initialize components
|
||||||
self.conditional_logic = ConditionalLogic()
|
self.conditional_logic = ConditionalLogic(
|
||||||
|
max_debate_rounds=self.config["max_debate_rounds"],
|
||||||
|
max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"],
|
||||||
|
)
|
||||||
self.graph_setup = GraphSetup(
|
self.graph_setup = GraphSetup(
|
||||||
self.quick_thinking_llm,
|
self.quick_thinking_llm,
|
||||||
self.deep_thinking_llm,
|
self.deep_thinking_llm,
|
||||||
self.toolkit,
|
|
||||||
self.tool_nodes,
|
self.tool_nodes,
|
||||||
self.bull_memory,
|
|
||||||
self.bear_memory,
|
|
||||||
self.trader_memory,
|
|
||||||
self.invest_judge_memory,
|
|
||||||
self.risk_manager_memory,
|
|
||||||
self.conditional_logic,
|
self.conditional_logic,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.propagator = Propagator()
|
self.propagator = Propagator(
|
||||||
|
max_recur_limit=self.config.get("max_recur_limit", 100),
|
||||||
|
)
|
||||||
self.reflector = Reflector(self.quick_thinking_llm)
|
self.reflector = Reflector(self.quick_thinking_llm)
|
||||||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||||||
|
|
||||||
@@ -95,67 +127,223 @@ class TradingAgentsGraph:
|
|||||||
self.ticker = None
|
self.ticker = None
|
||||||
self.log_states_dict = {} # date to full state dict
|
self.log_states_dict = {} # date to full state dict
|
||||||
|
|
||||||
# Set up the graph
|
# Set up the graph: keep the workflow for recompilation with a checkpointer.
|
||||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
self.workflow = self.graph_setup.setup_graph(selected_analysts)
|
||||||
|
self.graph = self.workflow.compile()
|
||||||
|
self._checkpointer_ctx = None
|
||||||
|
|
||||||
|
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||||
|
"""Get provider-specific kwargs for LLM client creation."""
|
||||||
|
kwargs = {}
|
||||||
|
provider = self.config.get("llm_provider", "").lower()
|
||||||
|
|
||||||
|
if provider == "google":
|
||||||
|
thinking_level = self.config.get("google_thinking_level")
|
||||||
|
if thinking_level:
|
||||||
|
kwargs["thinking_level"] = thinking_level
|
||||||
|
|
||||||
|
elif provider == "openai":
|
||||||
|
reasoning_effort = self.config.get("openai_reasoning_effort")
|
||||||
|
if reasoning_effort:
|
||||||
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
|
elif provider == "anthropic":
|
||||||
|
effort = self.config.get("anthropic_effort")
|
||||||
|
if effort:
|
||||||
|
kwargs["effort"] = effort
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||||
"""Create tool nodes for different data sources."""
|
"""Create tool nodes for different data sources using abstract methods."""
|
||||||
return {
|
return {
|
||||||
"market": ToolNode(
|
"market": ToolNode(
|
||||||
[
|
[
|
||||||
# online tools
|
# Core stock data tools
|
||||||
self.toolkit.get_YFin_data_online,
|
get_stock_data,
|
||||||
self.toolkit.get_stockstats_indicators_report_online,
|
# Technical indicators
|
||||||
# offline tools
|
get_indicators,
|
||||||
self.toolkit.get_YFin_data,
|
|
||||||
self.toolkit.get_stockstats_indicators_report,
|
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"social": ToolNode(
|
"social": ToolNode(
|
||||||
[
|
[
|
||||||
# online tools
|
# News tools for social media analysis
|
||||||
self.toolkit.get_stock_news_openai,
|
get_news,
|
||||||
# offline tools
|
|
||||||
self.toolkit.get_reddit_stock_info,
|
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"news": ToolNode(
|
"news": ToolNode(
|
||||||
[
|
[
|
||||||
# online tools
|
# News and insider information
|
||||||
self.toolkit.get_global_news_openai,
|
get_news,
|
||||||
self.toolkit.get_google_news,
|
get_global_news,
|
||||||
# offline tools
|
get_insider_transactions,
|
||||||
self.toolkit.get_finnhub_news,
|
|
||||||
self.toolkit.get_reddit_news,
|
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"fundamentals": ToolNode(
|
"fundamentals": ToolNode(
|
||||||
[
|
[
|
||||||
# online tools
|
# Fundamental analysis tools
|
||||||
self.toolkit.get_fundamentals_openai,
|
get_fundamentals,
|
||||||
# offline tools
|
get_balance_sheet,
|
||||||
self.toolkit.get_finnhub_company_insider_sentiment,
|
get_cashflow,
|
||||||
self.toolkit.get_finnhub_company_insider_transactions,
|
get_income_statement,
|
||||||
self.toolkit.get_simfin_balance_sheet,
|
|
||||||
self.toolkit.get_simfin_cashflow,
|
|
||||||
self.toolkit.get_simfin_income_stmt,
|
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def propagate(self, company_name, trade_date):
|
def _resolve_benchmark(self, ticker: str) -> str:
|
||||||
"""Run the trading agents graph for a company on a specific date."""
|
"""Pick the benchmark ticker for alpha calculation against ``ticker``.
|
||||||
|
|
||||||
|
``config["benchmark_ticker"]`` overrides everything when set; otherwise
|
||||||
|
the suffix map matches the ticker's exchange suffix (e.g. ``.T`` for
|
||||||
|
Tokyo). US-listed tickers without a dotted suffix fall through to the
|
||||||
|
empty-suffix entry (SPY by default). Unrecognised suffixes (including
|
||||||
|
US tickers with dots like ``BRK.B``) also fall back to the empty-suffix
|
||||||
|
entry, which is the right default because the alpha calculation works
|
||||||
|
in USD.
|
||||||
|
"""
|
||||||
|
explicit = self.config.get("benchmark_ticker")
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
benchmark_map = self.config.get("benchmark_map", {})
|
||||||
|
ticker_upper = ticker.upper()
|
||||||
|
for suffix, benchmark in benchmark_map.items():
|
||||||
|
if suffix and ticker_upper.endswith(suffix.upper()):
|
||||||
|
return benchmark
|
||||||
|
return benchmark_map.get("", "SPY")
|
||||||
|
|
||||||
|
def _fetch_returns(
|
||||||
|
self, ticker: str, trade_date: str, holding_days: int = 5,
|
||||||
|
benchmark: str = "SPY",
|
||||||
|
) -> Tuple[Optional[float], Optional[float], Optional[int]]:
|
||||||
|
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
|
||||||
|
|
||||||
|
``benchmark`` is the index used as the alpha baseline (resolved by the
|
||||||
|
caller via ``_resolve_benchmark``). Returns ``(raw_return, alpha_return,
|
||||||
|
actual_holding_days)`` or ``(None, None, None)`` if price data is
|
||||||
|
unavailable (too recent, delisted, or network error).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start = datetime.strptime(trade_date, "%Y-%m-%d")
|
||||||
|
end = start + timedelta(days=holding_days + 7) # buffer for weekends/holidays
|
||||||
|
end_str = end.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
stock = yf.Ticker(ticker).history(start=trade_date, end=end_str)
|
||||||
|
bench = yf.Ticker(benchmark).history(start=trade_date, end=end_str)
|
||||||
|
|
||||||
|
if len(stock) < 2 or len(bench) < 2:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
actual_days = min(holding_days, len(stock) - 1, len(bench) - 1)
|
||||||
|
raw = float(
|
||||||
|
(stock["Close"].iloc[actual_days] - stock["Close"].iloc[0])
|
||||||
|
/ stock["Close"].iloc[0]
|
||||||
|
)
|
||||||
|
bench_ret = float(
|
||||||
|
(bench["Close"].iloc[actual_days] - bench["Close"].iloc[0])
|
||||||
|
/ bench["Close"].iloc[0]
|
||||||
|
)
|
||||||
|
alpha = raw - bench_ret
|
||||||
|
return raw, alpha, actual_days
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Could not resolve outcome for %s on %s vs %s (will retry next run): %s",
|
||||||
|
ticker, trade_date, benchmark, e,
|
||||||
|
)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
def _resolve_pending_entries(self, ticker: str) -> None:
|
||||||
|
"""Resolve pending log entries for ticker at the start of a new run.
|
||||||
|
|
||||||
|
Fetches returns for each same-ticker pending entry, generates reflections,
|
||||||
|
then writes all updates in a single atomic batch write to avoid redundant I/O.
|
||||||
|
Skips entries whose price data is not yet available (too recent or delisted).
|
||||||
|
|
||||||
|
Trade-off: only same-ticker entries are resolved per run. Entries for
|
||||||
|
other tickers accumulate until that ticker is run again.
|
||||||
|
"""
|
||||||
|
pending = [e for e in self.memory_log.get_pending_entries() if e["ticker"] == ticker]
|
||||||
|
if not pending:
|
||||||
|
return
|
||||||
|
|
||||||
|
benchmark = self._resolve_benchmark(ticker)
|
||||||
|
updates = []
|
||||||
|
for entry in pending:
|
||||||
|
raw, alpha, days = self._fetch_returns(
|
||||||
|
ticker, entry["date"], benchmark=benchmark,
|
||||||
|
)
|
||||||
|
if raw is None:
|
||||||
|
continue # price not available yet — try again next run
|
||||||
|
reflection = self.reflector.reflect_on_final_decision(
|
||||||
|
final_decision=entry.get("decision", ""),
|
||||||
|
raw_return=raw,
|
||||||
|
alpha_return=alpha,
|
||||||
|
benchmark_name=benchmark,
|
||||||
|
)
|
||||||
|
updates.append({
|
||||||
|
"ticker": ticker,
|
||||||
|
"trade_date": entry["date"],
|
||||||
|
"raw_return": raw,
|
||||||
|
"alpha_return": alpha,
|
||||||
|
"holding_days": days,
|
||||||
|
"reflection": reflection,
|
||||||
|
})
|
||||||
|
|
||||||
|
if updates:
|
||||||
|
self.memory_log.batch_update_with_outcomes(updates)
|
||||||
|
|
||||||
|
def propagate(self, company_name, trade_date):
|
||||||
|
"""Run the trading agents graph for a company on a specific date.
|
||||||
|
|
||||||
|
When ``checkpoint_enabled`` is set in config, the graph is recompiled
|
||||||
|
with a per-ticker SqliteSaver so a crashed run can resume from the last
|
||||||
|
successful node on a subsequent invocation with the same ticker+date.
|
||||||
|
"""
|
||||||
self.ticker = company_name
|
self.ticker = company_name
|
||||||
|
|
||||||
# Initialize state
|
# Resolve any pending memory-log entries for this ticker before the pipeline runs.
|
||||||
|
self._resolve_pending_entries(company_name)
|
||||||
|
|
||||||
|
# Recompile with a checkpointer if the user opted in.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
self._checkpointer_ctx = get_checkpointer(
|
||||||
|
self.config["data_cache_dir"], company_name
|
||||||
|
)
|
||||||
|
saver = self._checkpointer_ctx.__enter__()
|
||||||
|
self.graph = self.workflow.compile(checkpointer=saver)
|
||||||
|
|
||||||
|
step = checkpoint_step(
|
||||||
|
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||||
|
)
|
||||||
|
if step is not None:
|
||||||
|
logger.info(
|
||||||
|
"Resuming from step %d for %s on %s", step, company_name, trade_date
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Starting fresh for %s on %s", company_name, trade_date)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._run_graph(company_name, trade_date)
|
||||||
|
finally:
|
||||||
|
if self._checkpointer_ctx is not None:
|
||||||
|
self._checkpointer_ctx.__exit__(None, None, None)
|
||||||
|
self._checkpointer_ctx = None
|
||||||
|
self.graph = self.workflow.compile()
|
||||||
|
|
||||||
|
def _run_graph(self, company_name, trade_date):
|
||||||
|
"""Execute the graph and write the resulting state to disk and memory log."""
|
||||||
|
# Initialize state — inject memory log context for PM.
|
||||||
|
past_context = self.memory_log.get_past_context(company_name)
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date
|
company_name, trade_date, past_context=past_context
|
||||||
)
|
)
|
||||||
args = self.propagator.get_graph_args()
|
args = self.propagator.get_graph_args()
|
||||||
|
|
||||||
|
# Inject thread_id so same ticker+date resumes, different date starts fresh.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
tid = thread_id(company_name, str(trade_date))
|
||||||
|
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
# Debug mode with tracing
|
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in self.graph.stream(init_agent_state, **args):
|
for chunk in self.graph.stream(init_agent_state, **args):
|
||||||
if len(chunk["messages"]) == 0:
|
if len(chunk["messages"]) == 0:
|
||||||
@@ -163,19 +351,33 @@ class TradingAgentsGraph:
|
|||||||
else:
|
else:
|
||||||
chunk["messages"][-1].pretty_print()
|
chunk["messages"][-1].pretty_print()
|
||||||
trace.append(chunk)
|
trace.append(chunk)
|
||||||
|
# Streamed chunks are per-node deltas. Merge them so the returned
|
||||||
final_state = trace[-1]
|
# state matches what graph.invoke() yields in the non-debug path.
|
||||||
|
final_state = {}
|
||||||
|
for chunk in trace:
|
||||||
|
final_state.update(chunk)
|
||||||
else:
|
else:
|
||||||
# Standard mode without tracing
|
|
||||||
final_state = self.graph.invoke(init_agent_state, **args)
|
final_state = self.graph.invoke(init_agent_state, **args)
|
||||||
|
|
||||||
# Store current state for reflection
|
# Store current state for reflection.
|
||||||
self.curr_state = final_state
|
self.curr_state = final_state
|
||||||
|
|
||||||
# Log state
|
# Log state to disk.
|
||||||
self._log_state(trade_date, final_state)
|
self._log_state(trade_date, final_state)
|
||||||
|
|
||||||
# Return decision and processed signal
|
# Store decision for deferred reflection on the next same-ticker run.
|
||||||
|
self.memory_log.store_decision(
|
||||||
|
ticker=company_name,
|
||||||
|
trade_date=trade_date,
|
||||||
|
final_trade_decision=final_state["final_trade_decision"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear checkpoint on successful completion to avoid stale state.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
clear_checkpoint(
|
||||||
|
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||||
|
)
|
||||||
|
|
||||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||||
|
|
||||||
def _log_state(self, trade_date, final_state):
|
def _log_state(self, trade_date, final_state):
|
||||||
@@ -200,8 +402,8 @@ class TradingAgentsGraph:
|
|||||||
},
|
},
|
||||||
"trader_investment_decision": final_state["trader_investment_plan"],
|
"trader_investment_decision": final_state["trader_investment_plan"],
|
||||||
"risk_debate_state": {
|
"risk_debate_state": {
|
||||||
"risky_history": final_state["risk_debate_state"]["risky_history"],
|
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
|
||||||
"safe_history": final_state["risk_debate_state"]["safe_history"],
|
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
|
||||||
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
|
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
|
||||||
"history": final_state["risk_debate_state"]["history"],
|
"history": final_state["risk_debate_state"]["history"],
|
||||||
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
|
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
|
||||||
@@ -210,33 +412,15 @@ class TradingAgentsGraph:
|
|||||||
"final_trade_decision": final_state["final_trade_decision"],
|
"final_trade_decision": final_state["final_trade_decision"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save to file
|
# Save to file. Reject ticker values that would escape the
|
||||||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
# results directory when joined as a path component.
|
||||||
|
safe_ticker = safe_ticker_component(self.ticker)
|
||||||
|
directory = Path(self.config["results_dir"]) / safe_ticker / "TradingAgentsStrategy_logs"
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(
|
log_path = directory / f"full_states_log_{trade_date}.json"
|
||||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log.json",
|
with open(log_path, "w", encoding="utf-8") as f:
|
||||||
"w",
|
json.dump(self.log_states_dict[str(trade_date)], f, indent=4)
|
||||||
) as f:
|
|
||||||
json.dump(self.log_states_dict, f, indent=4)
|
|
||||||
|
|
||||||
def reflect_and_remember(self, returns_losses):
|
|
||||||
"""Reflect on decisions and update memory based on returns."""
|
|
||||||
self.reflector.reflect_bull_researcher(
|
|
||||||
self.curr_state, returns_losses, self.bull_memory
|
|
||||||
)
|
|
||||||
self.reflector.reflect_bear_researcher(
|
|
||||||
self.curr_state, returns_losses, self.bear_memory
|
|
||||||
)
|
|
||||||
self.reflector.reflect_trader(
|
|
||||||
self.curr_state, returns_losses, self.trader_memory
|
|
||||||
)
|
|
||||||
self.reflector.reflect_invest_judge(
|
|
||||||
self.curr_state, returns_losses, self.invest_judge_memory
|
|
||||||
)
|
|
||||||
self.reflector.reflect_risk_manager(
|
|
||||||
self.curr_state, returns_losses, self.risk_manager_memory
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_signal(self, full_signal):
|
def process_signal(self, full_signal):
|
||||||
"""Process a signal to extract the core decision."""
|
"""Process a signal to extract the core decision."""
|
||||||
|
|||||||
15
tradingagents/llm_clients/TODO.md
Normal file
15
tradingagents/llm_clients/TODO.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# LLM Clients - Consistency Improvements
|
||||||
|
|
||||||
|
## Issues to Fix
|
||||||
|
|
||||||
|
### 1. `validate_model()` is never called
|
||||||
|
- Add validation call in `get_llm()` with warning (not error) for unknown models
|
||||||
|
|
||||||
|
### 2. ~~Inconsistent parameter handling~~ (Fixed)
|
||||||
|
- GoogleClient now accepts unified `api_key` and maps it to `google_api_key`
|
||||||
|
|
||||||
|
### 3. ~~`base_url` accepted but ignored~~ (Fixed)
|
||||||
|
- All clients now pass `base_url` to their respective LLM constructors
|
||||||
|
|
||||||
|
### 4. ~~Update validators.py with models from CLI~~ (Fixed)
|
||||||
|
- Synced in v0.2.2
|
||||||
4
tradingagents/llm_clients/__init__.py
Normal file
4
tradingagents/llm_clients/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .base_client import BaseLLMClient
|
||||||
|
from .factory import create_llm_client
|
||||||
|
|
||||||
|
__all__ = ["BaseLLMClient", "create_llm_client"]
|
||||||
48
tradingagents/llm_clients/anthropic_client.py
Normal file
48
tradingagents/llm_clients/anthropic_client.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
from .validators import validate_model
|
||||||
|
|
||||||
|
_PASSTHROUGH_KWARGS = (
|
||||||
|
"timeout", "max_retries", "api_key", "max_tokens",
|
||||||
|
"callbacks", "http_client", "http_async_client", "effort",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedChatAnthropic(ChatAnthropic):
|
||||||
|
"""ChatAnthropic with normalized content output.
|
||||||
|
|
||||||
|
Claude models with extended thinking or tool use return content as a
|
||||||
|
list of typed blocks. This normalizes to string for consistent
|
||||||
|
downstream handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def invoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicClient(BaseLLMClient):
|
||||||
|
"""Client for Anthropic Claude models."""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured ChatAnthropic instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
|
if self.base_url:
|
||||||
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
|
for key in _PASSTHROUGH_KWARGS:
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
return NormalizedChatAnthropic(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate model for Anthropic."""
|
||||||
|
return validate_model("anthropic", self.model)
|
||||||
44
tradingagents/llm_clients/api_key_env.py
Normal file
44
tradingagents/llm_clients/api_key_env.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Canonical provider -> API-key env-var mapping.
|
||||||
|
|
||||||
|
A single source of truth for which environment variable holds the API
|
||||||
|
key for each supported LLM provider. Used by the CLI's interactive key
|
||||||
|
prompt (cli/utils.ensure_api_key) and by anything else that needs to
|
||||||
|
ask "does this provider require a key, and which env var is it?".
|
||||||
|
|
||||||
|
When adding a new provider, register its env var here so the CLI flow
|
||||||
|
prompts for it automatically instead of failing on first API call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
||||||
|
"openai": "OPENAI_API_KEY",
|
||||||
|
"anthropic": "ANTHROPIC_API_KEY",
|
||||||
|
"google": "GOOGLE_API_KEY",
|
||||||
|
"azure": "AZURE_OPENAI_API_KEY",
|
||||||
|
"xai": "XAI_API_KEY",
|
||||||
|
"deepseek": "DEEPSEEK_API_KEY",
|
||||||
|
# Dual-region providers each carry their own account; keys are not
|
||||||
|
# interchangeable between the international and China endpoints.
|
||||||
|
"qwen": "DASHSCOPE_API_KEY",
|
||||||
|
"qwen-cn": "DASHSCOPE_CN_API_KEY",
|
||||||
|
"glm": "ZHIPU_API_KEY",
|
||||||
|
"glm-cn": "ZHIPU_CN_API_KEY",
|
||||||
|
"minimax": "MINIMAX_API_KEY",
|
||||||
|
"minimax-cn": "MINIMAX_CN_API_KEY",
|
||||||
|
"openrouter": "OPENROUTER_API_KEY",
|
||||||
|
# Local runtimes do not authenticate.
|
||||||
|
"ollama": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_key_env(provider: str) -> Optional[str]:
|
||||||
|
"""Return the env var name for `provider`'s API key, or None if not applicable.
|
||||||
|
|
||||||
|
Unknown providers also return None — callers should treat that as
|
||||||
|
"no key check possible" rather than as "no key required".
|
||||||
|
"""
|
||||||
|
return PROVIDER_API_KEY_ENV.get(provider.lower())
|
||||||
52
tradingagents/llm_clients/azure_client.py
Normal file
52
tradingagents/llm_clients/azure_client.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
from .validators import validate_model
|
||||||
|
|
||||||
|
_PASSTHROUGH_KWARGS = (
|
||||||
|
"timeout", "max_retries", "api_key", "reasoning_effort",
|
||||||
|
"callbacks", "http_client", "http_async_client",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
|
||||||
|
"""AzureChatOpenAI with normalized content output."""
|
||||||
|
|
||||||
|
def invoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIClient(BaseLLMClient):
|
||||||
|
"""Client for Azure OpenAI deployments.
|
||||||
|
|
||||||
|
Requires environment variables:
|
||||||
|
AZURE_OPENAI_API_KEY: API key
|
||||||
|
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
|
||||||
|
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
|
||||||
|
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured AzureChatOpenAI instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
|
||||||
|
llm_kwargs = {
|
||||||
|
"model": self.model,
|
||||||
|
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in _PASSTHROUGH_KWARGS:
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
return NormalizedAzureChatOpenAI(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Azure accepts any deployed model name."""
|
||||||
|
return True
|
||||||
62
tradingagents/llm_clients/base_client.py
Normal file
62
tradingagents/llm_clients/base_client.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Optional
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_content(response):
|
||||||
|
"""Normalize LLM response content to a plain string.
|
||||||
|
|
||||||
|
Multiple providers (OpenAI Responses API, Google Gemini 3) return content
|
||||||
|
as a list of typed blocks, e.g. [{'type': 'reasoning', ...}, {'type': 'text', 'text': '...'}].
|
||||||
|
Downstream agents expect response.content to be a string. This extracts
|
||||||
|
and joins the text blocks, discarding reasoning/metadata blocks.
|
||||||
|
"""
|
||||||
|
content = response.content
|
||||||
|
if isinstance(content, list):
|
||||||
|
texts = [
|
||||||
|
item.get("text", "") if isinstance(item, dict) and item.get("type") == "text"
|
||||||
|
else item if isinstance(item, str) else ""
|
||||||
|
for item in content
|
||||||
|
]
|
||||||
|
response.content = "\n".join(t for t in texts if t)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMClient(ABC):
|
||||||
|
"""Abstract base class for LLM clients."""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
self.model = model
|
||||||
|
self.base_url = base_url
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def get_provider_name(self) -> str:
|
||||||
|
"""Return the provider name used in warning messages."""
|
||||||
|
provider = getattr(self, "provider", None)
|
||||||
|
if provider:
|
||||||
|
return str(provider)
|
||||||
|
return self.__class__.__name__.removesuffix("Client").lower()
|
||||||
|
|
||||||
|
def warn_if_unknown_model(self) -> None:
|
||||||
|
"""Warn when the model is outside the known list for the provider."""
|
||||||
|
if self.validate_model():
|
||||||
|
return
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
(
|
||||||
|
f"Model '{self.model}' is not in the known model list for "
|
||||||
|
f"provider '{self.get_provider_name()}'. Continuing anyway."
|
||||||
|
),
|
||||||
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return the configured LLM instance."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate that the model is supported by this client."""
|
||||||
|
pass
|
||||||
120
tradingagents/llm_clients/capabilities.py
Normal file
120
tradingagents/llm_clients/capabilities.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Declarative per-model capability table for OpenAI-compatible providers.
|
||||||
|
|
||||||
|
This is the single place that knows which model IDs reject which API
|
||||||
|
parameters or require which structured-output method. The LLM client
|
||||||
|
subclasses consult ``get_capabilities(model_name)`` instead of hardcoding
|
||||||
|
model-name ``if`` ladders, so adding a new model (or a new provider quirk)
|
||||||
|
means editing this table — not the client code.
|
||||||
|
|
||||||
|
Pattern adapted from the per-model ``compat:`` flags DeepSeek themselves
|
||||||
|
publish in their integration guides (e.g. the Oh My Pi config schema
|
||||||
|
documents ``supportsToolChoice``, ``requiresReasoningContentForToolCalls``
|
||||||
|
as declarative per-model fields).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
StructuredMethod = Literal[
|
||||||
|
"function_calling", # uses tools; respects supports_tool_choice
|
||||||
|
"json_mode", # uses response_format={"type":"json_object"}
|
||||||
|
"json_schema", # uses response_format={"type":"json_schema",...}
|
||||||
|
"none", # no structured output available; caller falls back to free-text
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelCapabilities:
|
||||||
|
"""What an OpenAI-compatible model accepts at the API level."""
|
||||||
|
|
||||||
|
supports_tool_choice: bool
|
||||||
|
supports_json_mode: bool
|
||||||
|
supports_json_schema: bool
|
||||||
|
preferred_structured_method: StructuredMethod
|
||||||
|
# DeepSeek thinking-mode models 400 if reasoning_content from prior
|
||||||
|
# assistant turns is not echoed back on the next request.
|
||||||
|
requires_reasoning_content_roundtrip: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# DeepSeek's thinking models accept the ``tools`` array but reject the
|
||||||
|
# ``tool_choice`` parameter (official Oh My Pi integration guide and the
|
||||||
|
# 400 response in issue #678). Their official tool-calling examples
|
||||||
|
# (api-docs.deepseek.com/guides/tool_calls) pass ``tools=[...]`` without
|
||||||
|
# ``tool_choice`` — we mirror that pattern by setting supports_tool_choice
|
||||||
|
# to False and letting the client suppress the kwarg.
|
||||||
|
_DEEPSEEK_THINKING = ModelCapabilities(
|
||||||
|
supports_tool_choice=False,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_json_schema=False,
|
||||||
|
preferred_structured_method="function_calling",
|
||||||
|
requires_reasoning_content_roundtrip=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_DEEPSEEK_CHAT = ModelCapabilities(
|
||||||
|
supports_tool_choice=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_json_schema=False,
|
||||||
|
preferred_structured_method="function_calling",
|
||||||
|
)
|
||||||
|
|
||||||
|
# MiniMax M2.x reasoning models accept the tools array, but their
|
||||||
|
# tool_choice parameter is restricted to the enum {"none", "auto"}
|
||||||
|
# (platform.minimax.io/docs/api-reference/text-post). Langchain's
|
||||||
|
# function_calling path sends tool_choice as a function-spec dict, which
|
||||||
|
# MiniMax 400s — same shape as the DeepSeek bug. supports_tool_choice=False
|
||||||
|
# makes the dispatch in NormalizedChatOpenAI suppress the kwarg; the schema
|
||||||
|
# still ships as a tool. json_mode response_format is only for
|
||||||
|
# MiniMax-Text-01, not M2.x.
|
||||||
|
_MINIMAX_THINKING = ModelCapabilities(
|
||||||
|
supports_tool_choice=False,
|
||||||
|
supports_json_mode=False,
|
||||||
|
supports_json_schema=False,
|
||||||
|
preferred_structured_method="function_calling",
|
||||||
|
)
|
||||||
|
|
||||||
|
_DEFAULT = ModelCapabilities(
|
||||||
|
supports_tool_choice=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_json_schema=True,
|
||||||
|
preferred_structured_method="function_calling",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Exact-ID matches take precedence over pattern matches.
|
||||||
|
_BY_ID: dict[str, ModelCapabilities] = {
|
||||||
|
"deepseek-chat": _DEEPSEEK_CHAT,
|
||||||
|
"deepseek-reasoner": _DEEPSEEK_THINKING,
|
||||||
|
"deepseek-v4-flash": _DEEPSEEK_THINKING,
|
||||||
|
"deepseek-v4-pro": _DEEPSEEK_THINKING,
|
||||||
|
# MiniMax — full official model lineup per
|
||||||
|
# platform.minimax.io/docs/api-reference/text-openai-api
|
||||||
|
"MiniMax-M2.7": _MINIMAX_THINKING,
|
||||||
|
"MiniMax-M2.7-highspeed": _MINIMAX_THINKING,
|
||||||
|
"MiniMax-M2.5": _MINIMAX_THINKING,
|
||||||
|
"MiniMax-M2.5-highspeed": _MINIMAX_THINKING,
|
||||||
|
"MiniMax-M2.1": _MINIMAX_THINKING,
|
||||||
|
"MiniMax-M2.1-highspeed": _MINIMAX_THINKING,
|
||||||
|
"MiniMax-M2": _MINIMAX_THINKING,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Forward-compat patterns. New ``deepseek-v5-*`` / ``deepseek-reasoner-*``
|
||||||
|
# or ``MiniMax-M3*`` variants inherit the thinking-mode quirks automatically.
|
||||||
|
_BY_PATTERN: list[tuple[re.Pattern[str], ModelCapabilities]] = [
|
||||||
|
(re.compile(r"^deepseek-v\d"), _DEEPSEEK_THINKING),
|
||||||
|
(re.compile(r"^deepseek-reasoner"), _DEEPSEEK_THINKING),
|
||||||
|
(re.compile(r"^MiniMax-M\d"), _MINIMAX_THINKING),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_capabilities(model_name: str) -> ModelCapabilities:
|
||||||
|
"""Resolve capabilities by exact ID, then pattern, then default."""
|
||||||
|
if model_name in _BY_ID:
|
||||||
|
return _BY_ID[model_name]
|
||||||
|
for pattern, caps in _BY_PATTERN:
|
||||||
|
if pattern.match(model_name):
|
||||||
|
return caps
|
||||||
|
return _DEFAULT
|
||||||
57
tradingagents/llm_clients/factory.py
Normal file
57
tradingagents/llm_clients/factory.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient
|
||||||
|
|
||||||
|
# Providers that use the OpenAI-compatible chat completions API
|
||||||
|
_OPENAI_COMPATIBLE = (
|
||||||
|
"openai", "xai", "deepseek",
|
||||||
|
"qwen", "qwen-cn",
|
||||||
|
"glm", "glm-cn",
|
||||||
|
"minimax", "minimax-cn",
|
||||||
|
"ollama", "openrouter",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm_client(
|
||||||
|
provider: str,
|
||||||
|
model: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseLLMClient:
|
||||||
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
|
Provider modules are imported lazily so that simply importing this
|
||||||
|
factory (e.g. during test collection) does not pull in heavy LLM SDKs
|
||||||
|
or fail when their API keys are absent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: LLM provider name
|
||||||
|
model: Model name/identifier
|
||||||
|
base_url: Optional base URL for API endpoint
|
||||||
|
**kwargs: Additional provider-specific arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured BaseLLMClient instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is not supported
|
||||||
|
"""
|
||||||
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
|
if provider_lower in _OPENAI_COMPATIBLE:
|
||||||
|
from .openai_client import OpenAIClient
|
||||||
|
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "anthropic":
|
||||||
|
from .anthropic_client import AnthropicClient
|
||||||
|
return AnthropicClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "google":
|
||||||
|
from .google_client import GoogleClient
|
||||||
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "azure":
|
||||||
|
from .azure_client import AzureOpenAIClient
|
||||||
|
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
63
tradingagents/llm_clients/google_client.py
Normal file
63
tradingagents/llm_clients/google_client.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
from .validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI):
|
||||||
|
"""ChatGoogleGenerativeAI with normalized content output.
|
||||||
|
|
||||||
|
Gemini 3 models return content as list of typed blocks.
|
||||||
|
This normalizes to string for consistent downstream handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def invoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleClient(BaseLLMClient):
|
||||||
|
"""Client for Google Gemini models."""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
|
if self.base_url:
|
||||||
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
|
for key in ("timeout", "max_retries", "callbacks", "http_client", "http_async_client"):
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
# Unified api_key maps to provider-specific google_api_key
|
||||||
|
google_api_key = self.kwargs.get("api_key") or self.kwargs.get("google_api_key")
|
||||||
|
if google_api_key:
|
||||||
|
llm_kwargs["google_api_key"] = google_api_key
|
||||||
|
|
||||||
|
# Map thinking_level to appropriate API param based on model
|
||||||
|
# Gemini 3 Pro: low, high
|
||||||
|
# Gemini 3 Flash: minimal, low, medium, high
|
||||||
|
# Gemini 2.5: thinking_budget (0=disable, -1=dynamic)
|
||||||
|
thinking_level = self.kwargs.get("thinking_level")
|
||||||
|
if thinking_level:
|
||||||
|
model_lower = self.model.lower()
|
||||||
|
if "gemini-3" in model_lower:
|
||||||
|
# Gemini 3 Pro doesn't support "minimal", use "low" instead
|
||||||
|
if "pro" in model_lower and thinking_level == "minimal":
|
||||||
|
thinking_level = "low"
|
||||||
|
llm_kwargs["thinking_level"] = thinking_level
|
||||||
|
else:
|
||||||
|
# Gemini 2.5: map to thinking_budget
|
||||||
|
llm_kwargs["thinking_budget"] = -1 if thinking_level == "high" else 0
|
||||||
|
|
||||||
|
return NormalizedChatGoogleGenerativeAI(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate model for Google."""
|
||||||
|
return validate_model("google", self.model)
|
||||||
197
tradingagents/llm_clients/model_catalog.py
Normal file
197
tradingagents/llm_clients/model_catalog.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""Shared model catalog for CLI selections and validation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
ModelOption = Tuple[str, str]
|
||||||
|
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
|
||||||
|
|
||||||
|
|
||||||
|
# Shared model list for GLM via Z.AI (international) and BigModel (China).
|
||||||
|
# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides).
|
||||||
|
# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}.
|
||||||
|
_GLM_MODELS: Dict[str, List[ModelOption]] = {
|
||||||
|
"quick": [
|
||||||
|
("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"),
|
||||||
|
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
|
||||||
|
("GLM-4.5-Air - Lightweight, cost-efficient", "glm-4.5-air"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("GLM-5.1 - Latest flagship, 204K ctx", "glm-5.1"),
|
||||||
|
("GLM-5 - Flagship, 204K ctx", "glm-5"),
|
||||||
|
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Shared model list for Qwen's global (dashscope-intl) and CN (dashscope) endpoints.
|
||||||
|
# Source: modelstudio.console.alibabacloud.com (Featured Models — Flagship + Cost-optimized).
|
||||||
|
#
|
||||||
|
# Only versioned IDs are exposed in the dropdown. The version-less aliases
|
||||||
|
# (qwen-plus, qwen-flash) are documented by Alibaba as auto-upgrading
|
||||||
|
# pointers ("backbone, latest, and snapshot ... have been upgraded to the
|
||||||
|
# Qwen3 series"), which means their behavior shifts when Alibaba rotates
|
||||||
|
# the backing model. Users who want a specific generation pick it
|
||||||
|
# explicitly; users who really want auto-latest can enter the alias via
|
||||||
|
# "Custom model ID".
|
||||||
|
_QWEN_MODELS: Dict[str, List[ModelOption]] = {
|
||||||
|
"quick": [
|
||||||
|
("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"),
|
||||||
|
("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Qwen 3.6 Plus - Flagship vision-language, agentic coding SOTA", "qwen3.6-plus"),
|
||||||
|
("Qwen 3.5 Plus - Previous-gen flagship", "qwen3.5-plus"),
|
||||||
|
("Qwen 3 Max - Specialized for agent programming + tool use", "qwen3-max"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Shared model list for MiniMax's global and CN endpoints (same IDs).
|
||||||
|
# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api.
|
||||||
|
# All M2.x models share a 204,800-token context window.
|
||||||
|
_MINIMAX_MODELS: Dict[str, List[ModelOption]] = {
|
||||||
|
"quick": [
|
||||||
|
("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"),
|
||||||
|
("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"),
|
||||||
|
("MiniMax-M2.1-highspeed - M2.1 highspeed, 204K ctx", "MiniMax-M2.1-highspeed"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("MiniMax-M2.7 - Flagship, SOTA on coding/agent benchmarks, 204K ctx", "MiniMax-M2.7"),
|
||||||
|
("MiniMax-M2.7-highspeed - Same quality as M2.7, ~100 TPS", "MiniMax-M2.7-highspeed"),
|
||||||
|
("MiniMax-M2.5 - Previous-gen flagship, 204K ctx", "MiniMax-M2.5"),
|
||||||
|
("MiniMax-M2.1 - Earlier M2 line, 204K ctx", "MiniMax-M2.1"),
|
||||||
|
("MiniMax-M2 - Base M2, 204K ctx", "MiniMax-M2"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_OPTIONS: ProviderModeOptions = {
|
||||||
|
"openai": {
|
||||||
|
"quick": [
|
||||||
|
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
|
||||||
|
("GPT-5.4 Nano - Cheapest, high-volume tasks", "gpt-5.4-nano"),
|
||||||
|
("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"),
|
||||||
|
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("GPT-5.5 - Latest frontier, 1M context", "gpt-5.5"),
|
||||||
|
("GPT-5.4 - Previous-gen frontier, 1M context, cost-effective", "gpt-5.4"),
|
||||||
|
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
|
||||||
|
("GPT-5.5 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.5-pro"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"anthropic": {
|
||||||
|
"quick": [
|
||||||
|
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||||
|
("Claude Haiku 4.5 - Fastest with near-frontier intelligence", "claude-haiku-4-5"),
|
||||||
|
("Claude Sonnet 4.5 - High-performance for agents and coding", "claude-sonnet-4-5"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Claude Opus 4.7 - Latest frontier, long-running agents and coding", "claude-opus-4-7"),
|
||||||
|
("Claude Opus 4.6 - Frontier intelligence, agents and coding", "claude-opus-4-6"),
|
||||||
|
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
||||||
|
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"google": {
|
||||||
|
"quick": [
|
||||||
|
("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"),
|
||||||
|
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||||
|
("Gemini 3.1 Flash Lite - Most cost-efficient (GA)", "gemini-3.1-flash-lite"),
|
||||||
|
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Gemini 3.1 Pro - Reasoning-first, complex workflows (preview)", "gemini-3.1-pro-preview"),
|
||||||
|
("Gemini 3 Flash - Next-gen fast (preview)", "gemini-3-flash-preview"),
|
||||||
|
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
|
||||||
|
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"xai": {
|
||||||
|
"quick": [
|
||||||
|
("Grok 4.20 (Non-Reasoning) - Latest, speed-optimized", "grok-4.20-non-reasoning"),
|
||||||
|
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||||
|
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Grok 4.20 (Reasoning) - Latest frontier reasoning model", "grok-4.20-reasoning"),
|
||||||
|
("Grok 4 - Flagship (dated build)", "grok-4-0709"),
|
||||||
|
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||||
|
("Grok 4.20 - Auto-select reasoning behavior", "grok-4.20"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"quick": [
|
||||||
|
("DeepSeek V4 Flash - Latest V4 fast model", "deepseek-v4-flash"),
|
||||||
|
("DeepSeek V3.2", "deepseek-chat"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("DeepSeek V4 Pro - Latest V4 flagship model", "deepseek-v4-pro"),
|
||||||
|
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
|
||||||
|
("DeepSeek V3.2", "deepseek-chat"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# Qwen: same model IDs across global (dashscope-intl) and China
|
||||||
|
# (dashscope) endpoints, so the two provider keys share one model list.
|
||||||
|
"qwen": _QWEN_MODELS,
|
||||||
|
"qwen-cn": _QWEN_MODELS,
|
||||||
|
# GLM: Z.AI (international) and BigModel (China) host the same model
|
||||||
|
# IDs; the two provider keys share one model list.
|
||||||
|
"glm": _GLM_MODELS,
|
||||||
|
"glm-cn": _GLM_MODELS,
|
||||||
|
# MiniMax: same model IDs across global (.io) and China (.com) regions,
|
||||||
|
# so the two provider keys share one model list.
|
||||||
|
"minimax": _MINIMAX_MODELS,
|
||||||
|
"minimax-cn": _MINIMAX_MODELS,
|
||||||
|
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||||
|
# Ollama display labels intentionally omit a "local" marker — the
|
||||||
|
# endpoint is now configurable via OLLAMA_BASE_URL, so the same labels
|
||||||
|
# apply whether the user runs ollama-serve on localhost or against a
|
||||||
|
# remote host. The actual resolved endpoint is surfaced separately by
|
||||||
|
# cli.utils.confirm_ollama_endpoint() right after provider selection.
|
||||||
|
# "Custom model ID" lets users pick any model they have pulled via
|
||||||
|
# `ollama pull` beyond the three suggested defaults.
|
||||||
|
"ollama": {
|
||||||
|
"quick": [
|
||||||
|
("Qwen3:latest (8B)", "qwen3:latest"),
|
||||||
|
("GPT-OSS:latest (20B)", "gpt-oss:latest"),
|
||||||
|
("GLM-4.7-Flash:latest (30B)", "glm-4.7-flash:latest"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("GLM-4.7-Flash:latest (30B)", "glm-4.7-flash:latest"),
|
||||||
|
("GPT-OSS:latest (20B)", "gpt-oss:latest"),
|
||||||
|
("Qwen3:latest (8B)", "qwen3:latest"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
|
||||||
|
"""Return shared model options for a provider and selection mode."""
|
||||||
|
return MODEL_OPTIONS[provider.lower()][mode]
|
||||||
|
|
||||||
|
|
||||||
|
def get_known_models() -> Dict[str, List[str]]:
|
||||||
|
"""Build known model names from the shared CLI catalog."""
|
||||||
|
return {
|
||||||
|
provider: sorted(
|
||||||
|
{
|
||||||
|
value
|
||||||
|
for options in mode_options.values()
|
||||||
|
for _, value in options
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for provider, mode_options in MODEL_OPTIONS.items()
|
||||||
|
}
|
||||||
241
tradingagents/llm_clients/openai_client.py
Normal file
241
tradingagents/llm_clients/openai_client.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from .api_key_env import get_api_key_env
|
||||||
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
from .capabilities import get_capabilities
|
||||||
|
from .validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedChatOpenAI(ChatOpenAI):
|
||||||
|
"""ChatOpenAI with normalized content output and capability-aware binding.
|
||||||
|
|
||||||
|
The Responses API returns content as a list of typed blocks
|
||||||
|
(reasoning, text, etc.). ``invoke`` normalizes to string for
|
||||||
|
consistent downstream handling.
|
||||||
|
|
||||||
|
``with_structured_output`` consults the per-model capability table
|
||||||
|
(``capabilities.get_capabilities``) to pick the method and to decide
|
||||||
|
whether ``tool_choice`` may be sent. Models that reject ``tool_choice``
|
||||||
|
(e.g. DeepSeek V4 and reasoner — per their official tool-calling
|
||||||
|
guide) still bind the schema as a tool, but no ``tool_choice``
|
||||||
|
parameter is sent.
|
||||||
|
|
||||||
|
Provider-specific quirks beyond structured-output (e.g. DeepSeek's
|
||||||
|
reasoning_content roundtrip) live in subclasses so this base class
|
||||||
|
stays small.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def invoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
def with_structured_output(self, schema, *, method=None, **kwargs):
|
||||||
|
caps = get_capabilities(self.model_name)
|
||||||
|
if caps.preferred_structured_method == "none":
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.model_name} has no structured-output method available; "
|
||||||
|
f"agent factories will fall back to free-text generation."
|
||||||
|
)
|
||||||
|
method = method or caps.preferred_structured_method
|
||||||
|
# When the model rejects tool_choice, suppress langchain's hardcoded
|
||||||
|
# value. The schema is still bound as a tool — exactly what
|
||||||
|
# DeepSeek's official tool-calling examples do.
|
||||||
|
if method == "function_calling" and not caps.supports_tool_choice:
|
||||||
|
kwargs.setdefault("tool_choice", None)
|
||||||
|
return super().with_structured_output(schema, method=method, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _input_to_messages(input_: Any) -> list:
|
||||||
|
"""Normalise a langchain LLM input to a list of message objects.
|
||||||
|
|
||||||
|
Accepts a list of messages, a ``ChatPromptValue`` (from a
|
||||||
|
ChatPromptTemplate), or anything else (treated as no messages).
|
||||||
|
Used by providers that need to walk the outgoing message history;
|
||||||
|
in particular DeepSeek thinking-mode propagation must work for
|
||||||
|
both bare-list invocations and ChatPromptTemplate-driven ones, so
|
||||||
|
treating only ``list`` here would silently skip half the call sites.
|
||||||
|
"""
|
||||||
|
if isinstance(input_, list):
|
||||||
|
return input_
|
||||||
|
if hasattr(input_, "to_messages"):
|
||||||
|
return input_.to_messages()
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekChatOpenAI(NormalizedChatOpenAI):
|
||||||
|
"""DeepSeek-specific overrides on top of the OpenAI-compatible client.
|
||||||
|
|
||||||
|
Thinking-mode round-trip is the only DeepSeek-specific behavior that
|
||||||
|
stays here. When DeepSeek's thinking models return a response with
|
||||||
|
``reasoning_content``, that field must be echoed back as part of the
|
||||||
|
assistant message on the next turn or the API fails with HTTP 400.
|
||||||
|
``_create_chat_result`` captures it on receive and
|
||||||
|
``_get_request_payload`` re-attaches it on send.
|
||||||
|
|
||||||
|
Tool-choice handling for V4 and reasoner — those models reject the
|
||||||
|
``tool_choice`` parameter — is handled by the capability dispatch in
|
||||||
|
``NormalizedChatOpenAI.with_structured_output``, not here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||||
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||||
|
outgoing = payload.get("messages", [])
|
||||||
|
for message_dict, message in zip(outgoing, _input_to_messages(input_)):
|
||||||
|
if not isinstance(message, AIMessage):
|
||||||
|
continue
|
||||||
|
reasoning = message.additional_kwargs.get("reasoning_content")
|
||||||
|
if reasoning is not None:
|
||||||
|
message_dict["reasoning_content"] = reasoning
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _create_chat_result(self, response, generation_info=None):
|
||||||
|
chat_result = super()._create_chat_result(response, generation_info)
|
||||||
|
response_dict = (
|
||||||
|
response
|
||||||
|
if isinstance(response, dict)
|
||||||
|
else response.model_dump(
|
||||||
|
exclude={"choices": {"__all__": {"message": {"parsed"}}}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for generation, choice in zip(
|
||||||
|
chat_result.generations, response_dict.get("choices", [])
|
||||||
|
):
|
||||||
|
reasoning = choice.get("message", {}).get("reasoning_content")
|
||||||
|
if reasoning is not None:
|
||||||
|
generation.message.additional_kwargs["reasoning_content"] = reasoning
|
||||||
|
return chat_result
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxChatOpenAI(NormalizedChatOpenAI):
|
||||||
|
"""MiniMax-specific overrides on top of the OpenAI-compatible client.
|
||||||
|
|
||||||
|
M2.x reasoning models embed ``<think>...</think>`` blocks directly in
|
||||||
|
``message.content`` by default, which would pollute saved reports.
|
||||||
|
Per platform.minimax.io/docs/api-reference/text-openai-api, setting
|
||||||
|
``reasoning_split=True`` in the request body redirects the thinking
|
||||||
|
block into ``reasoning_details`` so ``content`` stays clean.
|
||||||
|
|
||||||
|
Tool-choice handling for M2.x — those models accept only the string
|
||||||
|
enum ``{"none", "auto"}`` and reject langchain's function-spec dict —
|
||||||
|
is handled by the capability dispatch in
|
||||||
|
``NormalizedChatOpenAI.with_structured_output``, not here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||||
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||||
|
payload.setdefault("reasoning_split", True)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
# Kwargs forwarded from user config to ChatOpenAI
|
||||||
|
_PASSTHROUGH_KWARGS = (
|
||||||
|
"timeout", "max_retries", "reasoning_effort",
|
||||||
|
"api_key", "callbacks", "http_client", "http_async_client",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provider base URLs. API-key env vars live in api_key_env.PROVIDER_API_KEY_ENV
|
||||||
|
# (one canonical mapping consulted by both this client and the CLI's
|
||||||
|
# interactive key-prompt). Dual-region providers (qwen/glm/minimax) keep
|
||||||
|
# separate endpoints because international and China accounts cannot share
|
||||||
|
# credentials (#758).
|
||||||
|
_PROVIDER_BASE_URL = {
|
||||||
|
"xai": "https://api.x.ai/v1",
|
||||||
|
"deepseek": "https://api.deepseek.com",
|
||||||
|
"qwen": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||||
|
"qwen-cn": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
"glm": "https://api.z.ai/api/paas/v4/",
|
||||||
|
"glm-cn": "https://open.bigmodel.cn/api/paas/v4/",
|
||||||
|
"minimax": "https://api.minimax.io/v1",
|
||||||
|
"minimax-cn": "https://api.minimaxi.com/v1",
|
||||||
|
"openrouter": "https://openrouter.ai/api/v1",
|
||||||
|
"ollama": "http://localhost:11434/v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_base_url(provider: str) -> Optional[str]:
|
||||||
|
"""Default base URL for ``provider``, with env-var overrides where defined.
|
||||||
|
|
||||||
|
Currently only Ollama supports an env-var override (``OLLAMA_BASE_URL``),
|
||||||
|
matching the convention in the broader Ollama tooling ecosystem so users
|
||||||
|
can point at a remote ollama-serve without editing code. The check is
|
||||||
|
call-time, not import-time, so tests that monkeypatch the env after
|
||||||
|
import behave correctly.
|
||||||
|
"""
|
||||||
|
if provider == "ollama":
|
||||||
|
env_url = os.environ.get("OLLAMA_BASE_URL")
|
||||||
|
if env_url:
|
||||||
|
return env_url
|
||||||
|
return _PROVIDER_BASE_URL.get(provider)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIClient(BaseLLMClient):
|
||||||
|
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
||||||
|
|
||||||
|
For native OpenAI models, uses the Responses API (/v1/responses) which
|
||||||
|
supports reasoning_effort with function tools across all model families
|
||||||
|
(GPT-4.1, GPT-5). Third-party compatible providers (xAI, OpenRouter,
|
||||||
|
Ollama) use standard Chat Completions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
provider: str = "openai",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
self.provider = provider.lower()
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured ChatOpenAI instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
|
# Provider-specific base URL and auth. An explicit base_url on the
|
||||||
|
# client (e.g. a corporate proxy) takes precedence over the
|
||||||
|
# provider default so users can route through their own gateway.
|
||||||
|
if self.provider in _PROVIDER_BASE_URL:
|
||||||
|
llm_kwargs["base_url"] = self.base_url or _resolve_provider_base_url(self.provider)
|
||||||
|
api_key_env = get_api_key_env(self.provider)
|
||||||
|
if api_key_env:
|
||||||
|
api_key = os.environ.get(api_key_env)
|
||||||
|
if api_key:
|
||||||
|
llm_kwargs["api_key"] = api_key
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"API key for provider '{self.provider}' is not set. "
|
||||||
|
f"Please set the {api_key_env} environment variable "
|
||||||
|
f"(e.g. add {api_key_env}=your_key to your .env file)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llm_kwargs["api_key"] = "ollama"
|
||||||
|
elif self.base_url:
|
||||||
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
|
# Forward user-provided kwargs
|
||||||
|
for key in _PASSTHROUGH_KWARGS:
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
# Native OpenAI: use Responses API for consistent behavior across
|
||||||
|
# all model families. Third-party providers use Chat Completions.
|
||||||
|
if self.provider == "openai":
|
||||||
|
llm_kwargs["use_responses_api"] = True
|
||||||
|
|
||||||
|
# Provider-specific quirks live in their own subclasses so the
|
||||||
|
# base NormalizedChatOpenAI stays free of provider branches.
|
||||||
|
if self.provider == "deepseek":
|
||||||
|
chat_cls = DeepSeekChatOpenAI
|
||||||
|
elif self.provider in ("minimax", "minimax-cn"):
|
||||||
|
chat_cls = MinimaxChatOpenAI
|
||||||
|
else:
|
||||||
|
chat_cls = NormalizedChatOpenAI
|
||||||
|
return chat_cls(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate model for the provider."""
|
||||||
|
return validate_model(self.provider, self.model)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user