mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-17 05:16:14 +03:00
Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c15200dc28 | ||
|
|
7aef10acbd | ||
|
|
03600f3121 | ||
|
|
6b6177ebf7 | ||
|
|
e3bc872982 | ||
|
|
cbc5f67d42 | ||
|
|
3cddf1e331 | ||
|
|
308757c999 | ||
|
|
eeb84aa63b | ||
|
|
9fd54f8368 | ||
|
|
7df18fc912 | ||
|
|
db059034a2 | ||
|
|
ddfb840ecf | ||
|
|
895ed130f9 | ||
|
|
295e84cd54 | ||
|
|
20d3b0782f | ||
|
|
4e7821d574 | ||
|
|
0c1231a405 | ||
|
|
e4be7cc5a3 | ||
|
|
a597063747 | ||
|
|
dab07688fb | ||
|
|
65608831f8 | ||
|
|
76add9048f | ||
|
|
7c8fe2fe9f | ||
|
|
2a58c2208f | ||
|
|
04f434e86d | ||
|
|
2e67782f20 | ||
|
|
1ff3f07a73 | ||
|
|
2f85be624e | ||
|
|
c93b92c7a4 | ||
|
|
d6762d6095 | ||
|
|
8694bd070d | ||
|
|
2c9f1bfe65 | ||
|
|
8a22594607 | ||
|
|
47cbb321fe | ||
|
|
e80636fc0e | ||
|
|
a66aa8fb94 | ||
|
|
3543e5397e | ||
|
|
d7b40a2a5c | ||
|
|
61522e103e | ||
|
|
e848b5e812 | ||
|
|
3e5e99b368 | ||
|
|
a2e7ac1599 | ||
|
|
b16fe53efe | ||
|
|
249caba06f | ||
|
|
a2f343bb54 | ||
|
|
5bae826749 | ||
|
|
99ec63f966 | ||
|
|
e7ec980021 | ||
|
|
f4519bcb84 | ||
|
|
4300b68f19 | ||
|
|
2d2c9e6d66 |
25
.env.example
25
.env.example
@@ -11,6 +11,24 @@ ZHIPU_CN_API_KEY=
|
|||||||
MINIMAX_API_KEY=
|
MINIMAX_API_KEY=
|
||||||
MINIMAX_CN_API_KEY=
|
MINIMAX_CN_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
MISTRAL_API_KEY=
|
||||||
|
MOONSHOT_API_KEY=
|
||||||
|
GROQ_API_KEY=
|
||||||
|
NVIDIA_API_KEY=
|
||||||
|
|
||||||
|
# FRED (Federal Reserve macro data: rates, inflation, labor, growth). Free key: https://fred.stlouisfed.org/docs/api/api_key.html
|
||||||
|
#FRED_API_KEY=
|
||||||
|
|
||||||
|
# Optional: a custom OpenAI-compatible endpoint (vLLM, LM Studio, llama.cpp,
|
||||||
|
# relay). Select provider "openai_compatible" and set the base URL; the key is
|
||||||
|
# optional (local servers need none).
|
||||||
|
#OPENAI_COMPATIBLE_API_KEY=
|
||||||
|
|
||||||
|
# AWS Bedrock (provider "bedrock", install with: pip install ".[bedrock]").
|
||||||
|
# Auth uses the standard AWS credential chain; set the region (and optionally a
|
||||||
|
# named profile). No single API key.
|
||||||
|
#AWS_DEFAULT_REGION=us-west-2
|
||||||
|
#AWS_PROFILE=
|
||||||
|
|
||||||
# Optional: point at a remote Ollama server. When unset, defaults to
|
# Optional: point at a remote Ollama server. When unset, defaults to
|
||||||
# the local instance at http://localhost:11434/v1. Convention follows
|
# the local instance at http://localhost:11434/v1. Convention follows
|
||||||
@@ -22,6 +40,9 @@ OPENROUTER_API_KEY=
|
|||||||
# Any TRADINGAGENTS_* variable below, when set, replaces the matching key
|
# Any TRADINGAGENTS_* variable below, when set, replaces the matching key
|
||||||
# in tradingagents/default_config.py. Values are coerced to the type of
|
# in tradingagents/default_config.py. Values are coerced to the type of
|
||||||
# the existing default (bool / int / str), so "true"/"3" work as expected.
|
# the existing default (bool / int / str), so "true"/"3" work as expected.
|
||||||
|
# In the CLI, setting the LLM provider / models / backend URL / language
|
||||||
|
# also skips the matching interactive selection step (useful for
|
||||||
|
# OpenAI-compatible endpoints like opencode or LM Studio, and unattended runs).
|
||||||
#TRADINGAGENTS_LLM_PROVIDER=openai
|
#TRADINGAGENTS_LLM_PROVIDER=openai
|
||||||
#TRADINGAGENTS_DEEP_THINK_LLM=gpt-5.4
|
#TRADINGAGENTS_DEEP_THINK_LLM=gpt-5.4
|
||||||
#TRADINGAGENTS_QUICK_THINK_LLM=gpt-5.4-mini
|
#TRADINGAGENTS_QUICK_THINK_LLM=gpt-5.4-mini
|
||||||
@@ -30,3 +51,7 @@ OPENROUTER_API_KEY=
|
|||||||
#TRADINGAGENTS_MAX_DEBATE_ROUNDS=1
|
#TRADINGAGENTS_MAX_DEBATE_ROUNDS=1
|
||||||
#TRADINGAGENTS_MAX_RISK_ROUNDS=1
|
#TRADINGAGENTS_MAX_RISK_ROUNDS=1
|
||||||
#TRADINGAGENTS_CHECKPOINT_ENABLED=false
|
#TRADINGAGENTS_CHECKPOINT_ENABLED=false
|
||||||
|
# Sampling temperature (lower = less run-to-run variation on models that
|
||||||
|
# honor it). Unset leaves each provider at its default. See the README
|
||||||
|
# "Reproducibility" note — no setting makes LLM output fully deterministic.
|
||||||
|
#TRADINGAGENTS_TEMPERATURE=0.0
|
||||||
|
|||||||
61
.github/workflows/ci.yml
vendored
Normal file
61
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ci-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: tests (py${{ matrix.python-version }})
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install (with dev extras)
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
- name: Run test suite
|
||||||
|
run: pytest -q
|
||||||
|
|
||||||
|
smoke-install:
|
||||||
|
name: clean-install smoke
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- name: Fresh install (no dev extras) and import
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install .
|
||||||
|
# Catches undeclared runtime deps (e.g. #994 python-dotenv): a bare
|
||||||
|
# install must import the package and the CLI module.
|
||||||
|
python -c "import tradingagents, cli.main; print('clean-install import OK')"
|
||||||
|
|
||||||
|
lint:
|
||||||
|
name: ruff (strict, full repo)
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- name: Install ruff
|
||||||
|
run: pip install "ruff>=0.15"
|
||||||
|
- name: Lint the repository
|
||||||
|
# The repo is fully clean under the strict select, so we lint everything
|
||||||
|
# (results/ and worklog/ are excluded via pyproject extend-exclude).
|
||||||
|
run: ruff check .
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -217,3 +217,7 @@ __marimo__/
|
|||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
**/data_cache/
|
**/data_cache/
|
||||||
|
|
||||||
|
# Enterprise env file (secrets) and generated run reports
|
||||||
|
.env.enterprise
|
||||||
|
reports/
|
||||||
|
|||||||
56
README.md
56
README.md
@@ -65,7 +65,7 @@ TradingAgents is a multi-agent trading framework that mirrors the dynamics of re
|
|||||||
|
|
||||||
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
|
> TradingAgents framework is designed for research purposes. Trading performance may vary based on many factors, including the chosen backbone language models, model temperature, trading periods, the quality of data, and other non-deterministic factors. [It is not intended as financial, investment, or trading advice.](https://tauric.ai/disclaimer/)
|
||||||
|
|
||||||
Our framework decomposes complex trading tasks into specialized roles. This ensures the system achieves a robust, scalable approach to market analysis and decision-making.
|
Our framework decomposes complex trading tasks into specialized roles.
|
||||||
|
|
||||||
### Analyst Team
|
### Analyst Team
|
||||||
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
|
- Fundamentals Analyst: Evaluates company financials and performance metrics, identifying intrinsic values and potential red flags.
|
||||||
@@ -85,7 +85,7 @@ Our framework decomposes complex trading tasks into specialized roles. This ensu
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
### Trader Agent
|
### Trader Agent
|
||||||
- Composes reports from the analysts and researchers to make informed trading decisions. It determines the timing and magnitude of trades based on comprehensive market insights.
|
- Composes reports from the analysts and researchers to make informed trading decisions, determining the timing and magnitude of trades.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/trader.png" width="70%" style="display: inline-block; margin: 0 2%;">
|
||||||
@@ -111,7 +111,7 @@ cd TradingAgents
|
|||||||
|
|
||||||
Create a virtual environment in any of your favorite environment managers:
|
Create a virtual environment in any of your favorite environment managers:
|
||||||
```bash
|
```bash
|
||||||
conda create -n tradingagents python=3.13
|
conda create -n tradingagents python=3.12
|
||||||
conda activate tradingagents
|
conda activate tradingagents
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -147,16 +147,20 @@ export DASHSCOPE_API_KEY=... # Qwen — International (dashscope-intl.aliy
|
|||||||
export DASHSCOPE_CN_API_KEY=... # Qwen — China (dashscope.aliyuncs.com)
|
export DASHSCOPE_CN_API_KEY=... # Qwen — China (dashscope.aliyuncs.com)
|
||||||
export ZHIPU_API_KEY=... # GLM via Z.AI (international)
|
export ZHIPU_API_KEY=... # GLM via Z.AI (international)
|
||||||
export ZHIPU_CN_API_KEY=... # GLM via BigModel (China, open.bigmodel.cn)
|
export ZHIPU_CN_API_KEY=... # GLM via BigModel (China, open.bigmodel.cn)
|
||||||
export MINIMAX_API_KEY=... # MiniMax — Global (api.minimax.io, M2.x, 204K ctx)
|
export MINIMAX_API_KEY=... # MiniMax — Global (api.minimax.io)
|
||||||
export MINIMAX_CN_API_KEY=... # MiniMax — China (api.minimaxi.com, M2.x, 204K ctx)
|
export MINIMAX_CN_API_KEY=... # MiniMax — China (api.minimaxi.com)
|
||||||
export OPENROUTER_API_KEY=... # OpenRouter
|
export OPENROUTER_API_KEY=... # OpenRouter
|
||||||
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||||
```
|
```
|
||||||
|
|
||||||
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
For Azure OpenAI, copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
||||||
|
|
||||||
|
For AWS Bedrock, install the extra with `pip install ".[bedrock]"`, set `llm_provider: "bedrock"`, configure AWS credentials (environment variables, `~/.aws/credentials`, or an IAM role) and `AWS_DEFAULT_REGION`, and use a Bedrock model ID, e.g. `us.anthropic.claude-opus-4-8-v1:0`.
|
||||||
|
|
||||||
For local models, configure Ollama with `llm_provider: "ollama"`. The default endpoint is `http://localhost:11434/v1`; set `OLLAMA_BASE_URL` to point at a remote `ollama-serve`. Pull models with `ollama pull <name>`, and pick "Custom model ID" in the CLI for any model not listed by default.
|
For local models, configure Ollama with `llm_provider: "ollama"`. The default endpoint is `http://localhost:11434/v1`; set `OLLAMA_BASE_URL` to point at a remote `ollama-serve`. Pull models with `ollama pull <name>`, and pick "Custom model ID" in the CLI for any model not listed by default.
|
||||||
|
|
||||||
|
For any other OpenAI-compatible server (vLLM, LM Studio, llama.cpp, or a custom relay), use `llm_provider: "openai_compatible"` and set the endpoint via `backend_url` (or `TRADINGAGENTS_LLM_BACKEND_URL`), e.g. `http://localhost:8000/v1` for vLLM or `http://localhost:1234/v1` for LM Studio. The model is whatever your server serves. No key is needed for local servers; set `OPENAI_COMPATIBLE_API_KEY` when the endpoint requires one.
|
||||||
|
|
||||||
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||||
```bash
|
```bash
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
@@ -171,6 +175,16 @@ python -m cli.main # alternative: run directly from source
|
|||||||
```
|
```
|
||||||
You will see a screen where you can select your desired tickers, analysis date, LLM provider, research depth, and more.
|
You will see a screen where you can select your desired tickers, analysis date, LLM provider, research depth, and more.
|
||||||
|
|
||||||
|
### Markets and tickers
|
||||||
|
|
||||||
|
TradingAgents works with any market Yahoo Finance covers, using the exchange-suffixed ticker. Company identity and the alpha benchmark resolve automatically per market.
|
||||||
|
|
||||||
|
- US: `AAPL`, `SPY`
|
||||||
|
- Hong Kong: `0700.HK` · Tokyo: `7203.T` · London: `AZN.L`
|
||||||
|
- India: `RELIANCE.NS`, `.BO` · Canada: `.TO` · Australia: `.AX`
|
||||||
|
- China A-shares: Shanghai `.SS`, Shenzhen `.SZ` (e.g. `600519.SS` for Kweichow Moutai)
|
||||||
|
- Crypto: `BTC-USD`, `ETH-USD`
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||||
</p>
|
</p>
|
||||||
@@ -213,8 +227,8 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, qwen-cn, glm, glm-cn, minimax, minimax-cn, openrouter, ollama, azure
|
config["llm_provider"] = "openai" # e.g. openai, google, anthropic, deepseek, groq, ollama; openai_compatible covers any OpenAI-compatible endpoint (vLLM, LM Studio, llama.cpp, ...)
|
||||||
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
|
config["deep_think_llm"] = "gpt-5.5" # Model for complex reasoning
|
||||||
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
||||||
config["max_debate_rounds"] = 2
|
config["max_debate_rounds"] = 2
|
||||||
|
|
||||||
@@ -253,11 +267,31 @@ ta = TradingAgentsGraph(config=config)
|
|||||||
_, decision = ta.propagate("NVDA", "2026-01-15")
|
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Reproducibility
|
||||||
|
|
||||||
|
TradingAgents is LLM-driven, so two runs of the same ticker and date can differ. This is expected for a research tool built on language models, not a defect. The variation comes from a few distinct sources, and it helps to separate them.
|
||||||
|
|
||||||
|
Language model sampling is non-deterministic. Even at a fixed temperature, providers do not guarantee byte-identical output across calls, and reasoning models (the default GPT-5.x family, and any thinking-mode model) vary the most because their internal reasoning is itself sampled.
|
||||||
|
|
||||||
|
Live data moves. News, StockTwits, and Reddit return different content as time passes, so a run today sees different inputs than a run last week even for the same historical trade date. Pin the analysis date to hold the price and indicator window fixed, but the social and news sources still reflect "now".
|
||||||
|
|
||||||
|
To reduce variation you can lower the sampling temperature. Set `temperature` in your config (or `TRADINGAGENTS_TEMPERATURE` in `.env`); lower values make models that honor it more repeatable. Reasoning models largely ignore temperature, so for tighter reproducibility pair a low temperature with a non-reasoning model such as `gpt-4.1`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = DEFAULT_CONFIG.copy()
|
||||||
|
config["llm_provider"] = "openai"
|
||||||
|
config["deep_think_llm"] = "gpt-4.1" # non-reasoning model honors temperature
|
||||||
|
config["quick_think_llm"] = "gpt-4.1"
|
||||||
|
config["temperature"] = 0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
What does not vary anymore: the analyzed company identity is resolved deterministically from the ticker before any agent runs, and the market analyst grounds exact price and indicator claims in a verified data snapshot. Earlier reports of "different companies" or fabricated price levels across runs are addressed by these two mechanisms.
|
||||||
|
|
||||||
|
Backtest results are not guaranteed to match any published figure. Returns depend on the model, the temperature, the date range, data quality, and the sampling above. Treat the framework as a research scaffold for studying multi-agent analysis, not as a strategy with a fixed, replicable return.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
Contributions are welcome: bug fixes, documentation, and feature ideas; past contributions are credited per release in [`CHANGELOG.md`](CHANGELOG.md).
|
||||||
|
|
||||||
Past contributions, including code, design feedback, and bug reports, are credited per release in [`CHANGELOG.md`](CHANGELOG.md).
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import getpass
|
import getpass
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|||||||
304
cli/main.py
304
cli/main.py
@@ -1,31 +1,53 @@
|
|||||||
from typing import Optional
|
|
||||||
import datetime
|
import datetime
|
||||||
import typer
|
import os
|
||||||
import questionary
|
|
||||||
from pathlib import Path
|
|
||||||
from functools import wraps
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.spinner import Spinner
|
|
||||||
from rich.live import Live
|
|
||||||
from rich.columns import Columns
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
from rich.layout import Layout
|
|
||||||
from rich.text import Text
|
|
||||||
from rich.table import Table
|
|
||||||
from collections import deque
|
|
||||||
import time
|
import time
|
||||||
from rich.tree import Tree
|
from collections import deque
|
||||||
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
from rich import box
|
from rich import box
|
||||||
from rich.align import Align
|
from rich.align import Align
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.layout import Layout
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
|
from rich.spinner import Spinner
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from cli.announcements import display_announcements, fetch_announcements
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
from cli.models import AnalystType
|
|
||||||
from cli.utils import *
|
|
||||||
from cli.announcements import fetch_announcements, display_announcements
|
|
||||||
from cli.stats_handler import StatsCallbackHandler
|
from cli.stats_handler import StatsCallbackHandler
|
||||||
|
from cli.utils import (
|
||||||
|
ask_anthropic_effort,
|
||||||
|
ask_gemini_thinking_config,
|
||||||
|
ask_glm_region,
|
||||||
|
ask_minimax_region,
|
||||||
|
ask_openai_reasoning_effort,
|
||||||
|
ask_output_language,
|
||||||
|
ask_qwen_region,
|
||||||
|
confirm_ollama_endpoint,
|
||||||
|
detect_asset_type,
|
||||||
|
ensure_api_key,
|
||||||
|
get_ticker,
|
||||||
|
prompt_openai_compatible_url,
|
||||||
|
resolve_backend_url,
|
||||||
|
select_analysts,
|
||||||
|
select_deep_thinking_agent,
|
||||||
|
select_llm_provider,
|
||||||
|
select_research_depth,
|
||||||
|
select_shallow_thinking_agent,
|
||||||
|
)
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
from tradingagents.graph.analyst_execution import (
|
||||||
|
AnalystWallTimeTracker,
|
||||||
|
build_analyst_execution_plan,
|
||||||
|
get_initial_analyst_node,
|
||||||
|
sync_analyst_tracker_from_chunk,
|
||||||
|
)
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@@ -162,7 +184,7 @@ class MessageBuffer:
|
|||||||
if content is not None:
|
if content is not None:
|
||||||
latest_section = section
|
latest_section = section
|
||||||
latest_content = content
|
latest_content = content
|
||||||
|
|
||||||
if latest_section and latest_content:
|
if latest_section and latest_content:
|
||||||
# Format the current section for display
|
# Format the current section for display
|
||||||
section_titles = {
|
section_titles = {
|
||||||
@@ -459,7 +481,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
|
|||||||
def get_user_selections():
|
def get_user_selections():
|
||||||
"""Get all user selections before starting the analysis display."""
|
"""Get all user selections before starting the analysis display."""
|
||||||
# Display ASCII art welcome message
|
# Display ASCII art welcome message
|
||||||
with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f:
|
with open(Path(__file__).parent / "static" / "welcome.txt", encoding="utf-8") as f:
|
||||||
welcome_ascii = f.read()
|
welcome_ascii = f.read()
|
||||||
|
|
||||||
# Create welcome box content
|
# Create welcome box content
|
||||||
@@ -499,11 +521,18 @@ def get_user_selections():
|
|||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 1: Ticker Symbol",
|
"Step 1: Ticker Symbol",
|
||||||
"Enter the exact ticker symbol to analyze, including exchange suffix when needed (examples: SPY, CNC.TO, 7203.T, 0700.HK)",
|
"Enter the ticker, with exchange suffix when needed (e.g. SPY, 0700.HK, BTC-USD)",
|
||||||
"SPY",
|
"SPY",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
selected_ticker = get_ticker()
|
selected_ticker = get_ticker()
|
||||||
|
asset_type = detect_asset_type(selected_ticker)
|
||||||
|
# Only announce when it's not the default stock path, to avoid printing
|
||||||
|
# "stock" on every run.
|
||||||
|
if asset_type.value != "stock":
|
||||||
|
console.print(
|
||||||
|
f"[green]Detected asset type:[/green] {asset_type.value}"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 2: Analysis date
|
# Step 2: Analysis date
|
||||||
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
@@ -516,14 +545,20 @@ def get_user_selections():
|
|||||||
)
|
)
|
||||||
analysis_date = get_analysis_date()
|
analysis_date = get_analysis_date()
|
||||||
|
|
||||||
# Step 3: Output language
|
# Step 3: Output language (skipped when set via TRADINGAGENTS_OUTPUT_LANGUAGE)
|
||||||
console.print(
|
if os.environ.get("TRADINGAGENTS_OUTPUT_LANGUAGE"):
|
||||||
create_question_box(
|
output_language = DEFAULT_CONFIG["output_language"]
|
||||||
"Step 3: Output Language",
|
console.print(
|
||||||
"Select the language for analyst reports and final decision"
|
f"[green]✓ Output language from environment:[/green] {output_language}"
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
output_language = ask_output_language()
|
console.print(
|
||||||
|
create_question_box(
|
||||||
|
"Step 3: Output Language",
|
||||||
|
"Select the language for analyst reports and final decision"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output_language = ask_output_language()
|
||||||
|
|
||||||
# Step 4: Select analysts
|
# Step 4: Select analysts
|
||||||
console.print(
|
console.print(
|
||||||
@@ -531,7 +566,7 @@ def get_user_selections():
|
|||||||
"Step 4: Analysts Team", "Select your LLM analyst agents for the analysis"
|
"Step 4: Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
selected_analysts = select_analysts()
|
selected_analysts = select_analysts(asset_type)
|
||||||
console.print(
|
console.print(
|
||||||
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
||||||
)
|
)
|
||||||
@@ -544,42 +579,75 @@ def get_user_selections():
|
|||||||
)
|
)
|
||||||
selected_research_depth = select_research_depth()
|
selected_research_depth = select_research_depth()
|
||||||
|
|
||||||
# Step 6: LLM Provider
|
# Step 6: LLM Provider (skipped when set via TRADINGAGENTS_LLM_PROVIDER).
|
||||||
console.print(
|
# The backend URL comes from TRADINGAGENTS_LLM_BACKEND_URL when set,
|
||||||
create_question_box(
|
# otherwise the provider's default endpoint — the same value the menu
|
||||||
"Step 6: LLM Provider", "Select your LLM provider"
|
# would have picked.
|
||||||
|
provider_from_env = bool(os.environ.get("TRADINGAGENTS_LLM_PROVIDER"))
|
||||||
|
if provider_from_env:
|
||||||
|
selected_llm_provider = DEFAULT_CONFIG["llm_provider"].lower()
|
||||||
|
backend_url = resolve_backend_url(
|
||||||
|
selected_llm_provider, env_url=DEFAULT_CONFIG["backend_url"]
|
||||||
)
|
)
|
||||||
)
|
console.print(f"[green]✓ LLM provider from environment:[/green] {selected_llm_provider}")
|
||||||
selected_llm_provider, backend_url = select_llm_provider()
|
console.print(f"[green]✓ Backend URL:[/green] {backend_url}")
|
||||||
|
# Still confirm/persist the API key so the run doesn't fail later.
|
||||||
# Providers with regional endpoints prompt for the region as a secondary
|
ensure_api_key(selected_llm_provider)
|
||||||
# step so the main dropdown stays clean (mainland China and international
|
else:
|
||||||
# accounts cannot share API keys).
|
console.print(
|
||||||
if selected_llm_provider == "qwen":
|
create_question_box(
|
||||||
selected_llm_provider, backend_url = ask_qwen_region()
|
"Step 6: LLM Provider", "Select your LLM provider"
|
||||||
elif selected_llm_provider == "minimax":
|
)
|
||||||
selected_llm_provider, backend_url = ask_minimax_region()
|
|
||||||
elif selected_llm_provider == "glm":
|
|
||||||
selected_llm_provider, backend_url = ask_glm_region()
|
|
||||||
|
|
||||||
# For Ollama, surface the resolved endpoint (OLLAMA_BASE_URL vs default)
|
|
||||||
# before model selection so it's obvious where we're connecting.
|
|
||||||
if selected_llm_provider == "ollama":
|
|
||||||
confirm_ollama_endpoint(backend_url)
|
|
||||||
|
|
||||||
# Confirm the provider's API key is present; prompt the user to paste
|
|
||||||
# one and persist it to .env if it's missing, so the analysis run
|
|
||||||
# doesn't fail later at the first API call.
|
|
||||||
ensure_api_key(selected_llm_provider)
|
|
||||||
|
|
||||||
# Step 7: Thinking agents
|
|
||||||
console.print(
|
|
||||||
create_question_box(
|
|
||||||
"Step 7: Thinking Agents", "Select your thinking agents for analysis"
|
|
||||||
)
|
)
|
||||||
)
|
selected_llm_provider, backend_url = select_llm_provider()
|
||||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
|
||||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
# Providers with regional endpoints prompt for the region as a secondary
|
||||||
|
# step so the main dropdown stays clean (mainland China and international
|
||||||
|
# accounts cannot share API keys).
|
||||||
|
if selected_llm_provider == "qwen":
|
||||||
|
selected_llm_provider, backend_url = ask_qwen_region()
|
||||||
|
elif selected_llm_provider == "minimax":
|
||||||
|
selected_llm_provider, backend_url = ask_minimax_region()
|
||||||
|
elif selected_llm_provider == "glm":
|
||||||
|
selected_llm_provider, backend_url = ask_glm_region()
|
||||||
|
|
||||||
|
# Honor an explicit env backend URL even when the provider was chosen
|
||||||
|
# interactively, so it isn't overwritten by the menu default (#978).
|
||||||
|
backend_url = resolve_backend_url(
|
||||||
|
selected_llm_provider, backend_url, env_url=DEFAULT_CONFIG["backend_url"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# The generic OpenAI-compatible endpoint has no default; ask for it if
|
||||||
|
# neither the menu nor the environment supplied one.
|
||||||
|
if selected_llm_provider == "openai_compatible" and not backend_url:
|
||||||
|
backend_url = prompt_openai_compatible_url()
|
||||||
|
|
||||||
|
# For Ollama, surface the resolved endpoint (OLLAMA_BASE_URL vs default)
|
||||||
|
# before model selection so it's obvious where we're connecting.
|
||||||
|
if selected_llm_provider == "ollama":
|
||||||
|
confirm_ollama_endpoint(backend_url)
|
||||||
|
|
||||||
|
# Confirm the provider's API key is present; prompt the user to paste
|
||||||
|
# one and persist it to .env if it's missing, so the analysis run
|
||||||
|
# doesn't fail later at the first API call.
|
||||||
|
ensure_api_key(selected_llm_provider)
|
||||||
|
|
||||||
|
# Step 7: Thinking agents (skipped when either model is set via environment)
|
||||||
|
if os.environ.get("TRADINGAGENTS_QUICK_THINK_LLM") or os.environ.get("TRADINGAGENTS_DEEP_THINK_LLM"):
|
||||||
|
selected_shallow_thinker = DEFAULT_CONFIG["quick_think_llm"]
|
||||||
|
selected_deep_thinker = DEFAULT_CONFIG["deep_think_llm"]
|
||||||
|
console.print(
|
||||||
|
f"[green]✓ Thinking agents from environment:[/green] "
|
||||||
|
f"quick={selected_shallow_thinker}, deep={selected_deep_thinker}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
console.print(
|
||||||
|
create_question_box(
|
||||||
|
"Step 7: Thinking Agents", "Select your thinking agents for analysis"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||||
|
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
||||||
|
|
||||||
# Step 8: Provider-specific thinking configuration
|
# Step 8: Provider-specific thinking configuration
|
||||||
thinking_level = None
|
thinking_level = None
|
||||||
@@ -587,7 +655,14 @@ def get_user_selections():
|
|||||||
anthropic_effort = None
|
anthropic_effort = None
|
||||||
|
|
||||||
provider_lower = selected_llm_provider.lower()
|
provider_lower = selected_llm_provider.lower()
|
||||||
if provider_lower == "google":
|
# When the provider is configured via environment we keep the run fully
|
||||||
|
# non-interactive and use the config defaults (None = each provider's own
|
||||||
|
# default reasoning/thinking behavior) instead of prompting.
|
||||||
|
if provider_from_env:
|
||||||
|
thinking_level = DEFAULT_CONFIG["google_thinking_level"]
|
||||||
|
reasoning_effort = DEFAULT_CONFIG["openai_reasoning_effort"]
|
||||||
|
anthropic_effort = DEFAULT_CONFIG["anthropic_effort"]
|
||||||
|
elif provider_lower == "google":
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 8: Thinking Mode",
|
"Step 8: Thinking Mode",
|
||||||
@@ -614,6 +689,7 @@ def get_user_selections():
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"ticker": selected_ticker,
|
"ticker": selected_ticker,
|
||||||
|
"asset_type": asset_type.value,
|
||||||
"analysis_date": analysis_date,
|
"analysis_date": analysis_date,
|
||||||
"analysts": selected_analysts,
|
"analysts": selected_analysts,
|
||||||
"research_depth": selected_research_depth,
|
"research_depth": selected_research_depth,
|
||||||
@@ -628,29 +704,6 @@ def get_user_selections():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_ticker():
|
|
||||||
"""Get ticker symbol from user input, preserving exchange suffixes."""
|
|
||||||
# typer.prompt strips trailing dot-suffixes on some shells (e.g. 000404.SH
|
|
||||||
# collapses to 000404). questionary.text reads the raw line.
|
|
||||||
ticker = questionary.text(
|
|
||||||
"",
|
|
||||||
validate=lambda value: (
|
|
||||||
not value.strip()
|
|
||||||
or (
|
|
||||||
all(ch.isalnum() or ch in "._-^" for ch in value.strip())
|
|
||||||
and len(value.strip()) <= 32
|
|
||||||
)
|
|
||||||
)
|
|
||||||
or "Please enter a valid ticker symbol, e.g. AAPL, 000404.SZ, 0700.HK.",
|
|
||||||
).ask()
|
|
||||||
|
|
||||||
if ticker is None:
|
|
||||||
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
return (ticker.strip() or "SPY").upper()
|
|
||||||
|
|
||||||
|
|
||||||
def get_analysis_date():
|
def get_analysis_date():
|
||||||
"""Get the analysis date from user input."""
|
"""Get the analysis date from user input."""
|
||||||
while True:
|
while True:
|
||||||
@@ -844,7 +897,7 @@ ANALYST_REPORT_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def update_analyst_statuses(message_buffer, chunk):
|
def update_analyst_statuses(message_buffer, chunk, wall_time_tracker=None):
|
||||||
"""Update analyst statuses based on accumulated report state.
|
"""Update analyst statuses based on accumulated report state.
|
||||||
|
|
||||||
Logic:
|
Logic:
|
||||||
@@ -858,6 +911,9 @@ def update_analyst_statuses(message_buffer, chunk):
|
|||||||
selected = message_buffer.selected_analysts
|
selected = message_buffer.selected_analysts
|
||||||
found_active = False
|
found_active = False
|
||||||
|
|
||||||
|
if wall_time_tracker is not None:
|
||||||
|
sync_analyst_tracker_from_chunk(wall_time_tracker, chunk)
|
||||||
|
|
||||||
for analyst_key in ANALYST_ORDER:
|
for analyst_key in ANALYST_ORDER:
|
||||||
if analyst_key not in selected:
|
if analyst_key not in selected:
|
||||||
continue
|
continue
|
||||||
@@ -881,9 +937,12 @@ def update_analyst_statuses(message_buffer, chunk):
|
|||||||
message_buffer.update_agent_status(agent_name, "pending")
|
message_buffer.update_agent_status(agent_name, "pending")
|
||||||
|
|
||||||
# When all analysts complete, transition research team to in_progress
|
# When all analysts complete, transition research team to in_progress
|
||||||
if not found_active and selected:
|
if (
|
||||||
if message_buffer.agent_status.get("Bull Researcher") == "pending":
|
not found_active
|
||||||
message_buffer.update_agent_status("Bull Researcher", "in_progress")
|
and selected
|
||||||
|
and message_buffer.agent_status.get("Bull Researcher") == "pending"
|
||||||
|
):
|
||||||
|
message_buffer.update_agent_status("Bull Researcher", "in_progress")
|
||||||
|
|
||||||
def extract_content_string(content):
|
def extract_content_string(content):
|
||||||
"""Extract string content from various message formats.
|
"""Extract string content from various message formats.
|
||||||
@@ -985,6 +1044,11 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
# Normalize analyst selection to predefined order (selection is a 'set', order is fixed)
|
# Normalize analyst selection to predefined order (selection is a 'set', order is fixed)
|
||||||
selected_set = {analyst.value for analyst in selections["analysts"]}
|
selected_set = {analyst.value for analyst in selections["analysts"]}
|
||||||
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
|
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
|
||||||
|
analyst_execution_plan = build_analyst_execution_plan(
|
||||||
|
selected_analyst_keys,
|
||||||
|
concurrency_limit=config["analyst_concurrency_limit"],
|
||||||
|
)
|
||||||
|
analyst_wall_time_tracker = AnalystWallTimeTracker(analyst_execution_plan)
|
||||||
|
|
||||||
# Initialize the graph with callbacks bound to LLMs
|
# Initialize the graph with callbacks bound to LLMs
|
||||||
graph = TradingAgentsGraph(
|
graph = TradingAgentsGraph(
|
||||||
@@ -1018,7 +1082,7 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
with open(log_file, "a", encoding="utf-8") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def save_tool_call_decorator(obj, func_name):
|
def save_tool_call_decorator(obj, func_name):
|
||||||
func = getattr(obj, func_name)
|
func = getattr(obj, func_name)
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@@ -1051,12 +1115,14 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
# Now start the display layout
|
# Now start the display layout
|
||||||
layout = create_layout()
|
layout = create_layout()
|
||||||
|
|
||||||
with Live(layout, refresh_per_second=4) as live:
|
with Live(layout, refresh_per_second=4):
|
||||||
# Initial display
|
# Initial display
|
||||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Add initial messages
|
# Add initial messages
|
||||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||||
|
if selections["asset_type"] != "stock":
|
||||||
|
message_buffer.add_message("System", f"Detected asset type: {selections['asset_type']}")
|
||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"System", f"Analysis date: {selections['analysis_date']}"
|
"System", f"Analysis date: {selections['analysis_date']}"
|
||||||
)
|
)
|
||||||
@@ -1067,8 +1133,9 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Update agent status to in_progress for the first analyst
|
# Update agent status to in_progress for the first analyst
|
||||||
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
|
first_analyst = get_initial_analyst_node(analyst_execution_plan)
|
||||||
message_buffer.update_agent_status(first_analyst, "in_progress")
|
message_buffer.update_agent_status(first_analyst, "in_progress")
|
||||||
|
analyst_wall_time_tracker.mark_started(selected_analyst_keys[0])
|
||||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Create spinner text
|
# Create spinner text
|
||||||
@@ -1077,9 +1144,18 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
)
|
)
|
||||||
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Initialize state and get graph args with callbacks
|
# Initialize state and get graph args with callbacks.
|
||||||
|
# Resolve the instrument identity once here so all agents anchor to
|
||||||
|
# the real company (#814); the CLI builds state directly rather than
|
||||||
|
# going through propagate(), so this must happen on the CLI path too.
|
||||||
|
instrument_context = graph.resolve_instrument_context(
|
||||||
|
selections["ticker"], selections["asset_type"]
|
||||||
|
)
|
||||||
init_agent_state = graph.propagator.create_initial_state(
|
init_agent_state = graph.propagator.create_initial_state(
|
||||||
selections["ticker"], selections["analysis_date"]
|
selections["ticker"],
|
||||||
|
selections["analysis_date"],
|
||||||
|
asset_type=selections["asset_type"],
|
||||||
|
instrument_context=instrument_context,
|
||||||
)
|
)
|
||||||
# Pass callbacks to graph config for tool execution tracking
|
# Pass callbacks to graph config for tool execution tracking
|
||||||
# (LLM tracking is handled separately via LLM constructor)
|
# (LLM tracking is handled separately via LLM constructor)
|
||||||
@@ -1108,7 +1184,11 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||||
|
|
||||||
# Update analyst statuses based on report state (runs on every chunk)
|
# Update analyst statuses based on report state (runs on every chunk)
|
||||||
update_analyst_statuses(message_buffer, chunk)
|
update_analyst_statuses(
|
||||||
|
message_buffer,
|
||||||
|
chunk,
|
||||||
|
wall_time_tracker=analyst_wall_time_tracker,
|
||||||
|
)
|
||||||
|
|
||||||
# Research Team - Handle Investment Debate State
|
# Research Team - Handle Investment Debate State
|
||||||
if chunk.get("investment_debate_state"):
|
if chunk.get("investment_debate_state"):
|
||||||
@@ -1170,16 +1250,15 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
|
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
|
||||||
)
|
)
|
||||||
if judge:
|
if judge and message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
||||||
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
message_buffer.update_report_section(
|
||||||
message_buffer.update_report_section(
|
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
||||||
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
)
|
||||||
)
|
message_buffer.update_agent_status("Aggressive Analyst", "completed")
|
||||||
message_buffer.update_agent_status("Aggressive Analyst", "completed")
|
message_buffer.update_agent_status("Conservative Analyst", "completed")
|
||||||
message_buffer.update_agent_status("Conservative Analyst", "completed")
|
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
||||||
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
message_buffer.update_agent_status("Portfolio Manager", "completed")
|
||||||
message_buffer.update_agent_status("Portfolio Manager", "completed")
|
|
||||||
|
|
||||||
# Update the display
|
# Update the display
|
||||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
@@ -1191,7 +1270,6 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
final_state = {}
|
final_state = {}
|
||||||
for chunk in trace:
|
for chunk in trace:
|
||||||
final_state.update(chunk)
|
final_state.update(chunk)
|
||||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
|
||||||
|
|
||||||
# Update all agent statuses to completed
|
# Update all agent statuses to completed
|
||||||
for agent in message_buffer.agent_status:
|
for agent in message_buffer.agent_status:
|
||||||
@@ -1200,9 +1278,10 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
message_buffer.add_message(
|
message_buffer.add_message(
|
||||||
"System", f"Completed analysis for {selections['analysis_date']}"
|
"System", f"Completed analysis for {selections['analysis_date']}"
|
||||||
)
|
)
|
||||||
|
message_buffer.add_message("System", analyst_wall_time_tracker.format_summary())
|
||||||
|
|
||||||
# Update final report sections
|
# Update final report sections
|
||||||
for section in message_buffer.report_sections.keys():
|
for section in message_buffer.report_sections:
|
||||||
if section in final_state:
|
if section in final_state:
|
||||||
message_buffer.update_report_section(section, final_state[section])
|
message_buffer.update_report_section(section, final_state[section])
|
||||||
|
|
||||||
@@ -1210,6 +1289,7 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
|
|
||||||
# Post-analysis prompts (outside Live context for clean interaction)
|
# Post-analysis prompts (outside Live context for clean interaction)
|
||||||
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
|
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
|
||||||
|
console.print(f"[dim]{analyst_wall_time_tracker.format_summary()}[/dim]")
|
||||||
|
|
||||||
# Prompt to save report
|
# Prompt to save report
|
||||||
save_choice = typer.prompt("Save report?", default="Y").strip().upper()
|
save_choice = typer.prompt("Save report?", default="Y").strip().upper()
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Dict
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class AnalystType(str, Enum):
|
class AnalystType(str, Enum):
|
||||||
@@ -10,3 +8,8 @@ class AnalystType(str, Enum):
|
|||||||
SOCIAL = "social"
|
SOCIAL = "social"
|
||||||
NEWS = "news"
|
NEWS = "news"
|
||||||
FUNDAMENTALS = "fundamentals"
|
FUNDAMENTALS = "fundamentals"
|
||||||
|
|
||||||
|
|
||||||
|
class AssetType(str, Enum):
|
||||||
|
STOCK = "stock"
|
||||||
|
CRYPTO = "crypto"
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.outputs import LLMResult
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
|
||||||
class StatsCallbackHandler(BaseCallbackHandler):
|
class StatsCallbackHandler(BaseCallbackHandler):
|
||||||
@@ -19,8 +19,8 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Increment LLM call counter when an LLM starts."""
|
"""Increment LLM call counter when an LLM starts."""
|
||||||
@@ -29,8 +29,8 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
messages: List[List[Any]],
|
messages: list[list[Any]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Increment LLM call counter when a chat model starts."""
|
"""Increment LLM call counter when a chat model starts."""
|
||||||
@@ -57,7 +57,7 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
input_str: str,
|
input_str: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -65,7 +65,7 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self.tool_calls += 1
|
self.tool_calls += 1
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> dict[str, Any]:
|
||||||
"""Return current statistics."""
|
"""Return current statistics."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return {
|
return {
|
||||||
|
|||||||
250
cli/utils.py
250
cli/utils.py
@@ -1,18 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Dict
|
|
||||||
|
|
||||||
import questionary
|
import questionary
|
||||||
from dotenv import find_dotenv, set_key
|
from dotenv import find_dotenv, set_key
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType, AssetType
|
||||||
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||||
from tradingagents.llm_clients.model_catalog import get_model_options
|
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
|
TICKER_INPUT_EXAMPLES = "SPY, 0700.HK, BTC-USD"
|
||||||
|
|
||||||
ANALYST_ORDER = [
|
ANALYST_ORDER = [
|
||||||
("Market Analyst", AnalystType.MARKET),
|
("Market Analyst", AnalystType.MARKET),
|
||||||
@@ -21,12 +20,33 @@ ANALYST_ORDER = [
|
|||||||
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
|
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
CRYPTO_SUFFIXES = ("-USD", "-USDT", "-USDC", "-BTC", "-ETH")
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_ticker_input(value: str) -> bool:
|
||||||
|
"""Whether a ticker entry is acceptable (charset + length).
|
||||||
|
|
||||||
|
Allows the characters Yahoo symbols use, including ``=`` for futures/forex
|
||||||
|
like ``GC=F`` and ``EURUSD=X`` (#980), and ``^`` for indices. Empty input is
|
||||||
|
allowed (it defaults to SPY downstream).
|
||||||
|
"""
|
||||||
|
v = value.strip()
|
||||||
|
return not v or (all(ch.isalnum() or ch in "._-^=" for ch in v) and len(v) <= 32)
|
||||||
|
|
||||||
|
|
||||||
def get_ticker() -> str:
|
def get_ticker() -> str:
|
||||||
"""Prompt the user to enter a ticker symbol."""
|
"""Prompt the user to enter a ticker symbol, preserving exchange suffixes.
|
||||||
|
|
||||||
|
Uses questionary.text (not typer.prompt, which strips trailing dot-suffixes
|
||||||
|
like ``000404.SH`` on some shells) and validates the symbol charset so an
|
||||||
|
obvious typo is caught before the run starts.
|
||||||
|
"""
|
||||||
ticker = questionary.text(
|
ticker = questionary.text(
|
||||||
f"Enter the exact ticker symbol to analyze ({TICKER_INPUT_EXAMPLES}):",
|
f"Enter ticker symbol (e.g. {TICKER_INPUT_EXAMPLES}):",
|
||||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid ticker symbol.",
|
validate=lambda x: (
|
||||||
|
is_valid_ticker_input(x)
|
||||||
|
or "Please enter a valid ticker symbol, e.g. AAPL, 000404.SZ, 0700.HK, GC=F."
|
||||||
|
),
|
||||||
style=questionary.Style(
|
style=questionary.Style(
|
||||||
[
|
[
|
||||||
("text", "fg:green"),
|
("text", "fg:green"),
|
||||||
@@ -35,16 +55,48 @@ def get_ticker() -> str:
|
|||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if not ticker:
|
if ticker is None:
|
||||||
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
|
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
return normalize_ticker_symbol(ticker)
|
return normalize_ticker_symbol(ticker) if ticker.strip() else "SPY"
|
||||||
|
|
||||||
|
|
||||||
def normalize_ticker_symbol(ticker: str) -> str:
|
def normalize_ticker_symbol(ticker: str) -> str:
|
||||||
"""Normalize ticker input while preserving exchange suffixes."""
|
"""Resolve user input to its canonical Yahoo symbol (single source of truth).
|
||||||
return ticker.strip().upper()
|
|
||||||
|
Delegates to the data layer's ``normalize_symbol`` so the symbol the CLI
|
||||||
|
passes through the pipeline is exactly the one the data path will price
|
||||||
|
(e.g. ``BTCUSD`` -> ``BTC-USD``, ``XAUUSD`` -> ``GC=F``). Falls back to the
|
||||||
|
plain upper-case if the data layer is unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from tradingagents.dataflows.symbol_utils import normalize_symbol
|
||||||
|
|
||||||
|
return normalize_symbol(ticker)
|
||||||
|
except Exception:
|
||||||
|
return ticker.strip().upper()
|
||||||
|
|
||||||
|
|
||||||
|
def detect_asset_type(ticker: str) -> AssetType:
|
||||||
|
"""Classify on the canonical symbol so e.g. BTCUSD and BTC-USDT both read as
|
||||||
|
crypto (#981/#982), matching what the data path will actually fetch."""
|
||||||
|
canonical = normalize_ticker_symbol(ticker)
|
||||||
|
if canonical.endswith(CRYPTO_SUFFIXES):
|
||||||
|
return AssetType.CRYPTO
|
||||||
|
return AssetType.STOCK
|
||||||
|
|
||||||
|
|
||||||
|
def filter_analysts_for_asset_type(
|
||||||
|
analysts: list[AnalystType], asset_type: AssetType
|
||||||
|
) -> list[AnalystType]:
|
||||||
|
if asset_type != AssetType.CRYPTO:
|
||||||
|
return analysts
|
||||||
|
return [
|
||||||
|
analyst
|
||||||
|
for analyst in analysts
|
||||||
|
if analyst != AnalystType.FUNDAMENTALS
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_analysis_date() -> str:
|
def get_analysis_date() -> str:
|
||||||
@@ -80,12 +132,18 @@ def get_analysis_date() -> str:
|
|||||||
return date.strip()
|
return date.strip()
|
||||||
|
|
||||||
|
|
||||||
def select_analysts() -> List[AnalystType]:
|
def select_analysts(asset_type: AssetType = AssetType.STOCK) -> list[AnalystType]:
|
||||||
"""Select analysts using an interactive checkbox."""
|
"""Select analysts using an interactive checkbox."""
|
||||||
|
available_analysts = filter_analysts_for_asset_type(
|
||||||
|
[value for _, value in ANALYST_ORDER],
|
||||||
|
asset_type,
|
||||||
|
)
|
||||||
choices = questionary.checkbox(
|
choices = questionary.checkbox(
|
||||||
"Select Your [Analysts Team]:",
|
"Select Your [Analysts Team]:",
|
||||||
choices=[
|
choices=[
|
||||||
questionary.Choice(display, value=value) for display, value in ANALYST_ORDER
|
questionary.Choice(display, value=value)
|
||||||
|
for display, value in ANALYST_ORDER
|
||||||
|
if value in available_analysts
|
||||||
],
|
],
|
||||||
instruction="\n- Press Space to select/unselect analysts\n- Press 'a' to select/unselect all\n- Press Enter when done",
|
instruction="\n- Press Space to select/unselect analysts\n- Press 'a' to select/unselect all\n- Press Enter when done",
|
||||||
validate=lambda x: len(x) > 0 or "You must select at least one analyst.",
|
validate=lambda x: len(x) > 0 or "You must select at least one analyst.",
|
||||||
@@ -138,28 +196,74 @@ def select_research_depth() -> int:
|
|||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def _fetch_openrouter_models() -> List[Tuple[str, str]]:
|
# Mainstream OpenRouter chat-LLM provider namespaces. We surface the newest
|
||||||
|
# models from these rather than the universal-newest, which is dominated by
|
||||||
|
# niche/experimental releases. These are the general-purpose chat providers;
|
||||||
|
# more enterprise/specialised namespaces (nvidia, cohere, amazon, ...) tend to
|
||||||
|
# ship research/safety variants as their newest, so they're left out of the
|
||||||
|
# shortlist. Provider names are stable (unlike model IDs), so this rarely needs
|
||||||
|
# touching; anything not here is still reachable via Custom ID.
|
||||||
|
_OPENROUTER_MAINSTREAM = {
|
||||||
|
"openai", "anthropic", "google", "deepseek", "qwen", "mistralai",
|
||||||
|
"meta-llama", "x-ai", "z-ai", "minimax", "moonshotai",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_openrouter_models() -> list[tuple[str, str]]:
|
||||||
"""Fetch available models from the OpenRouter API."""
|
"""Fetch available models from the OpenRouter API."""
|
||||||
import requests
|
import requests
|
||||||
try:
|
try:
|
||||||
resp = requests.get("https://openrouter.ai/api/v1/models", timeout=10)
|
resp = requests.get("https://openrouter.ai/api/v1/models", timeout=10)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
models = resp.json().get("data", [])
|
models = resp.json().get("data", [])
|
||||||
|
# Newest first so the top-N shown really is the latest available — the
|
||||||
|
# API currently returns this order, but sort explicitly so the prompt's
|
||||||
|
# "latest available" label holds regardless of response ordering.
|
||||||
|
models.sort(key=lambda m: m.get("created") or 0, reverse=True)
|
||||||
return [(m.get("name") or m["id"], m["id"]) for m in models]
|
return [(m.get("name") or m["id"], m["id"]) for m in models]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"\n[yellow]Could not fetch OpenRouter models: {e}[/yellow]")
|
console.print(f"\n[yellow]Could not fetch OpenRouter models: {e}[/yellow]")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def select_openrouter_model() -> str:
|
def _require_text(message: str, hint: str) -> str:
|
||||||
"""Select an OpenRouter model from the newest available, or enter a custom ID."""
|
"""Prompt for a required value; exit cleanly if the user cancels.
|
||||||
models = _fetch_openrouter_models()
|
|
||||||
|
|
||||||
choices = [questionary.Choice(name, value=mid) for name, mid in models[:5]]
|
``questionary.text(...).ask()`` returns None on Ctrl-C/Esc; mirror the
|
||||||
|
exit-on-cancel behavior of the other required selections so a cancelled
|
||||||
|
prompt never returns an empty model/deployment that would fail downstream.
|
||||||
|
"""
|
||||||
|
response = questionary.text(
|
||||||
|
message,
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or hint,
|
||||||
|
).ask()
|
||||||
|
if response is None:
|
||||||
|
console.print("\n[red]Cancelled. Exiting...[/red]")
|
||||||
|
exit(1)
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def select_openrouter_model(mode: str) -> str:
|
||||||
|
"""Select an OpenRouter model from the newest available, or enter a custom ID.
|
||||||
|
|
||||||
|
``mode`` ("quick"/"deep") labels the prompt so the two consecutive
|
||||||
|
OpenRouter selections are distinguishable, like the other providers (#1000).
|
||||||
|
"""
|
||||||
|
models = _fetch_openrouter_models() # newest first
|
||||||
|
# Prefer the newest from mainstream providers so the shortlist isn't crowded
|
||||||
|
# out by niche/experimental releases; fall back to all if none match.
|
||||||
|
mainstream = [
|
||||||
|
(name, mid) for name, mid in models
|
||||||
|
if not mid.startswith("~") # skip variant/alias duplicate routes
|
||||||
|
and mid.split("/", 1)[0] in _OPENROUTER_MAINSTREAM
|
||||||
|
]
|
||||||
|
top = (mainstream or models)[:5]
|
||||||
|
|
||||||
|
choices = [questionary.Choice(name, value=mid) for name, mid in top]
|
||||||
choices.append(questionary.Choice("Custom model ID", value="custom"))
|
choices.append(questionary.Choice("Custom model ID", value="custom"))
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
"Select OpenRouter Model (latest available):",
|
f"Select Your [{mode.title()}-Thinking] OpenRouter Model (latest available):",
|
||||||
choices=choices,
|
choices=choices,
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
style=questionary.Style([
|
style=questionary.Style([
|
||||||
@@ -169,33 +273,32 @@ def select_openrouter_model() -> str:
|
|||||||
]),
|
]),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None or choice == "custom":
|
if choice is None:
|
||||||
return questionary.text(
|
console.print("\n[red]No model selected. Exiting...[/red]")
|
||||||
|
exit(1)
|
||||||
|
if choice == "custom":
|
||||||
|
return _require_text(
|
||||||
"Enter OpenRouter model ID (e.g. google/gemma-4-26b-a4b-it):",
|
"Enter OpenRouter model ID (e.g. google/gemma-4-26b-a4b-it):",
|
||||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
"Please enter a model ID.",
|
||||||
).ask().strip()
|
)
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def _prompt_custom_model_id() -> str:
|
def _prompt_custom_model_id() -> str:
|
||||||
"""Prompt user to type a custom model ID."""
|
"""Prompt user to type a custom model ID."""
|
||||||
return questionary.text(
|
return _require_text("Enter model ID:", "Please enter a model ID.")
|
||||||
"Enter model ID:",
|
|
||||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
|
||||||
).ask().strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _select_model(provider: str, mode: str) -> str:
|
def _select_model(provider: str, mode: str) -> str:
|
||||||
"""Select a model for the given provider and mode (quick/deep)."""
|
"""Select a model for the given provider and mode (quick/deep)."""
|
||||||
if provider.lower() == "openrouter":
|
if provider.lower() == "openrouter":
|
||||||
return select_openrouter_model()
|
return select_openrouter_model(mode)
|
||||||
|
|
||||||
if provider.lower() == "azure":
|
if provider.lower() == "azure":
|
||||||
return questionary.text(
|
return _require_text(
|
||||||
f"Enter Azure deployment name ({mode}-thinking):",
|
f"Enter Azure deployment name ({mode}-thinking):",
|
||||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
|
"Please enter a deployment name.",
|
||||||
).ask().strip()
|
)
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
|
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
|
||||||
@@ -232,14 +335,17 @@ def select_deep_thinking_agent(provider) -> str:
|
|||||||
"""Select deep thinking llm engine using an interactive selection."""
|
"""Select deep thinking llm engine using an interactive selection."""
|
||||||
return _select_model(provider, "deep")
|
return _select_model(provider, "deep")
|
||||||
|
|
||||||
def select_llm_provider() -> tuple[str, str | None]:
|
def _llm_provider_table() -> list[tuple[str, str, str | None]]:
|
||||||
"""Select the LLM provider and its API endpoint."""
|
"""(display_name, provider_key, base_url) for every supported provider.
|
||||||
# Ollama users can point at a remote ollama-serve via OLLAMA_BASE_URL
|
|
||||||
# (convention from the broader Ollama ecosystem); falls back to the
|
Shared by the interactive picker and by env-driven configuration so an
|
||||||
# localhost default when unset.
|
env-set provider resolves to the same default endpoint the menu uses.
|
||||||
|
Ollama users can point at a remote ollama-serve via OLLAMA_BASE_URL
|
||||||
|
(convention from the broader Ollama ecosystem); falls back to the
|
||||||
|
localhost default when unset.
|
||||||
|
"""
|
||||||
ollama_url = os.environ.get("OLLAMA_BASE_URL") or "http://localhost:11434/v1"
|
ollama_url = os.environ.get("OLLAMA_BASE_URL") or "http://localhost:11434/v1"
|
||||||
# (display_name, provider_key, base_url)
|
return [
|
||||||
PROVIDERS = [
|
|
||||||
("OpenAI", "openai", "https://api.openai.com/v1"),
|
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||||
("Google", "google", None),
|
("Google", "google", None),
|
||||||
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
||||||
@@ -249,10 +355,57 @@ def select_llm_provider() -> tuple[str, str | None]:
|
|||||||
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||||
("MiniMax", "minimax", "https://api.minimax.io/v1"),
|
("MiniMax", "minimax", "https://api.minimax.io/v1"),
|
||||||
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||||
|
("Mistral", "mistral", "https://api.mistral.ai/v1"),
|
||||||
|
("Kimi (Moonshot)", "kimi", "https://api.moonshot.ai/v1"),
|
||||||
|
("Groq", "groq", "https://api.groq.com/openai/v1"),
|
||||||
|
("NVIDIA NIM", "nvidia", "https://integrate.api.nvidia.com/v1"),
|
||||||
("Azure OpenAI", "azure", None),
|
("Azure OpenAI", "azure", None),
|
||||||
|
("Amazon Bedrock", "bedrock", None),
|
||||||
("Ollama", "ollama", ollama_url),
|
("Ollama", "ollama", ollama_url),
|
||||||
|
("OpenAI-compatible (vLLM, LM Studio, llama.cpp, custom relay)", "openai_compatible", None),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def provider_default_url(provider_key: str) -> str | None:
|
||||||
|
"""Return the default backend URL for a provider key, or None if unknown."""
|
||||||
|
key = provider_key.lower()
|
||||||
|
for _, pk, url in _llm_provider_table():
|
||||||
|
if pk == key:
|
||||||
|
return url
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_backend_url(
|
||||||
|
provider: str, menu_url: str | None = None, env_url: str | None = None
|
||||||
|
) -> str | None:
|
||||||
|
"""Resolve the backend URL with the correct precedence.
|
||||||
|
|
||||||
|
An explicit env override (``env_url``, from ``TRADINGAGENTS_LLM_BACKEND_URL``
|
||||||
|
via ``DEFAULT_CONFIG['backend_url']``) is honored regardless of how the
|
||||||
|
provider was chosen — interactively or from the environment (#978).
|
||||||
|
Otherwise the menu/region URL, then the provider's default.
|
||||||
|
"""
|
||||||
|
return env_url or menu_url or provider_default_url(provider)
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_openai_compatible_url() -> str:
|
||||||
|
"""Prompt for a custom OpenAI-compatible endpoint base URL."""
|
||||||
|
url = questionary.text(
|
||||||
|
"Enter the OpenAI-compatible base URL "
|
||||||
|
"(e.g. http://localhost:8000/v1 for vLLM, http://localhost:1234/v1 for LM Studio):",
|
||||||
|
validate=lambda x: x.strip().startswith(("http://", "https://"))
|
||||||
|
or "Enter a URL starting with http:// or https://",
|
||||||
|
).ask()
|
||||||
|
if not url:
|
||||||
|
console.print("\n[red]No endpoint URL provided. Exiting...[/red]")
|
||||||
|
exit(1)
|
||||||
|
return url.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def select_llm_provider() -> tuple[str, str | None]:
|
||||||
|
"""Select the LLM provider and its API endpoint."""
|
||||||
|
PROVIDERS = _llm_provider_table()
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
"Select your LLM Provider:",
|
"Select your LLM Provider:",
|
||||||
choices=[
|
choices=[
|
||||||
@@ -268,7 +421,7 @@ def select_llm_provider() -> tuple[str, str | None]:
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||||
exit(1)
|
exit(1)
|
||||||
@@ -422,7 +575,7 @@ def confirm_ollama_endpoint(url: str) -> None:
|
|||||||
|
|
||||||
Surfaces three things the user benefits from seeing before model
|
Surfaces three things the user benefits from seeing before model
|
||||||
selection: which URL we'll actually hit, where it came from
|
selection: which URL we'll actually hit, where it came from
|
||||||
(\`OLLAMA_BASE_URL\` vs default), and a soft warning if the URL is
|
(`OLLAMA_BASE_URL` vs default), and a soft warning if the URL is
|
||||||
missing the scheme/port that ollama-serve expects. The warning is
|
missing the scheme/port that ollama-serve expects. The warning is
|
||||||
advisory only — we don't reject malformed input, since the user may
|
advisory only — we don't reject malformed input, since the user may
|
||||||
be doing something deliberately unusual (e.g. a reverse-proxy path).
|
be doing something deliberately unusual (e.g. a reverse-proxy path).
|
||||||
@@ -447,7 +600,7 @@ def confirm_ollama_endpoint(url: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_api_key(provider: str) -> Optional[str]:
|
def ensure_api_key(provider: str) -> str | None:
|
||||||
"""Make sure the API key for `provider` is available in the environment.
|
"""Make sure the API key for `provider` is available in the environment.
|
||||||
|
|
||||||
If the env var is already set, returns its value untouched. Otherwise
|
If the env var is already set, returns its value untouched. Otherwise
|
||||||
@@ -462,6 +615,13 @@ def ensure_api_key(provider: str) -> Optional[str]:
|
|||||||
if env_var is None:
|
if env_var is None:
|
||||||
return None # ollama / unknown — no key check possible
|
return None # ollama / unknown — no key check possible
|
||||||
|
|
||||||
|
# Key-optional providers (generic OpenAI-compatible / local servers) read the
|
||||||
|
# key when present but must never force an interactive prompt.
|
||||||
|
from tradingagents.llm_clients.openai_client import OPENAI_COMPATIBLE_PROVIDERS
|
||||||
|
spec = OPENAI_COMPATIBLE_PROVIDERS.get(provider.lower())
|
||||||
|
if spec is not None and spec.key_optional:
|
||||||
|
return os.environ.get(env_var)
|
||||||
|
|
||||||
existing = os.environ.get(env_var)
|
existing = os.environ.get(env_var)
|
||||||
if existing:
|
if existing:
|
||||||
return existing
|
return existing
|
||||||
@@ -515,10 +675,14 @@ def ask_output_language() -> str:
|
|||||||
]),
|
]),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
|
# Output language has a sensible default, so a cancel falls back to English
|
||||||
|
# rather than exiting the run (unlike the required model/provider prompts).
|
||||||
|
if choice is None:
|
||||||
|
return "English"
|
||||||
if choice == "custom":
|
if choice == "custom":
|
||||||
return questionary.text(
|
return (questionary.text(
|
||||||
"Enter language name (e.g. Turkish, Vietnamese, Thai, Indonesian):",
|
"Enter language name (e.g. Turkish, Vietnamese, Thai, Indonesian):",
|
||||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a language name.",
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a language name.",
|
||||||
).ask().strip()
|
).ask() or "").strip() or "English"
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|||||||
2
main.py
2
main.py
@@ -1,5 +1,5 @@
|
|||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides
|
# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides
|
||||||
# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.),
|
# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.),
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ dependencies = [
|
|||||||
"langgraph-checkpoint-sqlite>=2.0.0",
|
"langgraph-checkpoint-sqlite>=2.0.0",
|
||||||
"pandas>=2.3.0",
|
"pandas>=2.3.0",
|
||||||
"parsel>=1.10.0",
|
"parsel>=1.10.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
"pytz>=2025.2",
|
"pytz>=2025.2",
|
||||||
"questionary>=2.1.0",
|
"questionary>=2.1.0",
|
||||||
"redis>=6.2.0",
|
"redis>=6.2.0",
|
||||||
@@ -29,7 +30,19 @@ dependencies = [
|
|||||||
"stockstats>=0.6.5",
|
"stockstats>=0.6.5",
|
||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
"typing-extensions>=4.14.0",
|
"typing-extensions>=4.14.0",
|
||||||
"yfinance>=0.2.63",
|
"yfinance>=1.4.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"ruff>=0.15",
|
||||||
|
"pytest>=8.0",
|
||||||
|
"pytest-subtests>=0.13",
|
||||||
|
]
|
||||||
|
# Amazon Bedrock support (AWS SigV4 auth + boto3). Optional so the core install
|
||||||
|
# stays lean: pip install "tradingagents[bedrock]".
|
||||||
|
bedrock = [
|
||||||
|
"langchain-aws>=1.5.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
@@ -52,3 +65,24 @@ markers = [
|
|||||||
filterwarnings = [
|
filterwarnings = [
|
||||||
"ignore::DeprecationWarning",
|
"ignore::DeprecationWarning",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py310"
|
||||||
|
extend-exclude = ["results", "worklog"]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
# Standard "good defaults" rule set (pyflakes + pycodestyle + isort + bugbear +
|
||||||
|
# pyupgrade + comprehensions/simplify). Line length (E501) and layout are owned
|
||||||
|
# by the formatter; whole-repo `ruff format` adoption is deferred until the
|
||||||
|
# open-PR backlog clears, to avoid mass merge conflicts.
|
||||||
|
select = ["E", "W", "F", "I", "B", "UP", "C4", "SIM"]
|
||||||
|
ignore = ["E501"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"**/__init__.py" = ["F401"] # intentional re-exports
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
# Keep multiple aliased names from one module in a single combined import block
|
||||||
|
# (e.g. the vendor re-exports in interface.py) instead of one statement per name.
|
||||||
|
combine-as-imports = true
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ added, plus the heuristic SignalProcessor.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||||
@@ -30,7 +29,6 @@ from tradingagents.agents.trader.trader import create_trader
|
|||||||
from tradingagents.graph.signal_processing import SignalProcessor
|
from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
from tradingagents.llm_clients import create_llm_client
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_DEFAULTS = {
|
PROVIDER_DEFAULTS = {
|
||||||
"openai": ("gpt-5.4-mini", None),
|
"openai": ("gpt-5.4-mini", None),
|
||||||
"google": ("gemini-2.5-flash", None),
|
"google": ("gemini-2.5-flash", None),
|
||||||
|
|||||||
5
test.py
5
test.py
@@ -1,5 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
|
||||||
|
from tradingagents.dataflows.y_finance import (
|
||||||
|
get_stock_stats_indicators_window,
|
||||||
|
)
|
||||||
|
|
||||||
print("Testing optimized implementation with 30-day lookback:")
|
print("Testing optimized implementation with 30-day lookback:")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
@@ -35,6 +35,25 @@ def _dummy_api_keys(monkeypatch):
|
|||||||
monkeypatch.setenv(env_var, os.environ.get(env_var, "placeholder"))
|
monkeypatch.setenv(env_var, os.environ.get(env_var, "placeholder"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _isolate_config():
|
||||||
|
"""Reset the global dataflows config before and after each test.
|
||||||
|
|
||||||
|
``set_config`` merges (it never clears keys absent from the override), so a
|
||||||
|
test that sets e.g. ``tool_vendors`` would otherwise leak into later tests
|
||||||
|
and make routing behavior order-dependent. Replace the global outright so
|
||||||
|
every test starts from a clean DEFAULT_CONFIG.
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import tradingagents.dataflows.config as config_module
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
yield
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def mock_llm_client():
|
def mock_llm_client():
|
||||||
client = MagicMock()
|
client = MagicMock()
|
||||||
|
|||||||
54
tests/test_alpha_vantage_hardening.py
Normal file
54
tests/test_alpha_vantage_hardening.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Alpha Vantage request hardening.
|
||||||
|
|
||||||
|
Regressions for #990 (no request timeout -> can hang) and #991 (invalid-key
|
||||||
|
responses mislabeled as rate limits and silently treated as transient).
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.dataflows.alpha_vantage_common as av
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_get(body, capture=None):
|
||||||
|
def fake_get(url, params=None, **kwargs):
|
||||||
|
if capture is not None:
|
||||||
|
capture.update(kwargs)
|
||||||
|
return _FakeResponse(body)
|
||||||
|
return fake_get
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_request_passes_timeout(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
monkeypatch.setattr(av.requests, "get", _patched_get("Date,Close\n2025-01-02,1.0", captured))
|
||||||
|
av._make_api_request("TIME_SERIES_DAILY", {"symbol": "AAPL"})
|
||||||
|
assert captured.get("timeout") == av.REQUEST_TIMEOUT # #990
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_rate_limit_detected(monkeypatch):
|
||||||
|
body = '{"Information": "Our standard API rate limit is 25 requests per day. ... your API key ..."}'
|
||||||
|
monkeypatch.setattr(av.requests, "get", _patched_get(body))
|
||||||
|
with pytest.raises(av.AlphaVantageRateLimitError):
|
||||||
|
av._make_api_request("TIME_SERIES_DAILY", {"symbol": "AAPL"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_invalid_key_not_mislabeled_as_rate_limit(monkeypatch):
|
||||||
|
# AV's invalid-key notice mentions "API key"; it must NOT be treated as a
|
||||||
|
# (transient) rate limit, but surface as a real configuration error (#991).
|
||||||
|
body = ('{"Information": "the parameter apikey is invalid or missing. '
|
||||||
|
'Please claim your free API key on (https://www.alphavantage.co/support/#api-key)."}')
|
||||||
|
monkeypatch.setattr(av.requests, "get", _patched_get(body))
|
||||||
|
with pytest.raises(av.AlphaVantageNotConfiguredError):
|
||||||
|
av._make_api_request("TIME_SERIES_DAILY", {"symbol": "AAPL"})
|
||||||
|
with pytest.raises(av.AlphaVantageRateLimitError): # sanity: rate-limit path still distinct
|
||||||
|
monkeypatch.setattr(av.requests, "get", _patched_get('{"Note": "API call frequency is 5 calls per minute."}'))
|
||||||
|
av._make_api_request("TIME_SERIES_DAILY", {"symbol": "AAPL"})
|
||||||
95
tests/test_analyst_execution.py
Normal file
95
tests/test_analyst_execution.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from tradingagents.graph.analyst_execution import (
|
||||||
|
AnalystWallTimeTracker,
|
||||||
|
build_analyst_execution_plan,
|
||||||
|
get_initial_analyst_node,
|
||||||
|
sync_analyst_tracker_from_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AnalystExecutionPlanTests(unittest.TestCase):
|
||||||
|
def test_build_plan_preserves_selected_order(self):
|
||||||
|
plan = build_analyst_execution_plan(["news", "market"], concurrency_limit=2)
|
||||||
|
|
||||||
|
self.assertEqual([spec.key for spec in plan.specs], ["news", "market"])
|
||||||
|
self.assertEqual(plan.concurrency_limit, 2)
|
||||||
|
self.assertEqual(plan.specs[0].agent_node, "News Analyst")
|
||||||
|
self.assertEqual(plan.specs[0].tool_node, "tools_news")
|
||||||
|
self.assertEqual(plan.specs[0].clear_node, "Msg Clear News")
|
||||||
|
|
||||||
|
def test_rejects_unknown_analyst_keys(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
build_analyst_execution_plan(["market", "macro"])
|
||||||
|
|
||||||
|
def test_requires_positive_concurrency_limit(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
build_analyst_execution_plan(["market"], concurrency_limit=0)
|
||||||
|
|
||||||
|
def test_get_initial_analyst_node_uses_plan_metadata(self):
|
||||||
|
plan = build_analyst_execution_plan(["fundamentals", "news"])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
get_initial_analyst_node(plan),
|
||||||
|
"Fundamentals Analyst",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_social_key_displays_as_sentiment_analyst(self):
|
||||||
|
# The wire key stays "social" for saved-config back-compat, but the
|
||||||
|
# user-visible agent_node label must match the v0.2.5 rename so the
|
||||||
|
# wall-time summary and any future consumer of agent_node says
|
||||||
|
# "Sentiment Analyst" rather than the legacy "Social Analyst".
|
||||||
|
plan = build_analyst_execution_plan(["social"])
|
||||||
|
spec = plan.specs[0]
|
||||||
|
self.assertEqual(spec.key, "social")
|
||||||
|
self.assertEqual(spec.agent_node, "Sentiment Analyst")
|
||||||
|
self.assertEqual(spec.report_key, "sentiment_report")
|
||||||
|
|
||||||
|
|
||||||
|
class AnalystWallTimeTrackerTests(unittest.TestCase):
|
||||||
|
def test_records_wall_time_when_analyst_completes(self):
|
||||||
|
plan = build_analyst_execution_plan(["market", "news"])
|
||||||
|
tracker = AnalystWallTimeTracker(plan)
|
||||||
|
|
||||||
|
tracker.mark_started("market", started_at=10.0)
|
||||||
|
tracker.mark_completed("market", completed_at=13.5)
|
||||||
|
|
||||||
|
self.assertEqual(tracker.get_wall_times(), {"market": 3.5})
|
||||||
|
|
||||||
|
def test_formats_summary_in_plan_order(self):
|
||||||
|
plan = build_analyst_execution_plan(["news", "market"])
|
||||||
|
tracker = AnalystWallTimeTracker(plan)
|
||||||
|
|
||||||
|
tracker.mark_started("market", started_at=20.0)
|
||||||
|
tracker.mark_completed("market", completed_at=22.25)
|
||||||
|
tracker.mark_started("news", started_at=10.0)
|
||||||
|
tracker.mark_completed("news", completed_at=14.0)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
tracker.format_summary(),
|
||||||
|
"Analyst wall time: News 4.00s | Market 2.25s",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_syncs_wall_time_from_sequential_chunks(self):
|
||||||
|
plan = build_analyst_execution_plan(["market", "news"])
|
||||||
|
tracker = AnalystWallTimeTracker(plan)
|
||||||
|
|
||||||
|
sync_analyst_tracker_from_chunk(tracker, {}, now=10.0)
|
||||||
|
self.assertEqual(tracker.get_wall_times(), {})
|
||||||
|
|
||||||
|
sync_analyst_tracker_from_chunk(
|
||||||
|
tracker,
|
||||||
|
{"market_report": "done"},
|
||||||
|
now=13.0,
|
||||||
|
)
|
||||||
|
self.assertEqual(tracker.get_wall_times(), {"market": 3.0})
|
||||||
|
|
||||||
|
sync_analyst_tracker_from_chunk(
|
||||||
|
tracker,
|
||||||
|
{"market_report": "done", "news_report": "done"},
|
||||||
|
now=18.0,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
tracker.get_wall_times(),
|
||||||
|
{"market": 3.0, "news": 5.0},
|
||||||
|
)
|
||||||
84
tests/test_anthropic_effort.py
Normal file
84
tests/test_anthropic_effort.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Tests for Anthropic effort-parameter gating (#831).
|
||||||
|
|
||||||
|
Haiku 4.5 (and current Haiku versions) reject the ``effort`` parameter
|
||||||
|
with a 400. Opus 4.5+ and Sonnet 4.5+ accept it. The gate uses a
|
||||||
|
forward-compat regex so future ``claude-{opus,sonnet}-X-Y`` releases
|
||||||
|
inherit support automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients import anthropic_client as mod
|
||||||
|
|
||||||
|
|
||||||
|
def _capture_kwargs(monkeypatch):
|
||||||
|
captured: dict = {}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mod, "NormalizedChatAnthropic",
|
||||||
|
lambda **kwargs: captured.setdefault("kwargs", kwargs),
|
||||||
|
)
|
||||||
|
return captured
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestEffortGate:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["claude-haiku-4-5", "claude-haiku-5-0", "claude-haiku-4-7-preview"],
|
||||||
|
)
|
||||||
|
def test_haiku_does_not_receive_effort(self, monkeypatch, model):
|
||||||
|
captured = _capture_kwargs(monkeypatch)
|
||||||
|
mod.AnthropicClient(model=model, effort="medium", api_key="x").get_llm()
|
||||||
|
assert "effort" not in captured["kwargs"]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"claude-opus-4-5", "claude-opus-4-6", "claude-opus-4-7",
|
||||||
|
"claude-sonnet-4-5", "claude-sonnet-4-6",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_current_opus_and_sonnet_receive_effort(self, monkeypatch, model):
|
||||||
|
captured = _capture_kwargs(monkeypatch)
|
||||||
|
mod.AnthropicClient(model=model, effort="high", api_key="x").get_llm()
|
||||||
|
assert captured["kwargs"]["effort"] == "high"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["claude-opus-5-0", "claude-opus-4-8", "claude-sonnet-5-0"],
|
||||||
|
)
|
||||||
|
def test_future_opus_sonnet_inherit_effort_via_pattern(self, monkeypatch, model):
|
||||||
|
"""Forward-compat: new Opus/Sonnet versions don't need a code change."""
|
||||||
|
captured = _capture_kwargs(monkeypatch)
|
||||||
|
mod.AnthropicClient(model=model, effort="low", api_key="x").get_llm()
|
||||||
|
assert captured["kwargs"]["effort"] == "low"
|
||||||
|
|
||||||
|
def test_mythos_preview_receives_effort(self, monkeypatch):
|
||||||
|
captured = _capture_kwargs(monkeypatch)
|
||||||
|
mod.AnthropicClient(
|
||||||
|
model="claude-mythos-preview", effort="medium", api_key="x"
|
||||||
|
).get_llm()
|
||||||
|
assert captured["kwargs"]["effort"] == "medium"
|
||||||
|
|
||||||
|
def test_unknown_anthropic_model_does_not_receive_effort(self, monkeypatch):
|
||||||
|
"""Default is conservative — unknown models don't get effort to avoid 400s."""
|
||||||
|
captured = _capture_kwargs(monkeypatch)
|
||||||
|
mod.AnthropicClient(
|
||||||
|
model="claude-experimental-x", effort="medium", api_key="x"
|
||||||
|
).get_llm()
|
||||||
|
assert "effort" not in captured["kwargs"]
|
||||||
|
|
||||||
|
def test_other_kwargs_still_forwarded_when_effort_skipped(self, monkeypatch):
|
||||||
|
"""Skipping effort must not break other passthrough kwargs."""
|
||||||
|
captured = _capture_kwargs(monkeypatch)
|
||||||
|
mod.AnthropicClient(
|
||||||
|
model="claude-haiku-4-5",
|
||||||
|
effort="medium",
|
||||||
|
api_key="placeholder",
|
||||||
|
max_tokens=1024,
|
||||||
|
timeout=30,
|
||||||
|
).get_llm()
|
||||||
|
assert captured["kwargs"]["api_key"] == "placeholder"
|
||||||
|
assert captured["kwargs"]["max_tokens"] == 1024
|
||||||
|
assert captured["kwargs"]["timeout"] == 30
|
||||||
|
assert "effort" not in captured["kwargs"]
|
||||||
@@ -3,14 +3,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
|
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
|
||||||
|
|
||||||
|
|
||||||
# ---- Mapping coverage -----------------------------------------------------
|
# ---- Mapping coverage -----------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -71,6 +69,7 @@ def test_case_insensitive_lookup():
|
|||||||
def cli_utils(monkeypatch):
|
def cli_utils(monkeypatch):
|
||||||
"""Import cli.utils with a fresh environment so module-level state is consistent."""
|
"""Import cli.utils with a fresh environment so module-level state is consistent."""
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import cli.utils as cli_utils_module
|
import cli.utils as cli_utils_module
|
||||||
return importlib.reload(cli_utils_module)
|
return importlib.reload(cli_utils_module)
|
||||||
|
|
||||||
|
|||||||
46
tests/test_bedrock_provider.py
Normal file
46
tests/test_bedrock_provider.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Amazon Bedrock — first-class native client via the optional langchain-aws extra.
|
||||||
|
|
||||||
|
Auth uses the AWS credential chain (no single key env); the model is a Bedrock
|
||||||
|
model ID / inference profile ID; langchain-aws is imported lazily with a clear
|
||||||
|
install hint when the [bedrock] extra is absent.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
from tradingagents.llm_clients.validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_factory_routes_bedrock():
|
||||||
|
client = create_llm_client("bedrock", "us.anthropic.claude-opus-4-8-v1:0")
|
||||||
|
assert type(client).__name__ == "BedrockClient"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_bedrock_any_model_and_no_key_env():
|
||||||
|
assert validate_model("bedrock", "any.model-id:0") is True
|
||||||
|
# Bedrock uses the AWS credential chain, so there is no single key env.
|
||||||
|
assert get_api_key_env("bedrock") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_helpful_error_when_langchain_aws_absent(monkeypatch):
|
||||||
|
import tradingagents.llm_clients.bedrock_client as bc
|
||||||
|
monkeypatch.setattr(bc, "_BEDROCK_CLASS", None)
|
||||||
|
monkeypatch.setitem(sys.modules, "langchain_aws", None) # force ImportError on import
|
||||||
|
with pytest.raises(ImportError, match=r"bedrock"):
|
||||||
|
create_llm_client("bedrock", "m").get_llm()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_construction_when_extra_installed(monkeypatch):
|
||||||
|
pytest.importorskip("langchain_aws")
|
||||||
|
import tradingagents.llm_clients.bedrock_client as bc
|
||||||
|
monkeypatch.setattr(bc, "_BEDROCK_CLASS", None)
|
||||||
|
monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1")
|
||||||
|
llm = create_llm_client("bedrock", "us.anthropic.claude-sonnet-4-6-v1:0").get_llm()
|
||||||
|
assert type(llm).__name__ == "NormalizedChatBedrockConverse"
|
||||||
|
assert llm.region_name == "eu-west-1"
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Unit tests for the LLM capability table."""
|
"""Unit tests for the LLM capability table."""
|
||||||
|
|
||||||
|
from dataclasses import FrozenInstanceError
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tradingagents.llm_clients.capabilities import (
|
from tradingagents.llm_clients.capabilities import (
|
||||||
ModelCapabilities,
|
|
||||||
get_capabilities,
|
get_capabilities,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,7 +48,7 @@ class TestPatternMatches:
|
|||||||
caps = get_capabilities("deepseek-reasoner-pro")
|
caps = get_capabilities("deepseek-reasoner-pro")
|
||||||
assert caps.supports_tool_choice is False
|
assert caps.supports_tool_choice is False
|
||||||
|
|
||||||
def test_future_minimax_m3_inherits_thinking_quirks(self):
|
def test_minimax_m3_inherits_thinking_quirks(self):
|
||||||
caps = get_capabilities("MiniMax-M3")
|
caps = get_capabilities("MiniMax-M3")
|
||||||
assert caps.supports_tool_choice is False
|
assert caps.supports_tool_choice is False
|
||||||
|
|
||||||
@@ -75,6 +76,22 @@ class TestMinimaxExactMatches:
|
|||||||
def test_m2_base_rejects_tool_choice(self):
|
def test_m2_base_rejects_tool_choice(self):
|
||||||
assert get_capabilities("MiniMax-M2").supports_tool_choice is False
|
assert get_capabilities("MiniMax-M2").supports_tool_choice is False
|
||||||
|
|
||||||
|
def test_m2_x_requires_reasoning_split(self):
|
||||||
|
# M2.x reasoning models need reasoning_split=True so <think> blocks
|
||||||
|
# land in reasoning_details instead of content (#826).
|
||||||
|
for model in ("MiniMax-M2.7", "MiniMax-M2.5-highspeed", "MiniMax-M2"):
|
||||||
|
assert get_capabilities(model).requires_reasoning_split is True
|
||||||
|
|
||||||
|
def test_future_m3_inherits_reasoning_split(self):
|
||||||
|
assert get_capabilities("MiniMax-M3-highspeed").requires_reasoning_split is True
|
||||||
|
|
||||||
|
def test_non_reasoning_minimax_does_not_get_reasoning_split(self):
|
||||||
|
# Coding Plan, MiniMax-Text-01, and any non-M2-prefixed MiniMax model
|
||||||
|
# reject the reasoning_split kwarg via the openai SDK's strict
|
||||||
|
# validation (#826). Default capability has it disabled.
|
||||||
|
for model in ("minimax-text-01", "MiniMax-Coding-Plan", "abab6.5-chat"):
|
||||||
|
assert get_capabilities(model).requires_reasoning_split is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
class TestDefault:
|
class TestDefault:
|
||||||
@@ -103,5 +120,5 @@ class TestDefault:
|
|||||||
def test_capabilities_dataclass_is_frozen():
|
def test_capabilities_dataclass_is_frozen():
|
||||||
"""Capability rows are immutable so they can be safely shared."""
|
"""Capability rows are immutable so they can be safely shared."""
|
||||||
caps = get_capabilities("deepseek-chat")
|
caps = get_capabilities("deepseek-chat")
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(FrozenInstanceError):
|
||||||
caps.supports_tool_choice = False # type: ignore[misc]
|
caps.supports_tool_choice = False # type: ignore[misc]
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
||||||
from langgraph.graph import END, StateGraph
|
from langgraph.graph import END, StateGraph
|
||||||
|
|
||||||
from tradingagents.graph.checkpointer import (
|
from tradingagents.graph.checkpointer import (
|
||||||
|
|||||||
86
tests/test_cli_env_skip.py
Normal file
86
tests/test_cli_env_skip.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Tests for env-driven CLI behavior (#897, #873).
|
||||||
|
|
||||||
|
The config-layer override (TRADINGAGENTS_* -> DEFAULT_CONFIG) is covered by
|
||||||
|
test_env_overrides.py. These tests cover the CLI layer: an env-configured
|
||||||
|
provider/model/language must skip its interactive prompt and use the value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestProviderDefaultUrl(unittest.TestCase):
|
||||||
|
def test_known_providers_resolve(self):
|
||||||
|
from cli.utils import provider_default_url
|
||||||
|
self.assertEqual(provider_default_url("openai"), "https://api.openai.com/v1")
|
||||||
|
self.assertEqual(provider_default_url("DeepSeek"), "https://api.deepseek.com")
|
||||||
|
self.assertIsNone(provider_default_url("google")) # uses SDK default
|
||||||
|
|
||||||
|
def test_unknown_provider_returns_none(self):
|
||||||
|
from cli.utils import provider_default_url
|
||||||
|
self.assertIsNone(provider_default_url("not-a-provider"))
|
||||||
|
|
||||||
|
def test_ollama_honors_base_url_env(self):
|
||||||
|
from cli.utils import provider_default_url
|
||||||
|
with mock.patch.dict(os.environ, {"OLLAMA_BASE_URL": "http://host:1234/v1"}):
|
||||||
|
self.assertEqual(provider_default_url("ollama"), "http://host:1234/v1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestCliSkipsPromptsFromEnv(unittest.TestCase):
|
||||||
|
def test_env_config_skips_llm_prompts(self):
|
||||||
|
import cli.main as m
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"TRADINGAGENTS_LLM_PROVIDER": "openai",
|
||||||
|
"TRADINGAGENTS_DEEP_THINK_LLM": "kimi-k2.5",
|
||||||
|
"TRADINGAGENTS_QUICK_THINK_LLM": "deepseek-v4-pro",
|
||||||
|
"TRADINGAGENTS_LLM_BACKEND_URL": "https://opencode.ai/zen/go/v1",
|
||||||
|
"TRADINGAGENTS_OUTPUT_LANGUAGE": "Japanese",
|
||||||
|
}
|
||||||
|
fake_cfg = dict(m.DEFAULT_CONFIG)
|
||||||
|
fake_cfg.update({
|
||||||
|
"llm_provider": "openai",
|
||||||
|
"backend_url": "https://opencode.ai/zen/go/v1",
|
||||||
|
"quick_think_llm": "deepseek-v4-pro",
|
||||||
|
"deep_think_llm": "kimi-k2.5",
|
||||||
|
"output_language": "Japanese",
|
||||||
|
})
|
||||||
|
|
||||||
|
with mock.patch.dict(os.environ, env, clear=False), \
|
||||||
|
mock.patch.object(m, "DEFAULT_CONFIG", fake_cfg), \
|
||||||
|
mock.patch.object(m, "fetch_announcements", return_value=None), \
|
||||||
|
mock.patch.object(m, "display_announcements"), \
|
||||||
|
mock.patch.object(m, "get_ticker", return_value="AAPL"), \
|
||||||
|
mock.patch.object(m, "get_analysis_date", return_value="2026-05-29"), \
|
||||||
|
mock.patch.object(m, "select_analysts", return_value=[]), \
|
||||||
|
mock.patch.object(m, "select_research_depth", return_value=1), \
|
||||||
|
mock.patch.object(m, "ensure_api_key") as ensure_key, \
|
||||||
|
mock.patch.object(m, "select_llm_provider") as prompt_provider, \
|
||||||
|
mock.patch.object(m, "ask_output_language") as prompt_lang, \
|
||||||
|
mock.patch.object(m, "select_shallow_thinking_agent") as prompt_quick, \
|
||||||
|
mock.patch.object(m, "select_deep_thinking_agent") as prompt_deep:
|
||||||
|
sel = m.get_user_selections()
|
||||||
|
|
||||||
|
# None of the LLM selection prompts should have been shown.
|
||||||
|
prompt_provider.assert_not_called()
|
||||||
|
prompt_lang.assert_not_called()
|
||||||
|
prompt_quick.assert_not_called()
|
||||||
|
prompt_deep.assert_not_called()
|
||||||
|
# API key is still verified for the env-configured provider.
|
||||||
|
ensure_key.assert_called_once()
|
||||||
|
|
||||||
|
# The env values flow into the returned selections.
|
||||||
|
self.assertEqual(sel["llm_provider"], "openai")
|
||||||
|
self.assertEqual(sel["backend_url"], "https://opencode.ai/zen/go/v1")
|
||||||
|
self.assertEqual(sel["shallow_thinker"], "deepseek-v4-pro")
|
||||||
|
self.assertEqual(sel["deep_thinker"], "kimi-k2.5")
|
||||||
|
self.assertEqual(sel["output_language"], "Japanese")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
62
tests/test_cli_symbol_handling.py
Normal file
62
tests/test_cli_symbol_handling.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""CLI symbol validation/classification must agree with the data path.
|
||||||
|
|
||||||
|
Regressions for #980 (validation rejected GC=F), #981 (BTCUSD misclassified as
|
||||||
|
stock), #982 (BTC-USDT accepted but unpriceable on Yahoo).
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cli.models import AssetType
|
||||||
|
from cli.utils import detect_asset_type, is_valid_ticker_input, normalize_ticker_symbol
|
||||||
|
from tradingagents.dataflows.symbol_utils import normalize_symbol
|
||||||
|
|
||||||
|
|
||||||
|
# --- #982: stablecoin-quoted crypto normalizes to Yahoo's -USD pair ---
|
||||||
|
@pytest.mark.parametrize("raw,expected", [
|
||||||
|
("BTCUSD", "BTC-USD"),
|
||||||
|
("BTCUSDT", "BTC-USD"),
|
||||||
|
("BTC-USDT", "BTC-USD"),
|
||||||
|
("BTC-USDC", "BTC-USD"),
|
||||||
|
("ethusdt", "ETH-USD"),
|
||||||
|
# non-crypto must be untouched
|
||||||
|
("AAPL", "AAPL"),
|
||||||
|
("GC=F", "GC=F"),
|
||||||
|
("600519.SS", "600519.SS"),
|
||||||
|
("EURUSD", "EURUSD=X"),
|
||||||
|
])
|
||||||
|
def test_normalize_symbol_crypto_and_passthrough(raw, expected):
|
||||||
|
assert normalize_symbol(raw) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# --- #980: validation accepts Yahoo futures/forex symbols ---
|
||||||
|
@pytest.mark.parametrize("value,ok", [
|
||||||
|
("GC=F", True),
|
||||||
|
("EURUSD=X", True),
|
||||||
|
("AAPL", True),
|
||||||
|
("0700.HK", True),
|
||||||
|
("^GSPC", True),
|
||||||
|
("", True), # empty -> defaults to SPY downstream
|
||||||
|
("bad symbol!", False), # space + '!' rejected
|
||||||
|
("A" * 40, False), # too long
|
||||||
|
])
|
||||||
|
def test_ticker_input_validation(value, ok):
|
||||||
|
assert is_valid_ticker_input(value) is ok
|
||||||
|
|
||||||
|
|
||||||
|
# --- #981/#982: asset-type classified on the canonical symbol ---
|
||||||
|
@pytest.mark.parametrize("raw,expected", [
|
||||||
|
("BTCUSD", AssetType.CRYPTO),
|
||||||
|
("BTC-USDT", AssetType.CRYPTO),
|
||||||
|
("BTC-USD", AssetType.CRYPTO),
|
||||||
|
("ETHUSD", AssetType.CRYPTO),
|
||||||
|
("AAPL", AssetType.STOCK),
|
||||||
|
("GC=F", AssetType.STOCK),
|
||||||
|
("600519.SS", AssetType.STOCK),
|
||||||
|
])
|
||||||
|
def test_detect_asset_type(raw, expected):
|
||||||
|
assert detect_asset_type(raw) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_normalize_delegates_to_data_layer():
|
||||||
|
# CLI must produce the same canonical symbol the data path will price.
|
||||||
|
for raw in ("XAUUSD", "BTCUSD", "btc-usdt", "AAPL"):
|
||||||
|
assert normalize_ticker_symbol(raw) == normalize_symbol(raw)
|
||||||
56
tests/test_crypto_asset_mode.py
Normal file
56
tests/test_crypto_asset_mode.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from cli.models import AnalystType, AssetType
|
||||||
|
from cli.utils import detect_asset_type, filter_analysts_for_asset_type
|
||||||
|
from tradingagents.graph.propagation import Propagator
|
||||||
|
|
||||||
|
|
||||||
|
class CryptoAssetModeTests(unittest.TestCase):
|
||||||
|
def test_detects_crypto_pair_symbols(self):
|
||||||
|
self.assertEqual(detect_asset_type("BTC-USD"), AssetType.CRYPTO)
|
||||||
|
self.assertEqual(detect_asset_type("eth-usd"), AssetType.CRYPTO)
|
||||||
|
|
||||||
|
def test_defaults_non_crypto_symbols_to_stock(self):
|
||||||
|
self.assertEqual(detect_asset_type("AAPL"), AssetType.STOCK)
|
||||||
|
self.assertEqual(detect_asset_type("SPY"), AssetType.STOCK)
|
||||||
|
|
||||||
|
def test_filters_out_fundamentals_analyst_for_crypto(self):
|
||||||
|
analysts = [
|
||||||
|
AnalystType.MARKET,
|
||||||
|
AnalystType.SOCIAL,
|
||||||
|
AnalystType.NEWS,
|
||||||
|
AnalystType.FUNDAMENTALS,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
filter_analysts_for_asset_type(analysts, AssetType.CRYPTO),
|
||||||
|
[
|
||||||
|
AnalystType.MARKET,
|
||||||
|
AnalystType.SOCIAL,
|
||||||
|
AnalystType.NEWS,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_keeps_all_analysts_for_stock(self):
|
||||||
|
analysts = [
|
||||||
|
AnalystType.MARKET,
|
||||||
|
AnalystType.SOCIAL,
|
||||||
|
AnalystType.NEWS,
|
||||||
|
AnalystType.FUNDAMENTALS,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
filter_analysts_for_asset_type(analysts, AssetType.STOCK),
|
||||||
|
analysts,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_propagator_includes_asset_type_in_initial_state(self):
|
||||||
|
state = Propagator().create_initial_state(
|
||||||
|
"BTC-USD", "2026-04-18", asset_type=AssetType.CRYPTO.value
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(state["asset_type"], AssetType.CRYPTO.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
61
tests/test_date_boundaries.py
Normal file
61
tests/test_date_boundaries.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""yfinance treats ``end`` as exclusive; we must request one extra day so the
|
||||||
|
requested end_date (and the current day) is actually included.
|
||||||
|
|
||||||
|
Regressions for #986 (current-day OHLCV excluded) and #987 (requested end_date
|
||||||
|
row omitted).
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.dataflows.stockstats_utils as su
|
||||||
|
import tradingagents.dataflows.y_finance as yfin
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_get_yfin_requests_inclusive_end(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class FakeTicker:
|
||||||
|
def __init__(self, symbol):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def history(self, start, end):
|
||||||
|
captured["start"] = start
|
||||||
|
captured["end"] = end
|
||||||
|
idx = pd.to_datetime(["2025-05-08", "2025-05-09"])
|
||||||
|
return pd.DataFrame(
|
||||||
|
{"Open": [1.0, 2.0], "High": [1.0, 2.0], "Low": [1.0, 2.0],
|
||||||
|
"Close": [1.0, 2.0], "Volume": [1, 2]},
|
||||||
|
index=idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(yfin.yf, "Ticker", FakeTicker)
|
||||||
|
out = yfin.get_YFin_data_online("AAPL", "2025-05-01", "2025-05-09")
|
||||||
|
|
||||||
|
# end is requested one day past end_date so 2025-05-09 is included (#987).
|
||||||
|
assert captured["end"] == "2025-05-10"
|
||||||
|
# Header still reflects the requested range, not the internal +1 day.
|
||||||
|
assert "to 2025-05-09" in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_load_ohlcv_requests_inclusive_end(monkeypatch, tmp_path):
|
||||||
|
set_config({"data_cache_dir": str(tmp_path)})
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_download(symbol, start, end, **kwargs):
|
||||||
|
captured["end"] = end
|
||||||
|
idx = pd.to_datetime([pd.Timestamp.today().normalize()])
|
||||||
|
return pd.DataFrame(
|
||||||
|
{"Open": [100.0], "High": [100.0], "Low": [100.0],
|
||||||
|
"Close": [100.0], "Volume": [1]},
|
||||||
|
index=idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(su.yf, "download", fake_download)
|
||||||
|
today = pd.Timestamp.today().strftime("%Y-%m-%d")
|
||||||
|
su.load_ohlcv("AAPL", today)
|
||||||
|
|
||||||
|
expected_end = (pd.Timestamp.today() + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||||
|
assert captured["end"] == expected_end # tomorrow -> today's row included (#986)
|
||||||
@@ -24,7 +24,6 @@ from tradingagents.llm_clients.openai_client import (
|
|||||||
_input_to_messages,
|
_input_to_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
||||||
# (Gemini bot review note: non-list inputs must also work)
|
# (Gemini bot review note: non-list inputs must also work)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def _reload_with_env(monkeypatch, **overrides):
|
|||||||
def test_no_env_uses_built_in_defaults(monkeypatch):
|
def test_no_env_uses_built_in_defaults(monkeypatch):
|
||||||
dc = _reload_with_env(monkeypatch)
|
dc = _reload_with_env(monkeypatch)
|
||||||
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
|
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
|
||||||
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gpt-5.4"
|
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gpt-5.5"
|
||||||
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gpt-5.4-mini"
|
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gpt-5.4-mini"
|
||||||
assert dc.DEFAULT_CONFIG["backend_url"] is None
|
assert dc.DEFAULT_CONFIG["backend_url"] is None
|
||||||
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
|
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
|
||||||
|
|||||||
177
tests/test_fred.py
Normal file
177
tests/test_fred.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""FRED macro vendor: alias resolution, configuration errors, output formatting,
|
||||||
|
missing-value handling, lookahead-safe windowing, and router integration.
|
||||||
|
|
||||||
|
All API access is mocked, so these run without a network connection or a key.
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.dataflows.config as config_module
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
from tradingagents.dataflows import fred, interface
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
|
||||||
|
# A small, stable set of observations to format against.
|
||||||
|
_META = {
|
||||||
|
"seriess": [
|
||||||
|
{
|
||||||
|
"title": "Unemployment Rate",
|
||||||
|
"units_short": "%",
|
||||||
|
"frequency": "Monthly",
|
||||||
|
"seasonal_adjustment_short": "SA",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
_OBS = {
|
||||||
|
"observations": [
|
||||||
|
{"date": "2025-06-01", "value": "4.1"},
|
||||||
|
{"date": "2025-07-01", "value": "4.3"},
|
||||||
|
{"date": "2025-08-01", "value": "."}, # missing -> skipped
|
||||||
|
{"date": "2025-09-01", "value": "4.4"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _request_stub(meta=_META, obs=_OBS):
|
||||||
|
"""Build a _request replacement that dispatches on the endpoint path."""
|
||||||
|
def _impl(path, params):
|
||||||
|
if path == "series":
|
||||||
|
return meta
|
||||||
|
if path == "series/observations":
|
||||||
|
return obs
|
||||||
|
raise AssertionError(f"unexpected FRED path: {path}")
|
||||||
|
return _impl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class FredResolutionTests(unittest.TestCase):
|
||||||
|
def test_alias_maps_to_series_id(self):
|
||||||
|
self.assertEqual(fred._resolve_series_id("cpi"), "CPIAUCSL")
|
||||||
|
self.assertEqual(fred._resolve_series_id("unemployment"), "UNRATE")
|
||||||
|
|
||||||
|
def test_alias_is_case_and_separator_insensitive(self):
|
||||||
|
self.assertEqual(fred._resolve_series_id("Fed Funds Rate"), "FEDFUNDS")
|
||||||
|
self.assertEqual(fred._resolve_series_id("10y-treasury"), "DGS10")
|
||||||
|
|
||||||
|
def test_unknown_alias_is_treated_as_raw_series_id(self):
|
||||||
|
# Power users can pass any FRED series ID; we uppercase by convention.
|
||||||
|
self.assertEqual(fred._resolve_series_id("dgs30"), "DGS30")
|
||||||
|
self.assertEqual(fred._resolve_series_id("MyCustomSeries"), "MYCUSTOMSERIES")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class FredConfigTests(unittest.TestCase):
|
||||||
|
def test_missing_key_raises_not_configured(self):
|
||||||
|
with mock.patch.dict("os.environ", {}, clear=True), \
|
||||||
|
self.assertRaises(fred.FredNotConfiguredError):
|
||||||
|
fred.get_api_key()
|
||||||
|
|
||||||
|
def test_not_configured_is_a_value_error(self):
|
||||||
|
# Routing relies on this subclassing for "vendor unavailable" handling.
|
||||||
|
self.assertTrue(issubclass(fred.FredNotConfiguredError, ValueError))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class FredFormattingTests(unittest.TestCase):
|
||||||
|
def test_report_has_header_latest_change_and_table(self):
|
||||||
|
with mock.patch.object(fred, "_request", side_effect=_request_stub()):
|
||||||
|
out = fred.get_macro_data("unemployment", "2025-09-30", 365)
|
||||||
|
self.assertIn("## FRED: Unemployment Rate (UNRATE)", out)
|
||||||
|
self.assertIn("Units: %", out)
|
||||||
|
self.assertIn("Frequency: Monthly (SA)", out)
|
||||||
|
self.assertIn("**Latest:** 4.4 (2025-09-01)", out)
|
||||||
|
# change over the window: 4.4 - 4.1 = +0.30
|
||||||
|
self.assertIn("+0.30", out)
|
||||||
|
self.assertIn("| 2025-06-01 | 4.1 |", out)
|
||||||
|
|
||||||
|
def test_missing_value_is_skipped(self):
|
||||||
|
with mock.patch.object(fred, "_request", side_effect=_request_stub()):
|
||||||
|
out = fred.get_macro_data("unemployment", "2025-09-30", 365)
|
||||||
|
# the "." observation must not appear as a row
|
||||||
|
self.assertNotIn("2025-08-01", out)
|
||||||
|
|
||||||
|
def test_empty_window_reports_no_observations(self):
|
||||||
|
empty = {"observations": []}
|
||||||
|
with mock.patch.object(fred, "_request", side_effect=_request_stub(obs=empty)):
|
||||||
|
out = fred.get_macro_data("unemployment", "2025-09-30", 30)
|
||||||
|
self.assertIn("No observations", out)
|
||||||
|
|
||||||
|
def test_unknown_series_raises(self):
|
||||||
|
no_series = {"seriess": []}
|
||||||
|
with mock.patch.object(fred, "_request", side_effect=_request_stub(meta=no_series)), \
|
||||||
|
self.assertRaises(ValueError):
|
||||||
|
fred.get_macro_data("totally_unknown_xyz", "2025-09-30", 30)
|
||||||
|
|
||||||
|
def test_long_series_is_truncated_but_change_uses_full_range(self):
|
||||||
|
# Build > MAX_ROWS observations deterministically.
|
||||||
|
obs = {
|
||||||
|
"observations": [
|
||||||
|
{"date": f"2025-01-{(i % 28) + 1:02d}", "value": str(i)}
|
||||||
|
for i in range(fred.MAX_ROWS + 10)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
with mock.patch.object(fred, "_request", side_effect=_request_stub(obs=obs)):
|
||||||
|
out = fred.get_macro_data("unemployment", "2025-12-31", 365)
|
||||||
|
self.assertIn(f"most recent {fred.MAX_ROWS}", out)
|
||||||
|
# change-over-window must reference the true first (0) and last value
|
||||||
|
self.assertIn("from 0 ", out)
|
||||||
|
body_rows = [ln for ln in out.splitlines() if ln.startswith("| 2025")]
|
||||||
|
self.assertEqual(len(body_rows), fred.MAX_ROWS)
|
||||||
|
|
||||||
|
def test_window_is_lookahead_safe(self):
|
||||||
|
# observation_end must equal curr_date so a past date never pulls future data.
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def _capture(path, params):
|
||||||
|
captured[path] = params
|
||||||
|
return _META if path == "series" else _OBS
|
||||||
|
|
||||||
|
with mock.patch.object(fred, "_request", side_effect=_capture):
|
||||||
|
fred.get_macro_data("unemployment", "2025-09-30", 90)
|
||||||
|
obs_params = captured["series/observations"]
|
||||||
|
self.assertEqual(obs_params["observation_end"], "2025-09-30")
|
||||||
|
self.assertEqual(obs_params["observation_start"], "2025-07-02") # 90d back
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class FredRoutingTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def test_macro_category_routes_to_fred(self):
|
||||||
|
self.assertEqual(
|
||||||
|
interface.get_category_for_method("get_macro_indicators"), "macro_data"
|
||||||
|
)
|
||||||
|
set_config({"data_vendors": {"macro_data": "fred"}})
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS,
|
||||||
|
{"get_macro_indicators": {"fred": lambda *a, **k: "MACRO_OK"}},
|
||||||
|
clear=False,
|
||||||
|
):
|
||||||
|
out = interface.route_to_vendor("get_macro_indicators", "cpi", "2026-06-01", 365)
|
||||||
|
self.assertEqual(out, "MACRO_OK")
|
||||||
|
|
||||||
|
def test_not_configured_surfaces_through_router(self):
|
||||||
|
# With only fred and no key, the router has no fallback and must surface
|
||||||
|
# the real "not configured" failure rather than masking it.
|
||||||
|
set_config({"data_vendors": {"macro_data": "fred"}})
|
||||||
|
|
||||||
|
def _unconfigured(*a, **k):
|
||||||
|
raise fred.FredNotConfiguredError("FRED_API_KEY not set")
|
||||||
|
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS,
|
||||||
|
{"get_macro_indicators": {"fred": _unconfigured}},
|
||||||
|
clear=False,
|
||||||
|
), self.assertRaises(fred.FredNotConfiguredError):
|
||||||
|
interface.route_to_vendor("get_macro_indicators", "cpi", "2026-06-01", 365)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
59
tests/test_i18n_coverage.py
Normal file
59
tests/test_i18n_coverage.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Every report-producing agent must apply the configured output language
|
||||||
|
(#740/#801).
|
||||||
|
|
||||||
|
A non-English run should produce a fully localized report, not a mix of
|
||||||
|
languages. The bug originally happened because several agents silently omitted
|
||||||
|
the instruction (fixed in 6b384f7); this test codifies the invariant so a future
|
||||||
|
refactor can't quietly drop it again.
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
||||||
|
|
||||||
|
_AGENTS_DIR = Path(__file__).resolve().parents[1] / "tradingagents" / "agents"
|
||||||
|
|
||||||
|
# Every node whose text reaches the saved report. If you add a report-producing
|
||||||
|
# agent, add it here — and make it call get_language_instruction().
|
||||||
|
REPORT_AGENTS = [
|
||||||
|
"analysts/market_analyst.py",
|
||||||
|
"analysts/news_analyst.py",
|
||||||
|
"analysts/fundamentals_analyst.py",
|
||||||
|
"analysts/sentiment_analyst.py",
|
||||||
|
"researchers/bull_researcher.py",
|
||||||
|
"researchers/bear_researcher.py",
|
||||||
|
"managers/research_manager.py",
|
||||||
|
"managers/portfolio_manager.py",
|
||||||
|
"risk_mgmt/aggressive_debator.py",
|
||||||
|
"risk_mgmt/conservative_debator.py",
|
||||||
|
"risk_mgmt/neutral_debator.py",
|
||||||
|
"trader/trader.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestLanguageInstruction:
|
||||||
|
def test_english_adds_no_tokens(self, monkeypatch):
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
set_config({"output_language": "English"})
|
||||||
|
assert get_language_instruction() == ""
|
||||||
|
|
||||||
|
def test_non_english_emits_directive(self):
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
set_config({"output_language": "中文"})
|
||||||
|
out = get_language_instruction()
|
||||||
|
assert "中文" in out
|
||||||
|
assert "entire response" in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.parametrize("rel", REPORT_AGENTS)
|
||||||
|
def test_report_agent_applies_language_instruction(rel):
|
||||||
|
path = _AGENTS_DIR / rel
|
||||||
|
assert path.exists(), f"missing agent module: {rel}"
|
||||||
|
src = path.read_text(encoding="utf-8")
|
||||||
|
assert "get_language_instruction()" in src, (
|
||||||
|
f"{rel} does not apply get_language_instruction(); its output would "
|
||||||
|
f"ignore the configured output_language (#740/#801)."
|
||||||
|
)
|
||||||
170
tests/test_instrument_identity.py
Normal file
170
tests/test_instrument_identity.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""Tests for deterministic instrument-identity resolution (#814) and the
|
||||||
|
context-anchored message placeholder (#888)."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
build_instrument_context,
|
||||||
|
create_msg_delete,
|
||||||
|
get_instrument_context_from_state,
|
||||||
|
resolve_instrument_identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class ResolveInstrumentIdentityTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
resolve_instrument_identity.cache_clear()
|
||||||
|
|
||||||
|
def test_resolves_company_metadata_from_yfinance(self):
|
||||||
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||||
|
mock.return_value.info = {
|
||||||
|
"longName": "TOTO LTD.",
|
||||||
|
"shortName": "TOTO",
|
||||||
|
"sector": "Industrials",
|
||||||
|
"industry": "Building Products & Equipment",
|
||||||
|
"exchange": "PNK",
|
||||||
|
"quoteType": "EQUITY",
|
||||||
|
}
|
||||||
|
identity = resolve_instrument_identity("totdy")
|
||||||
|
mock.assert_called_once_with("TOTDY")
|
||||||
|
self.assertEqual(identity["company_name"], "TOTO LTD.")
|
||||||
|
self.assertEqual(identity["sector"], "Industrials")
|
||||||
|
self.assertEqual(identity["industry"], "Building Products & Equipment")
|
||||||
|
self.assertEqual(identity["exchange"], "PNK")
|
||||||
|
|
||||||
|
def test_falls_back_to_short_name(self):
|
||||||
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||||
|
mock.return_value.info = {"shortName": "TOTO", "sector": "Industrials"}
|
||||||
|
identity = resolve_instrument_identity("TOTDY")
|
||||||
|
self.assertEqual(identity["company_name"], "TOTO")
|
||||||
|
|
||||||
|
def test_skips_placeholder_values(self):
|
||||||
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||||
|
mock.return_value.info = {"longName": " ", "sector": "None", "industry": "n/a"}
|
||||||
|
identity = resolve_instrument_identity("TOTDY")
|
||||||
|
self.assertEqual(identity, {})
|
||||||
|
|
||||||
|
def test_fails_open_on_exception(self):
|
||||||
|
with patch(
|
||||||
|
"tradingagents.agents.utils.agent_utils.yf.Ticker",
|
||||||
|
side_effect=RuntimeError("rate limited"),
|
||||||
|
):
|
||||||
|
self.assertEqual(resolve_instrument_identity("TOTDY"), {})
|
||||||
|
|
||||||
|
def test_result_is_cached(self):
|
||||||
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||||
|
mock.return_value.info = {"longName": "TOTO LTD."}
|
||||||
|
first = resolve_instrument_identity("TOTDY")
|
||||||
|
second = resolve_instrument_identity("TOTDY")
|
||||||
|
mock.assert_called_once() # second call served from cache
|
||||||
|
self.assertEqual(first, second)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class BuildInstrumentContextTests(unittest.TestCase):
|
||||||
|
def test_mentions_exact_symbol_without_identity(self):
|
||||||
|
context = build_instrument_context("7203.T")
|
||||||
|
self.assertIn("7203.T", context)
|
||||||
|
self.assertIn("exchange suffix", context)
|
||||||
|
self.assertNotIn("Resolved identity", context)
|
||||||
|
|
||||||
|
def test_injects_resolved_identity(self):
|
||||||
|
context = build_instrument_context(
|
||||||
|
"TOTDY", "stock",
|
||||||
|
{
|
||||||
|
"company_name": "TOTO LTD.",
|
||||||
|
"sector": "Industrials",
|
||||||
|
"industry": "Building Products & Equipment",
|
||||||
|
"exchange": "PNK",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertIn("Company: TOTO LTD.", context)
|
||||||
|
self.assertIn("Industrials / Building Products & Equipment", context)
|
||||||
|
self.assertIn("Exchange: PNK", context)
|
||||||
|
self.assertIn("Do not substitute a different company", context)
|
||||||
|
|
||||||
|
def test_crypto_uses_name_label_and_keeps_hint(self):
|
||||||
|
context = build_instrument_context(
|
||||||
|
"BTC-USD", "crypto", {"company_name": "Bitcoin USD"}
|
||||||
|
)
|
||||||
|
self.assertIn("Name: Bitcoin USD", context)
|
||||||
|
self.assertIn("crypto asset rather than a company", context)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class GetInstrumentContextFromStateTests(unittest.TestCase):
|
||||||
|
def test_prefers_precomputed_context(self):
|
||||||
|
state = {"company_of_interest": "TOTDY", "instrument_context": "PRECOMPUTED"}
|
||||||
|
self.assertEqual(get_instrument_context_from_state(state), "PRECOMPUTED")
|
||||||
|
|
||||||
|
def test_fallback_is_network_free_ticker_only(self):
|
||||||
|
# No instrument_context and no yfinance call — must not hit the network.
|
||||||
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||||
|
context = get_instrument_context_from_state(
|
||||||
|
{"company_of_interest": "NVDA", "asset_type": "stock"}
|
||||||
|
)
|
||||||
|
mock.assert_not_called()
|
||||||
|
self.assertIn("NVDA", context)
|
||||||
|
|
||||||
|
def test_fallback_respects_asset_type(self):
|
||||||
|
context = get_instrument_context_from_state(
|
||||||
|
{"company_of_interest": "BTC-USD", "asset_type": "crypto"}
|
||||||
|
)
|
||||||
|
self.assertIn("crypto asset", context)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class ContextAnchoredPlaceholderTests(unittest.TestCase):
|
||||||
|
"""#888 — the message-clear placeholder must not be a bare 'Continue'."""
|
||||||
|
|
||||||
|
def _run(self, state_extra):
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="old", id="h1"),
|
||||||
|
AIMessage(content="reply", id="a1"),
|
||||||
|
],
|
||||||
|
**state_extra,
|
||||||
|
}
|
||||||
|
return create_msg_delete()(state)
|
||||||
|
|
||||||
|
def test_placeholder_is_not_bare_continue(self):
|
||||||
|
result = self._run(
|
||||||
|
{"company_of_interest": "EC", "asset_type": "stock", "trade_date": "2026-05-28"}
|
||||||
|
)
|
||||||
|
placeholder = result["messages"][-1]
|
||||||
|
self.assertIsInstance(placeholder, HumanMessage)
|
||||||
|
self.assertNotEqual(placeholder.content.strip(), "Continue")
|
||||||
|
|
||||||
|
def test_placeholder_carries_resolved_identity(self):
|
||||||
|
result = self._run(
|
||||||
|
{
|
||||||
|
"company_of_interest": "EC",
|
||||||
|
"instrument_context": "The instrument to analyze is `EC`. Resolved identity: Company: Ecopetrol.",
|
||||||
|
"trade_date": "2026-05-28",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
content = result["messages"][-1].content
|
||||||
|
self.assertIn("Ecopetrol", content)
|
||||||
|
self.assertIn("2026-05-28", content)
|
||||||
|
|
||||||
|
def test_old_messages_are_removed(self):
|
||||||
|
result = self._run({"company_of_interest": "EC", "trade_date": "2026-05-28"})
|
||||||
|
removals = [m for m in result["messages"] if isinstance(m, RemoveMessage)]
|
||||||
|
humans = [m for m in result["messages"] if isinstance(m, HumanMessage)]
|
||||||
|
self.assertEqual(len(removals), 2)
|
||||||
|
self.assertEqual(len(humans), 1)
|
||||||
|
|
||||||
|
def test_safe_defaults_when_state_minimal(self):
|
||||||
|
result = create_msg_delete()({"messages": [], "company_of_interest": "EC"})
|
||||||
|
placeholder = result["messages"][-1]
|
||||||
|
self.assertNotEqual(placeholder.content.strip(), "Continue")
|
||||||
|
self.assertIn("EC", placeholder.content)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
76
tests/test_market_data_validator.py
Normal file
76
tests/test_market_data_validator.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""Tests for the deterministic market-data verification snapshot (#830/#881)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.dataflows.market_data_validator as validator
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_ohlcv() -> pd.DataFrame:
|
||||||
|
dates = pd.bdate_range("2026-04-01", "2026-05-20")
|
||||||
|
closes = [100 + i for i in range(len(dates))]
|
||||||
|
return pd.DataFrame({
|
||||||
|
"Date": dates,
|
||||||
|
"Open": [c - 0.5 for c in closes],
|
||||||
|
"High": [c + 1.0 for c in closes],
|
||||||
|
"Low": [c - 1.0 for c in closes],
|
||||||
|
"Close": closes,
|
||||||
|
"Volume": [1_000_000 + i for i in range(len(dates))],
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestVerifiedSnapshot:
|
||||||
|
def test_excludes_future_rows(self, monkeypatch):
|
||||||
|
data = pd.concat([
|
||||||
|
_sample_ohlcv(),
|
||||||
|
pd.DataFrame({"Date": [pd.Timestamp("2026-06-01")], "Open": [999.0],
|
||||||
|
"High": [999.0], "Low": [999.0], "Close": [999.0], "Volume": [999]}),
|
||||||
|
], ignore_index=True)
|
||||||
|
monkeypatch.setattr(validator, "load_ohlcv", lambda s, d: data)
|
||||||
|
|
||||||
|
snap = validator.build_verified_market_snapshot("COF", "2026-05-13")
|
||||||
|
assert "Verified market data snapshot for COF" in snap
|
||||||
|
assert "Requested analysis date: 2026-05-13" in snap
|
||||||
|
assert "Latest trading row used: 2026-05-13" in snap
|
||||||
|
assert "999.00" not in snap # future row excluded
|
||||||
|
assert "boll_lb" in snap # indicators present
|
||||||
|
|
||||||
|
def test_uses_previous_trading_day_when_date_is_weekend(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(validator, "load_ohlcv", lambda s, d: _sample_ohlcv())
|
||||||
|
# 2026-05-16 is a Saturday; latest row should be Fri 2026-05-15
|
||||||
|
snap = validator.build_verified_market_snapshot("COF", "2026-05-16")
|
||||||
|
assert "Latest trading row used: 2026-05-15" in snap
|
||||||
|
assert "Recent verified closes" in snap
|
||||||
|
|
||||||
|
def test_raises_when_no_rows_on_or_before_date(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(validator, "load_ohlcv", lambda s, d: _sample_ohlcv())
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validator.build_verified_market_snapshot("COF", "2020-01-01")
|
||||||
|
|
||||||
|
def test_raises_on_empty_data(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(validator, "load_ohlcv", lambda s, d: pd.DataFrame())
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validator.build_verified_market_snapshot("COF", "2026-05-13")
|
||||||
|
|
||||||
|
def test_look_back_window_capped_at_30(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(validator, "load_ohlcv", lambda s, d: _sample_ohlcv())
|
||||||
|
snap = validator.build_verified_market_snapshot("COF", "2026-05-20", look_back_days=999)
|
||||||
|
# last-N closes table has at most 30 data rows
|
||||||
|
close_rows = [ln for ln in snap.splitlines() if ln.startswith("| 2026-")]
|
||||||
|
assert 0 < len(close_rows) <= 30
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTool:
|
||||||
|
def test_tool_delegates_to_builder(self, monkeypatch):
|
||||||
|
from tradingagents.agents.utils.market_data_validation_tools import (
|
||||||
|
get_verified_market_snapshot,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(validator, "load_ohlcv", lambda s, d: _sample_ohlcv())
|
||||||
|
out = get_verified_market_snapshot.invoke(
|
||||||
|
{"symbol": "COF", "curr_date": "2026-05-20"}
|
||||||
|
)
|
||||||
|
assert "Verified market data snapshot for COF" in out
|
||||||
23
tests/test_market_toolnode.py
Normal file
23
tests/test_market_toolnode.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""The market analyst is bound (and prompt-instructed) to call
|
||||||
|
get_verified_market_snapshot; if the executor ToolNode doesn't register it, the
|
||||||
|
call fails and the model reports the tool "unavailable" and skips verification.
|
||||||
|
|
||||||
|
Regression guard for that wiring gap (snapshot bound to the LLM but missing from
|
||||||
|
the market ToolNode).
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_market_toolnode_can_execute_verified_snapshot():
|
||||||
|
# _create_tool_nodes does not use self -> call unbound (avoids building LLMs).
|
||||||
|
nodes = TradingAgentsGraph._create_tool_nodes(None)
|
||||||
|
market_tools = set(nodes["market"].tools_by_name)
|
||||||
|
assert "get_verified_market_snapshot" in market_tools, (
|
||||||
|
"get_verified_market_snapshot is bound to the market analyst but not "
|
||||||
|
"registered in the market ToolNode, so the model's call fails."
|
||||||
|
)
|
||||||
|
# the other core market tools must remain too
|
||||||
|
assert {"get_stock_data", "get_indicators"} <= market_tools
|
||||||
@@ -1,15 +1,16 @@
|
|||||||
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
|
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pandas as pd
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||||
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
|
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
|
||||||
|
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||||
|
from tradingagents.graph.propagation import Propagator
|
||||||
from tradingagents.graph.reflection import Reflector
|
from tradingagents.graph.reflection import Reflector
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.graph.propagation import Propagator
|
|
||||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
|
||||||
|
|
||||||
_SEP = TradingMemoryLog._SEPARATOR
|
_SEP = TradingMemoryLog._SEPARATOR
|
||||||
|
|
||||||
@@ -563,6 +564,16 @@ class TestDeferredReflection:
|
|||||||
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "RELIANCE.NS") == "^NSEI"
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "RELIANCE.NS") == "^NSEI"
|
||||||
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "AZN.L") == "^FTSE"
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "AZN.L") == "^FTSE"
|
||||||
|
|
||||||
|
def test_resolve_benchmark_china_a_shares(self):
|
||||||
|
"""A-share tickers route to their exchange composite (uses the real
|
||||||
|
default benchmark_map, since A-share support relies on it)."""
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
mock_graph.config = {"benchmark_ticker": None,
|
||||||
|
"benchmark_map": DEFAULT_CONFIG["benchmark_map"]}
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "600519.SS") == "000001.SS"
|
||||||
|
assert TradingAgentsGraph._resolve_benchmark(mock_graph, "000001.SZ") == "399001.SZ"
|
||||||
|
|
||||||
def test_resolve_benchmark_us_ticker_defaults_to_spy(self):
|
def test_resolve_benchmark_us_ticker_defaults_to_spy(self):
|
||||||
"""US tickers (no dotted suffix) take the empty-suffix entry."""
|
"""US tickers (no dotted suffix) take the empty-suffix entry."""
|
||||||
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||||
|
|||||||
@@ -25,22 +25,22 @@ def _client(model: str = "MiniMax-M2.7"):
|
|||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
class TestMinimaxReasoningSplit:
|
class TestMinimaxReasoningSplit:
|
||||||
def test_request_payload_sets_reasoning_split(self):
|
def test_reasoning_split_sent_via_extra_body_not_top_level(self):
|
||||||
|
# Must be in extra_body, not top-level: the openai SDK validates
|
||||||
|
# top-level params and rejects unknown ones like reasoning_split (#826).
|
||||||
payload = _client()._get_request_payload([HumanMessage(content="hi")])
|
payload = _client()._get_request_payload([HumanMessage(content="hi")])
|
||||||
assert payload.get("reasoning_split") is True
|
assert payload.get("extra_body", {}).get("reasoning_split") is True
|
||||||
|
assert "reasoning_split" not in payload # never top-level
|
||||||
|
|
||||||
def test_caller_supplied_reasoning_split_is_preserved(self):
|
def test_non_reasoning_minimax_does_not_inject_reasoning_split(self):
|
||||||
"""If the user explicitly sets reasoning_split, don't override it
|
"""Coding Plan / MiniMax-Text-01 / any non-M2-prefixed model must NOT
|
||||||
(setdefault semantics — caller wins)."""
|
receive reasoning_split at all (top-level or extra_body) (#826)."""
|
||||||
client = _client()
|
for model in ("minimax-text-01", "MiniMax-Coding-Plan"):
|
||||||
payload = client._get_request_payload(
|
payload = _client(model)._get_request_payload(
|
||||||
[HumanMessage(content="hi")],
|
[HumanMessage(content="hi")]
|
||||||
reasoning_split=False,
|
)
|
||||||
)
|
assert "reasoning_split" not in payload
|
||||||
# langchain may or may not surface that kwarg into the payload;
|
assert "reasoning_split" not in payload.get("extra_body", {})
|
||||||
# what matters is we don't blindly overwrite a non-default value
|
|
||||||
# the caller passed. setdefault leaves an existing value alone.
|
|
||||||
assert payload.get("reasoning_split") in (False, True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
|
|||||||
79
tests/test_news_lookahead.py
Normal file
79
tests/test_news_lookahead.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""yfinance news must not leak future-dated (or undated, in a backtest) articles
|
||||||
|
into a historical window.
|
||||||
|
|
||||||
|
Regressions for #992 (flat articles bypassed the date filter), #1007 (global
|
||||||
|
news injected future articles), #993 (empty-after-filter returned a blank body).
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.dataflows.yfinance_news as ynews
|
||||||
|
|
||||||
|
|
||||||
|
def _epoch(date_str):
|
||||||
|
return int(time.mktime(datetime.strptime(date_str, "%Y-%m-%d").timetuple()))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_flat_article_publish_time_is_parsed():
|
||||||
|
# #992: flat articles now carry a pub_date (was always None -> unfilterable).
|
||||||
|
data = ynews._extract_article_data(
|
||||||
|
{"title": "X", "publisher": "P", "link": "l", "providerPublishTime": _epoch("2025-05-09")}
|
||||||
|
)
|
||||||
|
assert data["pub_date"] is not None
|
||||||
|
assert data["pub_date"].strftime("%Y-%m-%d") == "2025-05-09"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_window_excludes_future_and_undated_in_backtest():
|
||||||
|
start = datetime(2025, 5, 1)
|
||||||
|
end = datetime(2025, 5, 9) # historical window (well in the past)
|
||||||
|
inside = datetime(2025, 5, 5)
|
||||||
|
future = datetime(2025, 6, 1)
|
||||||
|
assert ynews._in_news_window(inside, start, end) is True
|
||||||
|
assert ynews._in_news_window(future, start, end) is False # look-ahead blocked
|
||||||
|
assert ynews._in_news_window(None, start, end) is False # undated -> excluded in backtest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_window_keeps_undated_in_live_window():
|
||||||
|
# Live window (reaches today): undated articles can't be "future", so keep them.
|
||||||
|
start = datetime.now()
|
||||||
|
end = datetime.now()
|
||||||
|
assert ynews._in_news_window(None, start, end) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_global_news_future_flat_article_excluded(monkeypatch):
|
||||||
|
# #1007: a flat, future-dated global article must not appear in a historical run.
|
||||||
|
future_article = {"title": "FUTURE EVENT", "publisher": "P", "link": "l",
|
||||||
|
"providerPublishTime": _epoch("2025-06-01")}
|
||||||
|
past_article = {"title": "PAST EVENT", "publisher": "P", "link": "l",
|
||||||
|
"providerPublishTime": _epoch("2025-05-05")}
|
||||||
|
|
||||||
|
class FakeSearch:
|
||||||
|
def __init__(self, *a, **k):
|
||||||
|
self.news = [future_article, past_article]
|
||||||
|
|
||||||
|
monkeypatch.setattr(ynews.yf, "Search", FakeSearch)
|
||||||
|
out = ynews.get_global_news_yfinance("2025-05-09", look_back_days=7, limit=10)
|
||||||
|
assert "PAST EVENT" in out
|
||||||
|
assert "FUTURE EVENT" not in out # #1007
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_global_news_empty_after_filter_is_informative(monkeypatch):
|
||||||
|
# #993: everything filtered out -> a clear message, not a blank-bodied report.
|
||||||
|
only_future = {"title": "FUTURE", "publisher": "P", "link": "l",
|
||||||
|
"providerPublishTime": _epoch("2025-06-01")}
|
||||||
|
|
||||||
|
class FakeSearch:
|
||||||
|
def __init__(self, *a, **k):
|
||||||
|
self.news = [only_future]
|
||||||
|
|
||||||
|
monkeypatch.setattr(ynews.yf, "Search", FakeSearch)
|
||||||
|
out = ynews.get_global_news_yfinance("2025-05-09", look_back_days=7, limit=10)
|
||||||
|
assert "No global news found" in out
|
||||||
|
assert "###" not in out # no empty article body
|
||||||
88
tests/test_no_data_handling.py
Normal file
88
tests/test_no_data_handling.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""Tests that empty vendor results never become fabricated data.
|
||||||
|
|
||||||
|
Covers two systematic fixes:
|
||||||
|
- load_ohlcv must not cache an empty download (cache poisoning), and must
|
||||||
|
raise NoMarketDataError instead of returning an empty frame.
|
||||||
|
- route_to_vendor must convert NoMarketDataError into a single explicit
|
||||||
|
"NO_DATA_AVAILABLE" sentinel after all vendors are exhausted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.dataflows import interface, stockstats_utils
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
from tradingagents.dataflows.symbol_utils import NoMarketDataError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestLoadOhlcvNoPoison(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self._tmp = os.path.join(os.path.dirname(__file__), "_tmp_cache")
|
||||||
|
os.makedirs(self._tmp, exist_ok=True)
|
||||||
|
set_config({"data_cache_dir": self._tmp})
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
for f in os.listdir(self._tmp):
|
||||||
|
os.remove(os.path.join(self._tmp, f))
|
||||||
|
os.rmdir(self._tmp)
|
||||||
|
|
||||||
|
def test_empty_download_raises_and_does_not_cache(self):
|
||||||
|
empty = pd.DataFrame()
|
||||||
|
with mock.patch.object(stockstats_utils.yf, "download", return_value=empty), \
|
||||||
|
self.assertRaises(NoMarketDataError):
|
||||||
|
stockstats_utils.load_ohlcv("FAKE", "2026-01-01")
|
||||||
|
# Nothing should have been written to the cache.
|
||||||
|
self.assertEqual(os.listdir(self._tmp), [])
|
||||||
|
|
||||||
|
# A second call must re-attempt the fetch (no poisoned cache served).
|
||||||
|
with mock.patch.object(stockstats_utils.yf, "download", return_value=empty) as dl2:
|
||||||
|
with self.assertRaises(NoMarketDataError):
|
||||||
|
stockstats_utils.load_ohlcv("FAKE", "2026-01-01")
|
||||||
|
self.assertTrue(dl2.called)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRouteToVendorSentinel(unittest.TestCase):
|
||||||
|
def test_no_data_from_all_vendors_returns_sentinel(self):
|
||||||
|
def raises_no_data(symbol, *a, **k):
|
||||||
|
raise NoMarketDataError(symbol, "GC=F", "no rows")
|
||||||
|
|
||||||
|
patched = {"yfinance": raises_no_data, "alpha_vantage": raises_no_data}
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS, {"get_stock_data": patched}, clear=False
|
||||||
|
):
|
||||||
|
result = interface.route_to_vendor(
|
||||||
|
"get_stock_data", "XAUUSD+", "2026-01-01", "2026-01-10"
|
||||||
|
)
|
||||||
|
self.assertIn("NO_DATA_AVAILABLE", result)
|
||||||
|
self.assertIn("XAUUSD+", result)
|
||||||
|
self.assertIn("GC=F", result)
|
||||||
|
self.assertIn("Do not estimate", result)
|
||||||
|
|
||||||
|
def test_unconfigured_fallback_does_not_mask_no_data(self):
|
||||||
|
# When the primary vendor reports no data and the fallback is simply
|
||||||
|
# unavailable (e.g. missing API key -> raises), the no-data sentinel
|
||||||
|
# must win rather than the fallback's incidental error crashing out.
|
||||||
|
def raises_no_data(symbol, *a, **k):
|
||||||
|
raise NoMarketDataError(symbol, symbol, "no rows")
|
||||||
|
|
||||||
|
def raises_unavailable(symbol, *a, **k):
|
||||||
|
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
|
||||||
|
|
||||||
|
patched = {"yfinance": raises_no_data, "alpha_vantage": raises_unavailable}
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS, {"get_stock_data": patched}, clear=False
|
||||||
|
):
|
||||||
|
result = interface.route_to_vendor(
|
||||||
|
"get_stock_data", "FAKE", "2026-01-01", "2026-01-10"
|
||||||
|
)
|
||||||
|
self.assertIn("NO_DATA_AVAILABLE", result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -7,7 +7,24 @@ import importlib
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
# ---- openai_client side: _resolve_provider_base_url -----------------------
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def _resync_reloaded_modules():
|
||||||
|
"""Restore module state after this file's importlib.reload() calls.
|
||||||
|
|
||||||
|
Several tests below reload ``cli.utils`` to re-evaluate OLLAMA_BASE_URL.
|
||||||
|
That leaves ``cli.main``'s star-imported names (e.g. get_ticker) bound to
|
||||||
|
the pre-reload module objects, which breaks identity checks in unrelated
|
||||||
|
tests that happen to run afterward. Re-sync once on teardown so the reload
|
||||||
|
doesn't leak across test modules.
|
||||||
|
"""
|
||||||
|
yield
|
||||||
|
import cli.main
|
||||||
|
import cli.utils
|
||||||
|
importlib.reload(cli.utils)
|
||||||
|
importlib.reload(cli.main)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- openai_client side: registry-driven base_url resolution --------------
|
||||||
|
|
||||||
|
|
||||||
def _reload_client():
|
def _reload_client():
|
||||||
@@ -15,16 +32,20 @@ def _reload_client():
|
|||||||
return importlib.reload(mod)
|
return importlib.reload(mod)
|
||||||
|
|
||||||
|
|
||||||
|
def _base_url(mod, provider, **kwargs):
|
||||||
|
return str(mod.OpenAIClient(model="m", provider=provider, **kwargs).get_llm().openai_api_base)
|
||||||
|
|
||||||
|
|
||||||
def test_resolver_returns_default_when_env_unset(monkeypatch):
|
def test_resolver_returns_default_when_env_unset(monkeypatch):
|
||||||
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
mod = _reload_client()
|
mod = _reload_client()
|
||||||
assert mod._resolve_provider_base_url("ollama") == "http://localhost:11434/v1"
|
assert _base_url(mod, "ollama") == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_resolver_returns_env_when_set(monkeypatch):
|
def test_resolver_returns_env_when_set(monkeypatch):
|
||||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-ollama:11434/v1")
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-ollama:11434/v1")
|
||||||
mod = _reload_client()
|
mod = _reload_client()
|
||||||
assert mod._resolve_provider_base_url("ollama") == "http://remote-ollama:11434/v1"
|
assert _base_url(mod, "ollama") == "http://remote-ollama:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_resolver_evaluation_is_call_time(monkeypatch):
|
def test_resolver_evaluation_is_call_time(monkeypatch):
|
||||||
@@ -32,15 +53,15 @@ def test_resolver_evaluation_is_call_time(monkeypatch):
|
|||||||
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
mod = _reload_client()
|
mod = _reload_client()
|
||||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://late-set:11434/v1")
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://late-set:11434/v1")
|
||||||
assert mod._resolve_provider_base_url("ollama") == "http://late-set:11434/v1"
|
assert _base_url(mod, "ollama") == "http://late-set:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_resolver_does_not_affect_other_providers(monkeypatch):
|
def test_resolver_does_not_affect_other_providers(monkeypatch):
|
||||||
"""OLLAMA_BASE_URL should NOT leak into xai/deepseek/etc."""
|
"""OLLAMA_BASE_URL should NOT leak into xai/deepseek/etc."""
|
||||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://elsewhere/v1")
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://elsewhere/v1")
|
||||||
mod = _reload_client()
|
mod = _reload_client()
|
||||||
assert mod._resolve_provider_base_url("xai") == "https://api.x.ai/v1"
|
assert _base_url(mod, "xai") == "https://api.x.ai/v1"
|
||||||
assert mod._resolve_provider_base_url("deepseek") == "https://api.deepseek.com"
|
assert _base_url(mod, "deepseek") == "https://api.deepseek.com"
|
||||||
|
|
||||||
|
|
||||||
def test_client_get_llm_picks_up_env(monkeypatch):
|
def test_client_get_llm_picks_up_env(monkeypatch):
|
||||||
|
|||||||
74
tests/test_openai_compatible_provider.py
Normal file
74
tests/test_openai_compatible_provider.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Generic OpenAI-compatible provider (vLLM / LM Studio / llama.cpp / relays).
|
||||||
|
|
||||||
|
Verifies the user-supplied base_url is required and honored, the key is optional
|
||||||
|
(keyless local default), Chat Completions (not the Responses API) is used, any
|
||||||
|
model name is accepted, and the env backend URL precedence (#978).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
from tradingagents.llm_clients.validators import validate_model
|
||||||
|
|
||||||
|
# Note: assert by class NAME, not isinstance — other tests reload the
|
||||||
|
# openai_client module, which would otherwise create a second class identity.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_factory_routes_to_openai_client():
|
||||||
|
client = create_llm_client(
|
||||||
|
provider="openai_compatible", model="my-model", base_url="http://localhost:8000/v1"
|
||||||
|
)
|
||||||
|
assert type(client).__name__ == "OpenAIClient"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_base_url_required(monkeypatch):
|
||||||
|
monkeypatch.delenv("OPENAI_COMPATIBLE_API_KEY", raising=False)
|
||||||
|
with pytest.raises(ValueError, match="requires a base_url"):
|
||||||
|
create_llm_client(provider="openai_compatible", model="m").get_llm()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_keyless_local_uses_placeholder_and_chat_completions(monkeypatch):
|
||||||
|
monkeypatch.delenv("OPENAI_COMPATIBLE_API_KEY", raising=False)
|
||||||
|
llm = create_llm_client(
|
||||||
|
provider="openai_compatible", model="qwen2.5", base_url="http://localhost:8000/v1"
|
||||||
|
).get_llm()
|
||||||
|
assert type(llm).__name__ == "NormalizedChatOpenAI"
|
||||||
|
assert str(llm.openai_api_base) == "http://localhost:8000/v1"
|
||||||
|
# keyless local servers: a placeholder key is sent
|
||||||
|
key = llm.openai_api_key.get_secret_value() if hasattr(llm.openai_api_key, "get_secret_value") else llm.openai_api_key
|
||||||
|
assert key == "EMPTY"
|
||||||
|
# must use Chat Completions, not OpenAI's Responses API
|
||||||
|
assert getattr(llm, "use_responses_api", False) in (False, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_optional_key_from_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_COMPATIBLE_API_KEY", "sk-relay-123")
|
||||||
|
llm = create_llm_client(
|
||||||
|
provider="openai_compatible", model="m", base_url="https://relay.example/v1"
|
||||||
|
).get_llm()
|
||||||
|
key = llm.openai_api_key.get_secret_value() if hasattr(llm.openai_api_key, "get_secret_value") else llm.openai_api_key
|
||||||
|
assert key == "sk-relay-123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_any_model_accepted_no_forced_key():
|
||||||
|
assert validate_model("openai_compatible", "literally-anything") is True
|
||||||
|
# The key env exists (read for keyed relays) but the provider is marked
|
||||||
|
# key-optional, so the CLI never forces a prompt and keyless servers work.
|
||||||
|
assert get_api_key_env("openai_compatible") == "OPENAI_COMPATIBLE_API_KEY"
|
||||||
|
from tradingagents.llm_clients.openai_client import OPENAI_COMPATIBLE_PROVIDERS
|
||||||
|
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].key_optional is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_env_backend_url_precedence():
|
||||||
|
# #978: explicit env URL wins over the menu/default regardless of provider source.
|
||||||
|
from cli.utils import resolve_backend_url
|
||||||
|
assert resolve_backend_url("openai", "https://api.openai.com/v1", env_url="http://proxy/v1") == "http://proxy/v1"
|
||||||
|
assert resolve_backend_url("openai", "https://api.openai.com/v1", env_url=None) == "https://api.openai.com/v1"
|
||||||
|
assert resolve_backend_url("deepseek", None, None) == "https://api.deepseek.com"
|
||||||
43
tests/test_openai_responses_base_url.py
Normal file
43
tests/test_openai_responses_base_url.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""The Responses API only exists on native OpenAI; a custom base_url on the
|
||||||
|
openai provider must fall back to Chat Completions (#1024)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.openai_client import (
|
||||||
|
OpenAIClient,
|
||||||
|
_is_native_openai_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class NativeBaseUrlTests:
|
||||||
|
def test_unset_is_native(self):
|
||||||
|
assert _is_native_openai_base_url(None) is True
|
||||||
|
assert _is_native_openai_base_url("") is True
|
||||||
|
|
||||||
|
def test_openai_hosts_are_native(self):
|
||||||
|
assert _is_native_openai_base_url("https://api.openai.com/v1") is True
|
||||||
|
assert _is_native_openai_base_url("api.openai.com/v1") is True
|
||||||
|
|
||||||
|
def test_custom_endpoints_are_not_native(self):
|
||||||
|
assert _is_native_openai_base_url("http://localhost:1234/v1") is False
|
||||||
|
assert _is_native_openai_base_url("https://my-gateway.example.com/v1") is False
|
||||||
|
assert _is_native_openai_base_url("https://api.openai.com.evil.com/v1") is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class ResponsesApiSelectionTests:
|
||||||
|
def test_native_openai_enables_responses_api(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||||
|
llm = OpenAIClient("gpt-5.5", provider="openai").get_llm()
|
||||||
|
assert getattr(llm, "use_responses_api", False) is True
|
||||||
|
|
||||||
|
def test_custom_base_url_disables_responses_api(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||||
|
llm = OpenAIClient(
|
||||||
|
"gpt-5.5", base_url="http://localhost:1234/v1", provider="openai"
|
||||||
|
).get_llm()
|
||||||
|
# use_responses_api should be absent/False so the client speaks Chat Completions.
|
||||||
|
assert getattr(llm, "use_responses_api", False) is False
|
||||||
122
tests/test_openrouter_model_select.py
Normal file
122
tests/test_openrouter_model_select.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""OpenRouter model selection: prompts are labeled by mode (#1000); required
|
||||||
|
prompts exit cleanly on cancel; the output-language prompt defaults to English
|
||||||
|
on cancel; and the OpenRouter list is newest-first."""
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cli import utils
|
||||||
|
|
||||||
|
|
||||||
|
def _asks(value):
|
||||||
|
return mock.Mock(ask=mock.Mock(return_value=value))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestOpenRouterPromptLabel:
|
||||||
|
@pytest.mark.parametrize("mode,label", [("quick", "Quick-Thinking"), ("deep", "Deep-Thinking")])
|
||||||
|
def test_prompt_states_the_mode(self, mode, label):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_select(message, **kwargs):
|
||||||
|
captured["message"] = message
|
||||||
|
return _asks("openrouter/some-model")
|
||||||
|
|
||||||
|
with mock.patch.object(utils, "_fetch_openrouter_models",
|
||||||
|
return_value=[("Some Model", "openrouter/some-model")]), \
|
||||||
|
mock.patch.object(utils.questionary, "select", side_effect=fake_select):
|
||||||
|
out = utils.select_openrouter_model(mode)
|
||||||
|
|
||||||
|
assert label in captured["message"]
|
||||||
|
assert out == "openrouter/some-model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestOpenRouterLatestFirst:
|
||||||
|
def test_models_sorted_newest_first(self):
|
||||||
|
payload = {"data": [
|
||||||
|
{"id": "old/model", "name": "Old", "created": 1000},
|
||||||
|
{"id": "new/model", "name": "New", "created": 3000},
|
||||||
|
{"id": "mid/model", "name": "Mid", "created": 2000},
|
||||||
|
]}
|
||||||
|
resp = mock.Mock()
|
||||||
|
resp.json.return_value = payload
|
||||||
|
resp.raise_for_status = mock.Mock()
|
||||||
|
with mock.patch("requests.get", return_value=resp):
|
||||||
|
out = utils._fetch_openrouter_models()
|
||||||
|
assert [mid for _, mid in out] == ["new/model", "mid/model", "old/model"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestMainstreamFilter:
|
||||||
|
def test_dropdown_prefers_mainstream_over_niche(self):
|
||||||
|
# _fetch returns newest-first; the shortlist should drop niche namespaces.
|
||||||
|
models = [
|
||||||
|
("Fusion", "openrouter/fusion"),
|
||||||
|
("Niche", "nex-agi/nex-n2-pro:free"),
|
||||||
|
("Claude", "anthropic/claude-x"),
|
||||||
|
("GPT", "openai/gpt-x"),
|
||||||
|
]
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_select(message, **kwargs):
|
||||||
|
captured["values"] = [c.value for c in kwargs["choices"]]
|
||||||
|
return _asks("anthropic/claude-x")
|
||||||
|
|
||||||
|
with mock.patch.object(utils, "_fetch_openrouter_models", return_value=models), \
|
||||||
|
mock.patch.object(utils.questionary, "select", side_effect=fake_select):
|
||||||
|
utils.select_openrouter_model("quick")
|
||||||
|
|
||||||
|
assert "anthropic/claude-x" in captured["values"]
|
||||||
|
assert "openai/gpt-x" in captured["values"]
|
||||||
|
assert "openrouter/fusion" not in captured["values"]
|
||||||
|
assert "nex-agi/nex-n2-pro:free" not in captured["values"]
|
||||||
|
assert "custom" in captured["values"] # escape hatch preserved
|
||||||
|
|
||||||
|
def test_falls_back_to_all_when_no_mainstream(self):
|
||||||
|
models = [("Niche", "nex-agi/x"), ("Other", "thedrummer/y")]
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_select(message, **kwargs):
|
||||||
|
captured["values"] = [c.value for c in kwargs["choices"]]
|
||||||
|
return _asks("nex-agi/x")
|
||||||
|
|
||||||
|
with mock.patch.object(utils, "_fetch_openrouter_models", return_value=models), \
|
||||||
|
mock.patch.object(utils.questionary, "select", side_effect=fake_select):
|
||||||
|
utils.select_openrouter_model("deep")
|
||||||
|
|
||||||
|
assert "nex-agi/x" in captured["values"] # fallback keeps the list usable
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestCancelExitsCleanly:
|
||||||
|
def test_dropdown_cancel_exits(self):
|
||||||
|
with mock.patch.object(utils, "_fetch_openrouter_models", return_value=[]), \
|
||||||
|
mock.patch.object(utils.questionary, "select", return_value=_asks(None)), \
|
||||||
|
pytest.raises(SystemExit):
|
||||||
|
utils.select_openrouter_model("quick")
|
||||||
|
|
||||||
|
def test_custom_id_cancel_exits(self):
|
||||||
|
with mock.patch.object(utils, "_fetch_openrouter_models", return_value=[]), \
|
||||||
|
mock.patch.object(utils.questionary, "select", return_value=_asks("custom")), \
|
||||||
|
mock.patch.object(utils.questionary, "text", return_value=_asks(None)), \
|
||||||
|
pytest.raises(SystemExit):
|
||||||
|
utils.select_openrouter_model("deep")
|
||||||
|
|
||||||
|
def test_prompt_custom_model_id_cancel_exits(self):
|
||||||
|
with mock.patch.object(utils.questionary, "text", return_value=_asks(None)), \
|
||||||
|
pytest.raises(SystemExit):
|
||||||
|
utils._prompt_custom_model_id()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestLanguageDefaultsToEnglish:
|
||||||
|
def test_select_cancel_defaults_english(self):
|
||||||
|
with mock.patch.object(utils.questionary, "select", return_value=_asks(None)):
|
||||||
|
assert utils.ask_output_language() == "English"
|
||||||
|
|
||||||
|
def test_custom_language_cancel_defaults_english(self):
|
||||||
|
with mock.patch.object(utils.questionary, "select", return_value=_asks("custom")), \
|
||||||
|
mock.patch.object(utils.questionary, "text", return_value=_asks(None)):
|
||||||
|
assert utils.ask_output_language() == "English"
|
||||||
129
tests/test_polymarket.py
Normal file
129
tests/test_polymarket.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Polymarket prediction-market vendor: forward-looking filtering, volume
|
||||||
|
ranking, formatting, graceful degradation, and router integration.
|
||||||
|
|
||||||
|
All API access is mocked, so these run without a network connection.
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import tradingagents.dataflows.config as config_module
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
from tradingagents.dataflows import interface, polymarket
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
|
||||||
|
|
||||||
|
def _market(question, prob, *, volume, end_date, closed=False, wk=None):
|
||||||
|
return {
|
||||||
|
"question": question,
|
||||||
|
"outcomes": '["Yes", "No"]',
|
||||||
|
"outcomePrices": f'["{prob}", "{round(1 - prob, 4)}"]',
|
||||||
|
"volumeNum": volume,
|
||||||
|
"endDate": end_date,
|
||||||
|
"closed": closed,
|
||||||
|
"oneWeekPriceChange": wk,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# One event with a mix: a high-volume open market, a closed one, a past-dated
|
||||||
|
# one, and a lower-volume open one. Far-future / far-past dates keep the test
|
||||||
|
# independent of the real clock.
|
||||||
|
_SEARCH = {
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"markets": [
|
||||||
|
_market("Open big?", 0.76, volume=5_000_000, end_date="2030-12-31T00:00:00Z", wk=-0.045),
|
||||||
|
_market("Resolved already?", 1.0, volume=9_000_000, end_date="2030-12-31T00:00:00Z", closed=True),
|
||||||
|
_market("Past event?", 0.5, volume=8_000_000, end_date="2020-01-01T00:00:00Z"),
|
||||||
|
_market("Open small?", 0.30, volume=1_000, end_date="2030-06-30T00:00:00Z"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class PolymarketFilterTests(unittest.TestCase):
|
||||||
|
def test_closed_and_past_markets_are_excluded(self):
|
||||||
|
with mock.patch.object(polymarket, "_request", return_value=_SEARCH):
|
||||||
|
out = polymarket.get_prediction_markets("anything", limit=10)
|
||||||
|
self.assertIn("Open big?", out)
|
||||||
|
self.assertIn("Open small?", out)
|
||||||
|
self.assertNotIn("Resolved already?", out) # closed
|
||||||
|
self.assertNotIn("Past event?", out) # endDate in the past
|
||||||
|
|
||||||
|
def test_ranked_by_volume(self):
|
||||||
|
with mock.patch.object(polymarket, "_request", return_value=_SEARCH):
|
||||||
|
out = polymarket.get_prediction_markets("anything", limit=10)
|
||||||
|
self.assertLess(out.index("Open big?"), out.index("Open small?"))
|
||||||
|
|
||||||
|
def test_limit_caps_results(self):
|
||||||
|
with mock.patch.object(polymarket, "_request", return_value=_SEARCH):
|
||||||
|
out = polymarket.get_prediction_markets("anything", limit=1)
|
||||||
|
self.assertIn("Open big?", out)
|
||||||
|
self.assertNotIn("Open small?", out)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class PolymarketFormatTests(unittest.TestCase):
|
||||||
|
def test_probability_volume_and_weekly_change_render(self):
|
||||||
|
with mock.patch.object(polymarket, "_request", return_value=_SEARCH):
|
||||||
|
out = polymarket.get_prediction_markets("anything", limit=10)
|
||||||
|
self.assertIn("Yes 76%", out)
|
||||||
|
self.assertIn("$5,000,000 volume", out)
|
||||||
|
self.assertIn("resolves 2030-12-31", out)
|
||||||
|
self.assertIn("1-week -4.5pp", out) # -0.045 -> -4.5pp
|
||||||
|
|
||||||
|
def test_weekly_change_omitted_when_absent(self):
|
||||||
|
# "Open small?" has wk=None -> no 1-week clause on its line.
|
||||||
|
with mock.patch.object(polymarket, "_request", return_value=_SEARCH):
|
||||||
|
out = polymarket.get_prediction_markets("anything", limit=10)
|
||||||
|
small_line = next(ln for ln in out.splitlines() if "Open small?" in ln)
|
||||||
|
self.assertNotIn("1-week", small_line)
|
||||||
|
|
||||||
|
def test_no_matches_reports_clearly(self):
|
||||||
|
with mock.patch.object(polymarket, "_request", return_value={"events": []}):
|
||||||
|
out = polymarket.get_prediction_markets("obscure ticker", limit=6)
|
||||||
|
self.assertIn("No open prediction markets", out)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class PolymarketResilienceTests(unittest.TestCase):
|
||||||
|
def test_network_error_degrades_gracefully(self):
|
||||||
|
# An external-service hiccup must not raise into the analyst.
|
||||||
|
with mock.patch.object(
|
||||||
|
polymarket, "_request", side_effect=requests.RequestException("boom")
|
||||||
|
):
|
||||||
|
out = polymarket.get_prediction_markets("Fed rate cut")
|
||||||
|
self.assertIn("unavailable", out.lower())
|
||||||
|
self.assertIn("Fed rate cut", out)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class PolymarketRoutingTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def test_category_routes_to_polymarket(self):
|
||||||
|
self.assertEqual(
|
||||||
|
interface.get_category_for_method("get_prediction_markets"),
|
||||||
|
"prediction_markets",
|
||||||
|
)
|
||||||
|
set_config({"data_vendors": {"prediction_markets": "polymarket"}})
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS,
|
||||||
|
{"get_prediction_markets": {"polymarket": lambda *a, **k: "POLY_OK"}},
|
||||||
|
clear=False,
|
||||||
|
):
|
||||||
|
out = interface.route_to_vendor("get_prediction_markets", "fed", 5)
|
||||||
|
self.assertEqual(out, "POLY_OK")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
59
tests/test_provider_registry.py
Normal file
59
tests/test_provider_registry.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""The OpenAI-compatible provider registry is the single source of truth for the
|
||||||
|
family; this guards each provider's resolved config (base URL, subclass, auth,
|
||||||
|
Responses API) so a future edit can't silently break one.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.openai_client import (
|
||||||
|
OPENAI_COMPATIBLE_PROVIDERS,
|
||||||
|
DeepSeekChatOpenAI,
|
||||||
|
MinimaxChatOpenAI,
|
||||||
|
NormalizedChatOpenAI,
|
||||||
|
is_openai_compatible,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_registry_membership():
|
||||||
|
assert is_openai_compatible("openai")
|
||||||
|
assert is_openai_compatible("openai_compatible") # the generic endpoint
|
||||||
|
# native (different API) clients are intentionally NOT in the registry
|
||||||
|
assert not is_openai_compatible("anthropic")
|
||||||
|
assert not is_openai_compatible("google")
|
||||||
|
assert not is_openai_compatible("azure")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.parametrize("provider,base_url,chat_class,responses", [
|
||||||
|
("openai", None, NormalizedChatOpenAI, True),
|
||||||
|
("xai", "https://api.x.ai/v1", NormalizedChatOpenAI, False),
|
||||||
|
("deepseek", "https://api.deepseek.com", DeepSeekChatOpenAI, False),
|
||||||
|
("qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", NormalizedChatOpenAI, False),
|
||||||
|
("qwen-cn", "https://dashscope.aliyuncs.com/compatible-mode/v1", NormalizedChatOpenAI, False),
|
||||||
|
("glm", "https://api.z.ai/api/paas/v4/", NormalizedChatOpenAI, False),
|
||||||
|
("glm-cn", "https://open.bigmodel.cn/api/paas/v4/", NormalizedChatOpenAI, False),
|
||||||
|
("minimax", "https://api.minimax.io/v1", MinimaxChatOpenAI, False),
|
||||||
|
("minimax-cn", "https://api.minimaxi.com/v1", MinimaxChatOpenAI, False),
|
||||||
|
("openrouter", "https://openrouter.ai/api/v1", NormalizedChatOpenAI, False),
|
||||||
|
("mistral", "https://api.mistral.ai/v1", NormalizedChatOpenAI, False),
|
||||||
|
("kimi", "https://api.moonshot.ai/v1", NormalizedChatOpenAI, False),
|
||||||
|
("groq", "https://api.groq.com/openai/v1", NormalizedChatOpenAI, False),
|
||||||
|
("nvidia", "https://integrate.api.nvidia.com/v1", NormalizedChatOpenAI, False),
|
||||||
|
("ollama", "http://localhost:11434/v1", NormalizedChatOpenAI, False),
|
||||||
|
])
|
||||||
|
def test_registry_spec(provider, base_url, chat_class, responses):
|
||||||
|
spec = OPENAI_COMPATIBLE_PROVIDERS[provider]
|
||||||
|
assert spec.base_url == base_url
|
||||||
|
assert spec.chat_class is chat_class
|
||||||
|
assert spec.use_responses_api is responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_key_optionality():
|
||||||
|
# Local/generic endpoints are key-optional; hosted APIs require a key.
|
||||||
|
assert OPENAI_COMPATIBLE_PROVIDERS["ollama"].key_optional is True
|
||||||
|
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].key_optional is True
|
||||||
|
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].require_base_url is True
|
||||||
|
assert OPENAI_COMPATIBLE_PROVIDERS["xai"].key_optional is False
|
||||||
|
# OLLAMA_BASE_URL is the only base-URL env override.
|
||||||
|
assert OPENAI_COMPATIBLE_PROVIDERS["ollama"].base_url_env == "OLLAMA_BASE_URL"
|
||||||
192
tests/test_reddit_fallback.py
Normal file
192
tests/test_reddit_fallback.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""Tests for the RSS-first Reddit fetcher, its 429 backoff, the opt-in JSON
|
||||||
|
path's degradation (#862), and chunked-transfer error handling (#1024)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import http.client
|
||||||
|
from unittest.mock import patch
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.dataflows import reddit
|
||||||
|
|
||||||
|
_SAMPLE_ATOM = """<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<feed xmlns="http://www.w3.org/2005/Atom">
|
||||||
|
<entry>
|
||||||
|
<title>NVDA earnings beat, stock pops</title>
|
||||||
|
<published>2026-05-20T14:30:00+00:00</published>
|
||||||
|
<content type="html"><!-- SC_OFF --><div class="md"><p>Great <b>quarter</b> for NVDA&#39;s datacenter unit.</p></div><!-- SC_ON --></content>
|
||||||
|
</entry>
|
||||||
|
<entry>
|
||||||
|
<title>Is NVDA overvalued?</title>
|
||||||
|
<published>2026-05-19T09:00:00Z</published>
|
||||||
|
<content type="html"><p>Forward P/E discussion</p></content>
|
||||||
|
</entry>
|
||||||
|
</feed>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _resp(read_fn):
|
||||||
|
"""A minimal context-manager response whose read() runs ``read_fn``."""
|
||||||
|
class _Resp:
|
||||||
|
def __enter__(self_inner):
|
||||||
|
return self_inner
|
||||||
|
|
||||||
|
def __exit__(self_inner, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def read(self_inner):
|
||||||
|
return read_fn()
|
||||||
|
return _Resp()
|
||||||
|
|
||||||
|
|
||||||
|
def _atom_resp():
|
||||||
|
return _resp(lambda: _SAMPLE_ATOM.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _raise(exc):
|
||||||
|
def _r():
|
||||||
|
raise exc
|
||||||
|
return _resp(_r)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestIsoToTimestamp:
|
||||||
|
def test_parses_offset_and_z(self):
|
||||||
|
assert reddit._iso_to_timestamp("2026-05-20T14:30:00+00:00") > 0
|
||||||
|
assert reddit._iso_to_timestamp("2026-05-19T09:00:00Z") > 0
|
||||||
|
|
||||||
|
def test_none_and_garbage_return_none(self):
|
||||||
|
assert reddit._iso_to_timestamp(None) is None
|
||||||
|
assert reddit._iso_to_timestamp("not-a-date") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestStripHtml:
|
||||||
|
def test_extracts_between_sc_markers_and_unescapes(self):
|
||||||
|
raw = "<!-- SC_OFF --><div class=\"md\"><p>Great <b>quarter</b> & more</p></div><!-- SC_ON -->"
|
||||||
|
assert reddit._strip_html(raw) == "Great quarter & more"
|
||||||
|
|
||||||
|
def test_empty(self):
|
||||||
|
assert reddit._strip_html("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRssParsing:
|
||||||
|
def test_parses_atom_entries(self):
|
||||||
|
with patch.object(reddit, "urlopen", return_value=_atom_resp()):
|
||||||
|
posts = reddit._fetch_subreddit_rss("NVDA", "stocks", limit=5, timeout=5.0)
|
||||||
|
assert len(posts) == 2
|
||||||
|
assert posts[0]["title"] == "NVDA earnings beat, stock pops"
|
||||||
|
assert posts[0]["source"] == "rss"
|
||||||
|
assert posts[0]["score"] is None
|
||||||
|
assert posts[0]["num_comments"] is None
|
||||||
|
assert posts[0]["created_utc"] > 0
|
||||||
|
assert "datacenter unit" in posts[0]["selftext"]
|
||||||
|
|
||||||
|
def test_malformed_xml_fails_open(self):
|
||||||
|
with patch.object(reddit, "urlopen", return_value=_resp(lambda: b"<<not xml>>")):
|
||||||
|
assert reddit._fetch_subreddit_rss("NVDA", "stocks", 5, 5.0) == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestFetchSubredditIsRssFirst:
|
||||||
|
"""The default per-subreddit fetch goes straight to RSS — it must not hit
|
||||||
|
the WAF-blocked JSON endpoint, which only burned rate-limit budget."""
|
||||||
|
|
||||||
|
def test_delegates_to_rss_without_touching_json(self):
|
||||||
|
sentinel = [{"title": "x", "source": "rss", "score": None,
|
||||||
|
"num_comments": None, "created_utc": None, "selftext": ""}]
|
||||||
|
with patch.object(reddit, "_fetch_subreddit_rss", return_value=sentinel) as rss, \
|
||||||
|
patch.object(reddit, "urlopen",
|
||||||
|
side_effect=AssertionError("JSON endpoint must not be called")):
|
||||||
|
out = reddit._fetch_subreddit("NVDA", "stocks", 5, 5.0)
|
||||||
|
rss.assert_called_once()
|
||||||
|
assert out is sentinel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestJsonPathFallsBackToRss:
|
||||||
|
"""The opt-in JSON path still degrades to RSS on a 403 (kept for #862)."""
|
||||||
|
|
||||||
|
def test_403_triggers_rss(self):
|
||||||
|
err = HTTPError("url", 403, "Blocked", {}, None)
|
||||||
|
rss_posts = [{"title": "x", "source": "rss", "score": None,
|
||||||
|
"num_comments": None, "created_utc": None, "selftext": ""}]
|
||||||
|
with patch.object(reddit, "urlopen", side_effect=err), \
|
||||||
|
patch.object(reddit, "_fetch_subreddit_rss", return_value=rss_posts) as rss:
|
||||||
|
out = reddit._fetch_subreddit_json("NVDA", "stocks", 5, 5.0)
|
||||||
|
rss.assert_called_once()
|
||||||
|
assert out and out[0]["source"] == "rss"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRss429Backoff:
|
||||||
|
def test_429_then_success_retries_once(self):
|
||||||
|
err = HTTPError("url", 429, "Too Many Requests", {}, None)
|
||||||
|
with patch.object(reddit, "urlopen", side_effect=[err, _atom_resp()]) as op, \
|
||||||
|
patch.object(reddit.time, "sleep") as slept:
|
||||||
|
posts = reddit._fetch_subreddit_rss("NVDA", "stocks", 5, 5.0)
|
||||||
|
assert op.call_count == 2 # original + exactly one retry
|
||||||
|
slept.assert_called_once() # backed off before retrying
|
||||||
|
assert len(posts) == 2
|
||||||
|
|
||||||
|
def test_429_twice_gives_up_after_one_retry(self):
|
||||||
|
err = HTTPError("url", 429, "Too Many Requests", {}, None)
|
||||||
|
with patch.object(reddit, "urlopen", side_effect=[err, err]) as op, \
|
||||||
|
patch.object(reddit.time, "sleep"):
|
||||||
|
posts = reddit._fetch_subreddit_rss("NVDA", "stocks", 5, 5.0)
|
||||||
|
assert op.call_count == 2 # one retry, then gives up cleanly
|
||||||
|
assert posts == []
|
||||||
|
|
||||||
|
def test_retry_after_header_is_honoured(self):
|
||||||
|
err = HTTPError("url", 429, "Too Many Requests", {"Retry-After": "12"}, None)
|
||||||
|
with patch.object(reddit, "urlopen", side_effect=[err, _atom_resp()]), \
|
||||||
|
patch.object(reddit.time, "sleep") as slept:
|
||||||
|
reddit._fetch_subreddit_rss("NVDA", "stocks", 5, 5.0)
|
||||||
|
slept.assert_called_once_with(12.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestChunkedTransferErrorsHandled:
|
||||||
|
"""IncompleteRead/RemoteDisconnected come from http.client and are NOT
|
||||||
|
OSErrors, so they were previously uncaught and crashed the pipeline (#1024)."""
|
||||||
|
|
||||||
|
def test_rss_incomplete_read_degrades_to_empty(self):
|
||||||
|
with patch.object(reddit, "urlopen", return_value=_raise(http.client.IncompleteRead(b""))):
|
||||||
|
assert reddit._fetch_subreddit_rss("NVDA", "stocks", 5, 5.0) == []
|
||||||
|
|
||||||
|
def test_json_incomplete_read_falls_back_to_rss(self):
|
||||||
|
with patch.object(reddit, "urlopen", return_value=_raise(http.client.IncompleteRead(b""))), \
|
||||||
|
patch.object(reddit, "_fetch_subreddit_rss", return_value=[]) as rss:
|
||||||
|
reddit._fetch_subreddit_json("NVDA", "stocks", 5, 5.0)
|
||||||
|
rss.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestFormatterHandlesRssPosts:
|
||||||
|
def test_rss_posts_omit_fake_counts_and_note_source(self):
|
||||||
|
rss_posts = [{
|
||||||
|
"title": "NVDA pops", "score": None, "num_comments": None,
|
||||||
|
"created_utc": reddit._iso_to_timestamp("2026-05-20T14:30:00Z"),
|
||||||
|
"selftext": "great quarter", "source": "rss",
|
||||||
|
}]
|
||||||
|
with patch.object(reddit, "_fetch_subreddit", return_value=rss_posts):
|
||||||
|
out = reddit.fetch_reddit_posts("NVDA", subreddits=("stocks",), inter_request_delay=0)
|
||||||
|
assert "via RSS feed" in out
|
||||||
|
assert "↑" not in out # no fake score arrow
|
||||||
|
assert "NVDA pops" in out
|
||||||
|
assert "great quarter" in out
|
||||||
|
|
||||||
|
def test_json_posts_still_show_counts(self):
|
||||||
|
json_posts = [{
|
||||||
|
"title": "NVDA pops", "score": 1234, "num_comments": 56,
|
||||||
|
"created_utc": reddit._iso_to_timestamp("2026-05-20T14:30:00Z"),
|
||||||
|
"selftext": "",
|
||||||
|
}]
|
||||||
|
with patch.object(reddit, "_fetch_subreddit", return_value=json_posts):
|
||||||
|
out = reddit.fetch_reddit_posts("NVDA", subreddits=("stocks",), inter_request_delay=0)
|
||||||
|
assert "1234↑" in out
|
||||||
|
assert "56c" in out
|
||||||
|
assert "via RSS" not in out
|
||||||
@@ -14,6 +14,11 @@ class TestSafeTickerComponent(unittest.TestCase):
|
|||||||
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
|
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
|
||||||
self.assertEqual(safe_ticker_component(ticker), ticker)
|
self.assertEqual(safe_ticker_component(ticker), ticker)
|
||||||
|
|
||||||
|
def test_accepts_futures_and_forex_formats(self):
|
||||||
|
# Futures use '=' (GC=F gold, CL=F crude), forex/CFD symbols use '+'.
|
||||||
|
for ticker in ("GC=F", "CL=F", "ES=F", "XAUUSD+", "EURUSD+"):
|
||||||
|
self.assertEqual(safe_ticker_component(ticker), ticker)
|
||||||
|
|
||||||
def test_rejects_path_separators(self):
|
def test_rejects_path_separators(self):
|
||||||
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
|
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import pytest
|
|||||||
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
|
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
|
||||||
from tradingagents.graph.signal_processing import SignalProcessor
|
from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Heuristic parser
|
# Heuristic parser
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
70
tests/test_stockstats_date_column.py
Normal file
70
tests/test_stockstats_date_column.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""Tests for tolerating a non-`Date` index column in stockstats_utils (#890).
|
||||||
|
|
||||||
|
Guards against a download frame whose date column is `index` or `Datetime`
|
||||||
|
instead of `Date`, which would otherwise silently drop every indicator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.dataflows import stockstats_utils as su
|
||||||
|
|
||||||
|
|
||||||
|
def _ohlcv(date_col: str) -> pd.DataFrame:
|
||||||
|
"""OHLCV frame whose date column is named `date_col`."""
|
||||||
|
dates = pd.bdate_range("2026-04-01", periods=10)
|
||||||
|
return pd.DataFrame({
|
||||||
|
date_col: dates,
|
||||||
|
"Open": [100.0 + i for i in range(10)],
|
||||||
|
"High": [101.0 + i for i in range(10)],
|
||||||
|
"Low": [99.0 + i for i in range(10)],
|
||||||
|
"Close": [100.5 + i for i in range(10)],
|
||||||
|
"Volume": [1_000_000 + i for i in range(10)],
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestEnsureDateColumn:
|
||||||
|
def test_renames_index_column(self):
|
||||||
|
out = su._ensure_date_column(_ohlcv("index"))
|
||||||
|
assert "Date" in out.columns and "index" not in out.columns
|
||||||
|
|
||||||
|
def test_renames_datetime_and_date_variants(self):
|
||||||
|
assert "Date" in su._ensure_date_column(_ohlcv("Datetime")).columns
|
||||||
|
assert "Date" in su._ensure_date_column(_ohlcv("date")).columns
|
||||||
|
|
||||||
|
def test_leaves_existing_date_untouched(self):
|
||||||
|
df = _ohlcv("Date")
|
||||||
|
assert su._ensure_date_column(df) is df # no-op short-circuit
|
||||||
|
|
||||||
|
def test_no_datelike_column_is_left_alone(self):
|
||||||
|
df = pd.DataFrame({"Close": [1, 2, 3]})
|
||||||
|
out = su._ensure_date_column(df)
|
||||||
|
assert "Date" not in out.columns # nothing to rename; caller handles
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestCleanDataframeAcrossVersions:
|
||||||
|
def test_clean_handles_index_column(self):
|
||||||
|
"""A frame with `index` instead of `Date` must still clean to a
|
||||||
|
usable, date-parsed frame (was KeyError: 'Date')."""
|
||||||
|
cleaned = su._clean_dataframe(_ohlcv("index"))
|
||||||
|
assert "Date" in cleaned.columns
|
||||||
|
assert pd.api.types.is_datetime64_any_dtype(cleaned["Date"])
|
||||||
|
assert len(cleaned) == 10
|
||||||
|
|
||||||
|
def test_clean_handles_legacy_date_column(self):
|
||||||
|
cleaned = su._clean_dataframe(_ohlcv("Date"))
|
||||||
|
assert len(cleaned) == 10
|
||||||
|
|
||||||
|
def test_indicators_compute_after_index_rename(self):
|
||||||
|
"""stockstats must compute indicators on a frame whose date column
|
||||||
|
arrived as `index`, instead of erroring per indicator."""
|
||||||
|
from stockstats import wrap
|
||||||
|
cleaned = su._clean_dataframe(_ohlcv("index"))
|
||||||
|
df = wrap(cleaned)
|
||||||
|
df["close_5_sma"] # triggers calculation
|
||||||
|
assert "close_5_sma" in df.columns
|
||||||
|
assert df["close_5_sma"].notna().any()
|
||||||
42
tests/test_stocktwits_resilience.py
Normal file
42
tests/test_stocktwits_resilience.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""StockTwits fetch degrades (never raises) on transport errors, including the
|
||||||
|
http.client chunked-transfer exceptions that are not OSErrors (#1024)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import http.client
|
||||||
|
from unittest.mock import patch
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.dataflows import stocktwits
|
||||||
|
|
||||||
|
|
||||||
|
def _raise(exc):
|
||||||
|
class _Resp:
|
||||||
|
def __enter__(self_inner):
|
||||||
|
return self_inner
|
||||||
|
|
||||||
|
def __exit__(self_inner, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def read(self_inner):
|
||||||
|
raise exc
|
||||||
|
return _Resp()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class StockTwitsResilienceTests:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"exc",
|
||||||
|
[
|
||||||
|
http.client.IncompleteRead(b""),
|
||||||
|
HTTPError("url", 503, "down", {}, None),
|
||||||
|
TimeoutError("slow"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_transport_errors_return_placeholder(self, exc):
|
||||||
|
with patch.object(stocktwits, "urlopen", return_value=_raise(exc)):
|
||||||
|
out = stocktwits.fetch_stocktwits_messages("NVDA")
|
||||||
|
assert "unavailable" in out.lower()
|
||||||
|
assert out.startswith("<stocktwits unavailable")
|
||||||
@@ -1,28 +1,32 @@
|
|||||||
"""Tests for structured-output agents (Trader and Research Manager).
|
"""Tests for structured-output agents (Trader, Research Manager, Sentiment Analyst).
|
||||||
|
|
||||||
The Portfolio Manager has its own coverage in tests/test_memory_log.py
|
The Portfolio Manager has its own coverage in tests/test_memory_log.py
|
||||||
(which exercises the full memory-log → PM injection cycle). This file
|
(which exercises the full memory-log → PM injection cycle). This file
|
||||||
covers the parallel schemas, render functions, and graceful-fallback
|
covers the parallel schemas, render functions, and graceful-fallback
|
||||||
behavior we added for the Trader and Research Manager so all three
|
behavior we added for the Trader, Research Manager, and Sentiment Analyst
|
||||||
decision-making agents share the same shape.
|
so they share the same deterministic output shape.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from tradingagents.agents.analysts.sentiment_analyst import create_sentiment_analyst
|
||||||
from tradingagents.agents.managers.research_manager import create_research_manager
|
from tradingagents.agents.managers.research_manager import create_research_manager
|
||||||
from tradingagents.agents.schemas import (
|
from tradingagents.agents.schemas import (
|
||||||
PortfolioRating,
|
PortfolioRating,
|
||||||
ResearchPlan,
|
ResearchPlan,
|
||||||
|
SentimentBand,
|
||||||
|
SentimentReport,
|
||||||
TraderAction,
|
TraderAction,
|
||||||
TraderProposal,
|
TraderProposal,
|
||||||
render_research_plan,
|
render_research_plan,
|
||||||
|
render_sentiment_report,
|
||||||
render_trader_proposal,
|
render_trader_proposal,
|
||||||
)
|
)
|
||||||
from tradingagents.agents.trader.trader import create_trader
|
from tradingagents.agents.trader.trader import create_trader
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Render functions
|
# Render functions
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -230,3 +234,126 @@ class TestResearchManagerAgent:
|
|||||||
rm = create_research_manager(llm)
|
rm = create_research_manager(llm)
|
||||||
result = rm(_make_rm_state())
|
result = rm(_make_rm_state())
|
||||||
assert result["investment_plan"] == plain_response
|
assert result["investment_plan"] == plain_response
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sentiment Analyst: schema, render, structured happy path + fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRenderSentimentReport:
|
||||||
|
def test_header_contains_band_and_score(self):
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.BULLISH,
|
||||||
|
overall_score=7.2,
|
||||||
|
confidence="high",
|
||||||
|
narrative="Source breakdown here.",
|
||||||
|
)
|
||||||
|
md = render_sentiment_report(report)
|
||||||
|
assert "**Overall Sentiment:** **Bullish**" in md
|
||||||
|
assert "(Score: 7.2/10)" in md
|
||||||
|
|
||||||
|
def test_header_contains_confidence(self):
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.NEUTRAL,
|
||||||
|
overall_score=5.0,
|
||||||
|
confidence="low",
|
||||||
|
narrative="Limited data.",
|
||||||
|
)
|
||||||
|
assert "**Confidence:** Low" in render_sentiment_report(report)
|
||||||
|
|
||||||
|
def test_narrative_preserved_in_output(self):
|
||||||
|
narrative = "## Breakdown\n\nStockTwits: 70% bullish.\n\n| Signal | Direction |\n|---|---|\n| News | Neutral |"
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.MILDLY_BULLISH,
|
||||||
|
overall_score=6.0,
|
||||||
|
confidence="medium",
|
||||||
|
narrative=narrative,
|
||||||
|
)
|
||||||
|
assert narrative in render_sentiment_report(report)
|
||||||
|
|
||||||
|
def test_all_six_bands_render(self):
|
||||||
|
for band in SentimentBand:
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=band, overall_score=5.0,
|
||||||
|
confidence="medium", narrative="n",
|
||||||
|
)
|
||||||
|
assert band.value in render_sentiment_report(report)
|
||||||
|
|
||||||
|
def test_score_out_of_range_rejected(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SentimentReport(
|
||||||
|
overall_band=SentimentBand.BULLISH, overall_score=11.0,
|
||||||
|
confidence="high", narrative="n",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sentiment_state():
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"trade_date": "2026-01-15",
|
||||||
|
"asset_type": "stock",
|
||||||
|
"messages": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _structured_sentiment_llm(captured: dict, report: SentimentReport | None = None):
|
||||||
|
"""MagicMock LLM whose structured binding captures the prompt and returns
|
||||||
|
a real SentimentReport so render_sentiment_report works."""
|
||||||
|
if report is None:
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.BULLISH, overall_score=7.5,
|
||||||
|
confidence="high",
|
||||||
|
narrative="StockTwits 75% bullish. News constructive. Reddit upbeat.",
|
||||||
|
)
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = lambda prompt: (
|
||||||
|
captured.__setitem__("prompt", prompt) or report
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSentimentAnalystAgent:
|
||||||
|
def test_structured_path_produces_rendered_markdown(self):
|
||||||
|
captured = {}
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.MILDLY_BEARISH, overall_score=4.0,
|
||||||
|
confidence="medium", narrative="Mixed signals across sources.",
|
||||||
|
)
|
||||||
|
analyst = create_sentiment_analyst(_structured_sentiment_llm(captured, report))
|
||||||
|
sr = analyst(_make_sentiment_state())["sentiment_report"]
|
||||||
|
assert "**Overall Sentiment:** **Mildly Bearish**" in sr
|
||||||
|
assert "(Score: 4.0/10)" in sr
|
||||||
|
assert "Mixed signals across sources." in sr
|
||||||
|
|
||||||
|
def test_sentiment_report_also_in_messages(self):
|
||||||
|
captured = {}
|
||||||
|
analyst = create_sentiment_analyst(_structured_sentiment_llm(captured))
|
||||||
|
result = analyst(_make_sentiment_state())
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
|
assert result["sentiment_report"] == result["messages"][0].content
|
||||||
|
|
||||||
|
def test_prompt_contains_ticker(self):
|
||||||
|
captured = {}
|
||||||
|
create_sentiment_analyst(_structured_sentiment_llm(captured))(_make_sentiment_state())
|
||||||
|
assert any("NVDA" in str(m) for m in captured["prompt"])
|
||||||
|
|
||||||
|
def test_falls_back_to_freetext_when_structured_unavailable(self):
|
||||||
|
plain = "**Overall Sentiment:** **Bearish** (Score: 3.0/10)\n**Confidence:** Low\n\nLimited data."
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain)
|
||||||
|
assert create_sentiment_analyst(llm)(_make_sentiment_state())["sentiment_report"] == plain
|
||||||
|
|
||||||
|
def test_falls_back_to_freetext_when_structured_call_fails(self):
|
||||||
|
plain = "Fallback free-text sentiment."
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = ValueError("bad JSON from model")
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain)
|
||||||
|
assert create_sentiment_analyst(llm)(_make_sentiment_state())["sentiment_report"] == plain
|
||||||
|
|||||||
54
tests/test_symbol_normalization_paths.py
Normal file
54
tests/test_symbol_normalization_paths.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Symbol normalization must apply on every yfinance path, not just price fetch.
|
||||||
|
|
||||||
|
Regression tests for #983 (instrument identity) and #984 (reflection returns):
|
||||||
|
a broker symbol like XAUUSD must resolve to the same Yahoo symbol (GC=F) that
|
||||||
|
the price path uses, so identity and realized-return lookups hit the right
|
||||||
|
instrument instead of failing/mismatching.
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
import tradingagents.agents.utils.agent_utils as au
|
||||||
|
import tradingagents.graph.trading_graph as tg
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_lookup_normalizes_symbol(monkeypatch):
|
||||||
|
seen = {}
|
||||||
|
|
||||||
|
class FakeTicker:
|
||||||
|
def __init__(self, symbol):
|
||||||
|
seen["symbol"] = symbol
|
||||||
|
|
||||||
|
@property
|
||||||
|
def info(self):
|
||||||
|
return {"longName": "Gold Futures", "quoteType": "FUTURE"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(au.yf, "Ticker", FakeTicker)
|
||||||
|
au.resolve_instrument_identity.cache_clear()
|
||||||
|
|
||||||
|
identity = au.resolve_instrument_identity("XAUUSD")
|
||||||
|
|
||||||
|
assert seen["symbol"] == "GC=F" # normalized, not the raw broker symbol
|
||||||
|
assert identity.get("company_name") == "Gold Futures"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_returns_normalizes_symbol(monkeypatch):
|
||||||
|
queried = []
|
||||||
|
|
||||||
|
class FakeTicker:
|
||||||
|
def __init__(self, symbol):
|
||||||
|
queried.append(symbol)
|
||||||
|
|
||||||
|
def history(self, *args, **kwargs):
|
||||||
|
return pd.DataFrame({"Close": [100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0]})
|
||||||
|
|
||||||
|
monkeypatch.setattr(tg.yf, "Ticker", FakeTicker)
|
||||||
|
|
||||||
|
# _fetch_returns does not use ``self``; call unbound to avoid building the graph.
|
||||||
|
raw, alpha, days = TradingAgentsGraph._fetch_returns(
|
||||||
|
None, "XAUUSD", "2025-01-02", holding_days=5, benchmark="SPY"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert queried[0] == "GC=F" # stock symbol normalized (#984)
|
||||||
|
assert queried[1] == "SPY" # benchmark left as the canonical symbol
|
||||||
|
assert raw is not None and days is not None
|
||||||
81
tests/test_symbol_utils.py
Normal file
81
tests/test_symbol_utils.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Tests for symbol normalization and the no-data routing sentinel."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.dataflows.symbol_utils import (
|
||||||
|
NoMarketDataError,
|
||||||
|
is_yahoo_safe,
|
||||||
|
normalize_symbol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestNormalizeSymbol(unittest.TestCase):
|
||||||
|
def test_plain_equities_unchanged(self):
|
||||||
|
for sym in ("AAPL", "MSFT", "TSM", "BRK.B", "0700.HK", "^GSPC", "GC=F"):
|
||||||
|
self.assertEqual(normalize_symbol(sym), sym)
|
||||||
|
|
||||||
|
def test_lowercases_are_upper(self):
|
||||||
|
self.assertEqual(normalize_symbol("aapl"), "AAPL")
|
||||||
|
self.assertEqual(normalize_symbol(" msft "), "MSFT")
|
||||||
|
|
||||||
|
def test_metal_aliases_map_to_futures(self):
|
||||||
|
self.assertEqual(normalize_symbol("XAUUSD"), "GC=F")
|
||||||
|
self.assertEqual(normalize_symbol("XAUUSD+"), "GC=F") # broker CFD suffix
|
||||||
|
self.assertEqual(normalize_symbol("xauusd+"), "GC=F")
|
||||||
|
self.assertEqual(normalize_symbol("GOLD"), "GC=F")
|
||||||
|
self.assertEqual(normalize_symbol("XAGUSD"), "SI=F")
|
||||||
|
|
||||||
|
def test_energy_and_index_aliases(self):
|
||||||
|
self.assertEqual(normalize_symbol("USOIL"), "CL=F")
|
||||||
|
self.assertEqual(normalize_symbol("SPX500"), "^GSPC")
|
||||||
|
self.assertEqual(normalize_symbol("NAS100"), "^NDX")
|
||||||
|
self.assertEqual(normalize_symbol("US30"), "^DJI")
|
||||||
|
|
||||||
|
def test_forex_pairs_get_x_suffix(self):
|
||||||
|
self.assertEqual(normalize_symbol("EURUSD"), "EURUSD=X")
|
||||||
|
self.assertEqual(normalize_symbol("GBPJPY"), "GBPJPY=X")
|
||||||
|
self.assertEqual(normalize_symbol("eurusd"), "EURUSD=X")
|
||||||
|
|
||||||
|
def test_crypto_pairs_get_dash_usd(self):
|
||||||
|
self.assertEqual(normalize_symbol("BTCUSD"), "BTC-USD")
|
||||||
|
self.assertEqual(normalize_symbol("ETHUSD"), "ETH-USD")
|
||||||
|
|
||||||
|
def test_six_letter_non_currency_left_alone(self):
|
||||||
|
# GOOGLE-style 6-letter tickers that aren't two currency codes
|
||||||
|
# must not be mangled into a fake forex pair.
|
||||||
|
self.assertEqual(normalize_symbol("ABCDEF"), "ABCDEF")
|
||||||
|
|
||||||
|
def test_empty_input_passthrough(self):
|
||||||
|
self.assertEqual(normalize_symbol(""), "")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestNoMarketDataError(unittest.TestCase):
|
||||||
|
def test_message_includes_resolution(self):
|
||||||
|
err = NoMarketDataError("XAUUSD+", "GC=F", "no rows")
|
||||||
|
self.assertIn("XAUUSD+", str(err))
|
||||||
|
self.assertIn("GC=F", str(err))
|
||||||
|
self.assertEqual(err.symbol, "XAUUSD+")
|
||||||
|
self.assertEqual(err.canonical, "GC=F")
|
||||||
|
|
||||||
|
def test_canonical_defaults_to_symbol(self):
|
||||||
|
err = NoMarketDataError("FOOBAR")
|
||||||
|
self.assertEqual(err.canonical, "FOOBAR")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestIsYahooSafe(unittest.TestCase):
|
||||||
|
def test_accepts_structural_chars(self):
|
||||||
|
for sym in ("AAPL", "GC=F", "^GSPC", "BRK.B", "BTC-USD"):
|
||||||
|
self.assertTrue(is_yahoo_safe(sym))
|
||||||
|
|
||||||
|
def test_rejects_slash_and_space(self):
|
||||||
|
for sym in ("a/b", "AA PL", ""):
|
||||||
|
self.assertFalse(is_yahoo_safe(sym))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
80
tests/test_temperature_config.py
Normal file
80
tests/test_temperature_config.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Tests for the configurable sampling temperature (#178/#168).
|
||||||
|
|
||||||
|
Temperature is a cross-provider knob: when set it must reach the underlying
|
||||||
|
chat client; when unset the provider keeps its own default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTemperatureForwarding:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider,model",
|
||||||
|
[
|
||||||
|
("openai", "gpt-4.1"),
|
||||||
|
("anthropic", "claude-sonnet-4-6"),
|
||||||
|
("google", "gemini-2.5-flash"),
|
||||||
|
("deepseek", "deepseek-chat"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_temperature_reaches_client_when_set(self, provider, model):
|
||||||
|
llm = create_llm_client(
|
||||||
|
provider=provider, model=model, temperature=0.0, api_key="placeholder"
|
||||||
|
).get_llm()
|
||||||
|
assert llm.temperature == 0.0
|
||||||
|
|
||||||
|
def test_temperature_omitted_leaves_provider_default(self):
|
||||||
|
# Not passing temperature must not force it to a value.
|
||||||
|
llm = create_llm_client(
|
||||||
|
provider="openai", model="gpt-4.1", api_key="placeholder"
|
||||||
|
).get_llm()
|
||||||
|
# langchain's default is unset/None, not 0.0
|
||||||
|
assert llm.temperature is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestTemperatureEnvOverlay:
|
||||||
|
def test_env_sets_temperature(self, monkeypatch):
|
||||||
|
import tradingagents.default_config as dc
|
||||||
|
monkeypatch.setenv("TRADINGAGENTS_TEMPERATURE", "0.2")
|
||||||
|
importlib.reload(dc)
|
||||||
|
# Stored on config (string from env is fine; consumed via float()).
|
||||||
|
assert dc.DEFAULT_CONFIG["temperature"] in ("0.2", 0.2)
|
||||||
|
assert float(dc.DEFAULT_CONFIG["temperature"]) == 0.2
|
||||||
|
monkeypatch.delenv("TRADINGAGENTS_TEMPERATURE", raising=False)
|
||||||
|
importlib.reload(dc)
|
||||||
|
|
||||||
|
def test_default_temperature_is_none(self, monkeypatch):
|
||||||
|
import tradingagents.default_config as dc
|
||||||
|
monkeypatch.delenv("TRADINGAGENTS_TEMPERATURE", raising=False)
|
||||||
|
importlib.reload(dc)
|
||||||
|
assert dc.DEFAULT_CONFIG["temperature"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestProviderKwargsTemperature:
|
||||||
|
"""_get_provider_kwargs float-coerces and forwards temperature, or omits it."""
|
||||||
|
|
||||||
|
def _kwargs_for(self, temperature):
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
# Call the method without constructing the full graph.
|
||||||
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||||
|
graph.config = {"llm_provider": "openai", "temperature": temperature}
|
||||||
|
return TradingAgentsGraph._get_provider_kwargs(graph)
|
||||||
|
|
||||||
|
def test_float_string_coerced(self):
|
||||||
|
assert self._kwargs_for("0.3")["temperature"] == 0.3
|
||||||
|
|
||||||
|
def test_float_passthrough(self):
|
||||||
|
assert self._kwargs_for(0.0)["temperature"] == 0.0
|
||||||
|
|
||||||
|
def test_none_omitted(self):
|
||||||
|
assert "temperature" not in self._kwargs_for(None)
|
||||||
|
|
||||||
|
def test_empty_string_omitted(self):
|
||||||
|
assert "temperature" not in self._kwargs_for("")
|
||||||
@@ -16,6 +16,14 @@ class TickerSymbolHandlingTests(unittest.TestCase):
|
|||||||
self.assertIn("7203.T", context)
|
self.assertIn("7203.T", context)
|
||||||
self.assertIn("exchange suffix", context)
|
self.assertIn("exchange suffix", context)
|
||||||
|
|
||||||
|
def test_single_get_ticker_no_shadow(self):
|
||||||
|
# Regression: cli/main.py had a duplicate get_ticker with an empty
|
||||||
|
# questionary prompt (rendered as a bare "?") that shadowed the
|
||||||
|
# descriptive one in cli/utils. Keep a single canonical definition.
|
||||||
|
import cli.main
|
||||||
|
import cli.utils
|
||||||
|
self.assertIs(cli.main.get_ticker, cli.utils.get_ticker)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
105
tests/test_vendor_errors.py
Normal file
105
tests/test_vendor_errors.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""The vendor data-error hierarchy: every "vendor couldn't return usable data"
|
||||||
|
condition derives from VendorError, so the router catches base types and any
|
||||||
|
vendor slots in without new handling.
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.dataflows.config as config_module
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
from tradingagents.dataflows import interface
|
||||||
|
from tradingagents.dataflows.alpha_vantage_common import (
|
||||||
|
AlphaVantageNotConfiguredError,
|
||||||
|
AlphaVantageRateLimitError,
|
||||||
|
)
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
from tradingagents.dataflows.errors import (
|
||||||
|
NoMarketDataError,
|
||||||
|
VendorError,
|
||||||
|
VendorNotConfiguredError,
|
||||||
|
VendorRateLimitError,
|
||||||
|
)
|
||||||
|
from tradingagents.dataflows.fred import FredNotConfiguredError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class HierarchyTests(unittest.TestCase):
|
||||||
|
def test_all_conditions_derive_from_vendor_error(self):
|
||||||
|
for cls in (NoMarketDataError, VendorRateLimitError, VendorNotConfiguredError):
|
||||||
|
self.assertTrue(issubclass(cls, VendorError))
|
||||||
|
|
||||||
|
def test_not_configured_is_still_a_value_error(self):
|
||||||
|
# Back-compat: existing `except ValueError` callers keep working.
|
||||||
|
self.assertTrue(issubclass(VendorNotConfiguredError, ValueError))
|
||||||
|
|
||||||
|
def test_vendor_named_errors_subclass_the_generic_bases(self):
|
||||||
|
self.assertTrue(issubclass(AlphaVantageRateLimitError, VendorRateLimitError))
|
||||||
|
self.assertTrue(issubclass(AlphaVantageNotConfiguredError, VendorNotConfiguredError))
|
||||||
|
self.assertTrue(issubclass(FredNotConfiguredError, VendorNotConfiguredError))
|
||||||
|
# ... and therefore still ValueErrors
|
||||||
|
self.assertTrue(issubclass(FredNotConfiguredError, ValueError))
|
||||||
|
|
||||||
|
def test_symbol_utils_reexports_no_market_data_error(self):
|
||||||
|
from tradingagents.dataflows.symbol_utils import (
|
||||||
|
NoMarketDataError as ReExported,
|
||||||
|
)
|
||||||
|
self.assertIs(ReExported, NoMarketDataError)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class RouterHandlesBaseTypesTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def test_rate_limit_subclass_caught_by_base(self):
|
||||||
|
# A vendor-named rate-limit error skips to the next vendor in the chain.
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "alpha_vantage,yfinance"}})
|
||||||
|
|
||||||
|
def _throttled(*a, **k):
|
||||||
|
raise AlphaVantageRateLimitError("slow down")
|
||||||
|
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS,
|
||||||
|
{"get_stock_data": {"alpha_vantage": _throttled, "yfinance": lambda *a, **k: "YF"}},
|
||||||
|
clear=False,
|
||||||
|
):
|
||||||
|
out = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||||
|
self.assertEqual(out, "YF")
|
||||||
|
|
||||||
|
def test_not_configured_falls_through_to_next_vendor(self):
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "alpha_vantage,yfinance"}})
|
||||||
|
|
||||||
|
def _unconfigured(*a, **k):
|
||||||
|
raise AlphaVantageNotConfiguredError("no key")
|
||||||
|
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS,
|
||||||
|
{"get_stock_data": {"alpha_vantage": _unconfigured, "yfinance": lambda *a, **k: "YF"}},
|
||||||
|
clear=False,
|
||||||
|
):
|
||||||
|
out = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||||
|
self.assertEqual(out, "YF")
|
||||||
|
|
||||||
|
def test_sole_unconfigured_vendor_surfaces_the_error(self):
|
||||||
|
# With no fallback, the not-configured condition must surface (not vanish).
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "alpha_vantage"}})
|
||||||
|
|
||||||
|
def _unconfigured(*a, **k):
|
||||||
|
raise AlphaVantageNotConfiguredError("no key")
|
||||||
|
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS,
|
||||||
|
{"get_stock_data": {"alpha_vantage": _unconfigured}},
|
||||||
|
clear=False,
|
||||||
|
), self.assertRaises(AlphaVantageNotConfiguredError):
|
||||||
|
interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
101
tests/test_vendor_routing.py
Normal file
101
tests/test_vendor_routing.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""Vendor router must respect the configured chain and never silently hide a
|
||||||
|
broken primary.
|
||||||
|
|
||||||
|
Regressions for #988 (explicit single-vendor config still fell back to others),
|
||||||
|
#289 (fallback ran for unchosen vendors), and #989 (serious primary failures
|
||||||
|
were swallowed without a trace).
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.dataflows.config as config_module
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
from tradingagents.dataflows import interface
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
from tradingagents.dataflows.symbol_utils import NoMarketDataError
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_config():
|
||||||
|
# Hard reset: set_config() merges, so empty DEFAULT dicts (e.g. tool_vendors)
|
||||||
|
# don't clear keys leaked by other tests. Replace the global outright.
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
|
def _no_data(symbol, *a, **k):
|
||||||
|
raise NoMarketDataError(symbol, symbol, "no rows")
|
||||||
|
|
||||||
|
|
||||||
|
def _returns(value):
|
||||||
|
def impl(symbol, *a, **k):
|
||||||
|
return value
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
def _raises(exc):
|
||||||
|
def impl(symbol, *a, **k):
|
||||||
|
raise exc
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class VendorRoutingTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
_reset_config()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
_reset_config()
|
||||||
|
|
||||||
|
def _route(self, vendors_for_get_stock_data):
|
||||||
|
return mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS,
|
||||||
|
{"get_stock_data": vendors_for_get_stock_data},
|
||||||
|
clear=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_explicit_single_vendor_does_not_fall_back(self):
|
||||||
|
# #988: with yfinance pinned, a healthy alpha_vantage must NOT be used.
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "yfinance"}})
|
||||||
|
av = mock.Mock(side_effect=_returns("AV_DATA"))
|
||||||
|
with self._route({"yfinance": _no_data, "alpha_vantage": av}):
|
||||||
|
result = interface.route_to_vendor("get_stock_data", "FAKE", "2026-01-01", "2026-01-10")
|
||||||
|
self.assertIn("NO_DATA_AVAILABLE", result)
|
||||||
|
av.assert_not_called() # the unchosen vendor was never tried
|
||||||
|
|
||||||
|
def test_explicit_multi_vendor_falls_back_within_chain(self):
|
||||||
|
# Listing both vendors opts in to ordered fallback.
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "yfinance,alpha_vantage"}})
|
||||||
|
with self._route({"yfinance": _no_data, "alpha_vantage": _returns("AV_DATA")}):
|
||||||
|
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||||
|
self.assertEqual(result, "AV_DATA")
|
||||||
|
|
||||||
|
def test_primary_error_is_logged_not_masked(self):
|
||||||
|
# #989: primary errors + fallback no-data -> NO_DATA, but the failure
|
||||||
|
# must be visible in logs (broken primary not hidden).
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "yfinance,alpha_vantage"}})
|
||||||
|
with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}), \
|
||||||
|
self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm:
|
||||||
|
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||||
|
self.assertIn("NO_DATA_AVAILABLE", result)
|
||||||
|
joined = "\n".join(cm.output)
|
||||||
|
self.assertIn("boom", joined) # the real error surfaced in logs
|
||||||
|
self.assertIn("yfinance", joined)
|
||||||
|
|
||||||
|
def test_unknown_configured_vendor_raises(self):
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "bogus_vendor"}})
|
||||||
|
with self.assertRaises(ValueError) as ctx:
|
||||||
|
interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||||
|
self.assertIn("bogus_vendor", str(ctx.exception))
|
||||||
|
|
||||||
|
def test_default_sentinel_uses_all_vendors(self):
|
||||||
|
# No explicit choice ("default") keeps the resilient full-chain behavior.
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "default"}})
|
||||||
|
with self._route({"yfinance": _no_data, "alpha_vantage": _returns("AV_DATA")}):
|
||||||
|
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||||
|
self.assertEqual(result, "AV_DATA")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
113
tests/test_yfinance_stale_ohlcv_guard.py
Normal file
113
tests/test_yfinance_stale_ohlcv_guard.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Stale OHLCV guard (#1021): a vendor returning a year-old partial frame must
|
||||||
|
be rejected, not fed into the report as if it were current.
|
||||||
|
|
||||||
|
The guard raises NoMarketDataError with a stale-specific detail, so the router's
|
||||||
|
existing try-next-vendor + single-sentinel handling applies and the sentinel
|
||||||
|
surfaces the reason.
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import tradingagents.dataflows.config as config_module
|
||||||
|
import tradingagents.dataflows.y_finance as y_finance
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
|
from tradingagents.dataflows import interface
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
from tradingagents.dataflows.stockstats_utils import _assert_ohlcv_not_stale
|
||||||
|
from tradingagents.dataflows.symbol_utils import NoMarketDataError
|
||||||
|
|
||||||
|
|
||||||
|
def _frame(date):
|
||||||
|
return pd.DataFrame(
|
||||||
|
{
|
||||||
|
"Date": [pd.Timestamp(date)],
|
||||||
|
"Open": [330.0],
|
||||||
|
"High": [332.0],
|
||||||
|
"Low": [328.0],
|
||||||
|
"Close": [330.58],
|
||||||
|
"Volume": [1_000_000],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class StaleGuardUnitTests(unittest.TestCase):
|
||||||
|
def test_recent_prior_trading_day_is_accepted(self):
|
||||||
|
# 1 day before curr_date — well within the freshness window.
|
||||||
|
_assert_ohlcv_not_stale(_frame("2026-06-10"), "2026-06-11", "CB")
|
||||||
|
|
||||||
|
def test_year_old_row_is_rejected_with_detail(self):
|
||||||
|
with self.assertRaises(NoMarketDataError) as ctx:
|
||||||
|
_assert_ohlcv_not_stale(_frame("2025-06-11"), "2026-06-11", "CB", "CB")
|
||||||
|
msg = str(ctx.exception)
|
||||||
|
self.assertIn("2025-06-11", msg)
|
||||||
|
self.assertIn("2026-06-11", msg)
|
||||||
|
self.assertIn("stale", msg)
|
||||||
|
|
||||||
|
def test_empty_frame_is_left_to_caller(self):
|
||||||
|
# Empty is a no-data condition handled elsewhere, not a staleness one.
|
||||||
|
_assert_ohlcv_not_stale(
|
||||||
|
pd.DataFrame(columns=["Date", "Close"]), "2026-06-11", "X"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_long_holiday_gap_within_threshold_is_accepted(self):
|
||||||
|
_assert_ohlcv_not_stale(_frame("2026-06-02"), "2026-06-11", "X") # 9 days
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class StaleGuardPropagationTests(unittest.TestCase):
|
||||||
|
def test_get_yfin_data_online_raises_on_stale_frame(self):
|
||||||
|
stale = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"Open": [280.0], "High": [286.0], "Low": [278.0],
|
||||||
|
"Close": [284.45], "Volume": [1_000_000],
|
||||||
|
},
|
||||||
|
index=pd.DatetimeIndex([pd.Timestamp("2025-06-11")], name="Date"),
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyTicker:
|
||||||
|
def __init__(self, symbol):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def history(self, start, end):
|
||||||
|
return stale
|
||||||
|
|
||||||
|
with mock.patch.object(y_finance.yf, "Ticker", DummyTicker), \
|
||||||
|
self.assertRaises(NoMarketDataError):
|
||||||
|
y_finance.get_YFin_data_online("CB", "2026-06-01", "2026-06-11")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class StaleGuardRoutingTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
config_module._config = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def test_router_sentinel_surfaces_stale_reason(self):
|
||||||
|
set_config({"data_vendors": {"core_stock_apis": "yfinance"}})
|
||||||
|
|
||||||
|
def _stale(symbol, *a, **k):
|
||||||
|
raise NoMarketDataError(
|
||||||
|
symbol, symbol, "latest row is 2025-06-11, 365 days before ... (stale)"
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.dict(
|
||||||
|
interface.VENDOR_METHODS,
|
||||||
|
{"get_stock_data": {"yfinance": _stale}},
|
||||||
|
clear=False,
|
||||||
|
):
|
||||||
|
out = interface.route_to_vendor(
|
||||||
|
"get_stock_data", "CB", "2026-06-01", "2026-06-11"
|
||||||
|
)
|
||||||
|
self.assertIn("NO_DATA_AVAILABLE", out)
|
||||||
|
self.assertIn("stale", out) # the typed detail is surfaced to the agent
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
|
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
|
||||||
@@ -20,10 +21,8 @@ except ImportError:
|
|||||||
# subclassed warning categories. To suppress a specific warning we must
|
# subclassed warning categories. To suppress a specific warning we must
|
||||||
# install our filter AFTER langchain-core has installed its own, so import
|
# install our filter AFTER langchain-core has installed its own, so import
|
||||||
# it first. The package is a guaranteed transitive dep via langgraph.
|
# it first. The package is a guaranteed transitive dep via langgraph.
|
||||||
try:
|
with contextlib.suppress(ImportError):
|
||||||
import langchain_core # noqa: F401
|
import langchain_core # noqa: F401
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
|
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
|
||||||
# explicit allowed_objects, which triggers a noisy pending-deprecation
|
# explicit allowed_objects, which triggers a noisy pending-deprecation
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
from .utils.agent_utils import create_msg_delete
|
|
||||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
|
||||||
|
|
||||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||||
from .analysts.market_analyst import create_market_analyst
|
from .analysts.market_analyst import create_market_analyst
|
||||||
from .analysts.news_analyst import create_news_analyst
|
from .analysts.news_analyst import create_news_analyst
|
||||||
@@ -8,18 +5,16 @@ from .analysts.sentiment_analyst import (
|
|||||||
create_sentiment_analyst,
|
create_sentiment_analyst,
|
||||||
create_social_media_analyst, # deprecated alias kept for back-compat
|
create_social_media_analyst, # deprecated alias kept for back-compat
|
||||||
)
|
)
|
||||||
|
from .managers.portfolio_manager import create_portfolio_manager
|
||||||
|
from .managers.research_manager import create_research_manager
|
||||||
from .researchers.bear_researcher import create_bear_researcher
|
from .researchers.bear_researcher import create_bear_researcher
|
||||||
from .researchers.bull_researcher import create_bull_researcher
|
from .researchers.bull_researcher import create_bull_researcher
|
||||||
|
|
||||||
from .risk_mgmt.aggressive_debator import create_aggressive_debator
|
from .risk_mgmt.aggressive_debator import create_aggressive_debator
|
||||||
from .risk_mgmt.conservative_debator import create_conservative_debator
|
from .risk_mgmt.conservative_debator import create_conservative_debator
|
||||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||||
|
|
||||||
from .managers.research_manager import create_research_manager
|
|
||||||
from .managers.portfolio_manager import create_portfolio_manager
|
|
||||||
|
|
||||||
from .trader.trader import create_trader
|
from .trader.trader import create_trader
|
||||||
|
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||||
|
from .utils.agent_utils import create_msg_delete
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentState",
|
"AgentState",
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
|
||||||
get_balance_sheet,
|
get_balance_sheet,
|
||||||
get_cashflow,
|
get_cashflow,
|
||||||
get_fundamentals,
|
get_fundamentals,
|
||||||
get_income_statement,
|
get_income_statement,
|
||||||
get_insider_transactions,
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
|
||||||
|
|
||||||
|
|
||||||
def create_fundamentals_analyst(llm):
|
def create_fundamentals_analyst(llm):
|
||||||
def fundamentals_analyst_node(state):
|
def fundamentals_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
get_fundamentals,
|
get_fundamentals,
|
||||||
|
|||||||
@@ -1,22 +1,24 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
|
||||||
get_indicators,
|
get_indicators,
|
||||||
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
|
get_verified_market_snapshot,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
|
||||||
|
|
||||||
|
|
||||||
def create_market_analyst(llm):
|
def create_market_analyst(llm):
|
||||||
|
|
||||||
def market_analyst_node(state):
|
def market_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
get_indicators,
|
get_indicators,
|
||||||
|
get_verified_market_snapshot,
|
||||||
]
|
]
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
@@ -44,7 +46,11 @@ Volatility Indicators:
|
|||||||
Volume-Based Indicators:
|
Volume-Based Indicators:
|
||||||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||||
|
|
||||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
|
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names.
|
||||||
|
|
||||||
|
Before writing the final report, call get_verified_market_snapshot for this ticker and the current date, and treat it as the source of truth for any exact OHLCV, price-level, or indicator-value claim. If another tool's output conflicts with the verified snapshot, flag the discrepancy rather than inventing a reconciled number. Do not claim historical validation, support/resistance bounces, or exact percentage moves unless they are directly supported by tool output with concrete dates and prices.
|
||||||
|
|
||||||
|
Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
|
||||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
+ get_language_instruction()
|
+ get_language_instruction()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,25 +1,31 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
|
||||||
get_global_news,
|
get_global_news,
|
||||||
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
|
get_macro_indicators,
|
||||||
get_news,
|
get_news,
|
||||||
|
get_prediction_markets,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
|
||||||
|
|
||||||
|
|
||||||
def create_news_analyst(llm):
|
def create_news_analyst(llm):
|
||||||
def news_analyst_node(state):
|
def news_analyst_node(state):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
asset_type = state.get("asset_type", "stock")
|
||||||
|
asset_label = "company" if asset_type == "stock" else "asset"
|
||||||
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
get_news,
|
get_news,
|
||||||
get_global_news,
|
get_global_news,
|
||||||
|
get_macro_indicators,
|
||||||
|
get_prediction_markets,
|
||||||
]
|
]
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
f"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for {asset_label}-specific or targeted news searches, get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news, get_macro_indicators(indicator, curr_date, look_back_days) to ground macro commentary in actual data from FRED (e.g. 'cpi', 'core_pce', 'unemployment', 'fed_funds_rate', '10y_treasury', 'yield_curve'), and get_prediction_markets(topic, limit) for live market-implied probabilities of forward-looking events (e.g. 'Fed rate cut', 'recession 2026', geopolitical or sector events). Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
|
||||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||||
+ get_language_instruction()
|
+ get_language_instruction()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,19 +14,31 @@ the LLM is invoked and injects them into the prompt as structured blocks:
|
|||||||
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
|
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
|
||||||
|
|
||||||
The agent does not use tool-calling; the data is in the prompt from
|
The agent does not use tool-calling; the data is in the prompt from
|
||||||
turn 0. The LLM produces the sentiment report in a single invocation.
|
turn 0. Output uses the structured-output pattern (json_schema for
|
||||||
|
OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic), falling
|
||||||
|
back to free-text generation for providers that lack native support, so
|
||||||
|
the sentiment header (band + score + confidence) is deterministic across
|
||||||
|
runs and providers instead of free-form per-model prose.
|
||||||
|
|
||||||
See: https://github.com/TauricResearch/TradingAgents/issues/557
|
See: https://github.com/TauricResearch/TradingAgents/issues/557
|
||||||
|
See: https://github.com/TauricResearch/TradingAgents/issues/796
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
|
from tradingagents.agents.schemas import SentimentReport, render_sentiment_report
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
get_news,
|
get_news,
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.structured import (
|
||||||
|
bind_structured,
|
||||||
|
invoke_structured_or_freetext,
|
||||||
|
)
|
||||||
from tradingagents.dataflows.reddit import fetch_reddit_posts
|
from tradingagents.dataflows.reddit import fetch_reddit_posts
|
||||||
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
|
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
|
||||||
|
|
||||||
@@ -39,15 +51,17 @@ def create_sentiment_analyst(llm):
|
|||||||
"""Create a sentiment analyst node for the trading graph.
|
"""Create a sentiment analyst node for the trading graph.
|
||||||
|
|
||||||
Pre-fetches news + StockTwits + Reddit data, injects them into the
|
Pre-fetches news + StockTwits + Reddit data, injects them into the
|
||||||
prompt as structured blocks, and produces a sentiment report in a
|
prompt as structured blocks, and produces a deterministic sentiment
|
||||||
single LLM call.
|
report via structured output (with a free-text fallback for providers
|
||||||
|
that do not support it).
|
||||||
"""
|
"""
|
||||||
|
structured_llm = bind_structured(llm, SentimentReport, "Sentiment Analyst")
|
||||||
|
|
||||||
def sentiment_analyst_node(state):
|
def sentiment_analyst_node(state):
|
||||||
ticker = state["company_of_interest"]
|
ticker = state["company_of_interest"]
|
||||||
end_date = state["trade_date"]
|
end_date = state["trade_date"]
|
||||||
start_date = _seven_days_back(end_date)
|
start_date = _seven_days_back(end_date)
|
||||||
instrument_context = build_instrument_context(ticker)
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
|
||||||
# Pre-fetch all three sources. Each fetcher degrades gracefully and
|
# Pre-fetch all three sources. Each fetcher degrades gracefully and
|
||||||
# returns a string (no exceptions surface from here), so the LLM
|
# returns a string (no exceptions surface from here), so the LLM
|
||||||
@@ -83,14 +97,22 @@ def create_sentiment_analyst(llm):
|
|||||||
prompt = prompt.partial(current_date=end_date)
|
prompt = prompt.partial(current_date=end_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
# No bind_tools — the data is already in the prompt; a single LLM
|
# Format the template into a concrete message list so the structured
|
||||||
# call produces the report directly.
|
# and free-text paths receive the same input. No bind_tools — the
|
||||||
chain = prompt | llm
|
# data is already in the prompt.
|
||||||
result = chain.invoke(state["messages"])
|
formatted_messages = prompt.format_messages(messages=state["messages"])
|
||||||
|
|
||||||
|
report_text = invoke_structured_or_freetext(
|
||||||
|
structured_llm,
|
||||||
|
llm,
|
||||||
|
formatted_messages,
|
||||||
|
render_sentiment_report,
|
||||||
|
"Sentiment Analyst",
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [AIMessage(content=report_text)],
|
||||||
"sentiment_report": result.content,
|
"sentiment_report": report_text,
|
||||||
}
|
}
|
||||||
|
|
||||||
return sentiment_analyst_node
|
return sentiment_analyst_node
|
||||||
@@ -143,21 +165,20 @@ Community discussion. Engagement signal via upvote score and comment count. Subr
|
|||||||
|
|
||||||
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
|
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
|
||||||
|
|
||||||
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this caveat explicitly. If the sources are silent on a given subreddit, say so.
|
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this explicitly in the `confidence` field and the narrative. If the sources are silent on a given subreddit, say so.
|
||||||
|
|
||||||
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
|
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
|
||||||
|
|
||||||
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
|
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
|
||||||
|
|
||||||
## Output
|
## Output fields
|
||||||
|
|
||||||
Produce a sentiment report covering, in order:
|
Fill the following fields:
|
||||||
|
|
||||||
1. **Overall sentiment direction** — Bullish / Bearish / Neutral / Mixed — with a brief confidence note based on data quality and sample size.
|
- **overall_band**: Exactly one of Bullish / Mildly Bullish / Neutral / Mixed / Mildly Bearish / Bearish. Use Mixed when sources point in clearly different directions; Neutral only when all sources are genuinely silent.
|
||||||
2. **Source-by-source breakdown** — what each of news / StockTwits / Reddit is telling you, with specific evidence (cite message counts, ratios, notable posts).
|
- **overall_score**: A number from 0 (maximally bearish) to 10 (maximally bullish); 5 is neutral. Keep it consistent with overall_band.
|
||||||
3. **Divergences, alignments, and key narratives** across sources.
|
- **confidence**: low / medium / high, based on data quality and sample size.
|
||||||
4. **Catalysts and risks** surfaced by the data.
|
- **narrative**: Full source-by-source breakdown, divergences, dominant narrative themes, catalysts and risks, and a markdown summary table of key sentiment signals (direction, source, supporting evidence).
|
||||||
5. **Markdown table** at the end summarizing key sentiment signals, their direction, source, and supporting evidence.
|
|
||||||
|
|
||||||
{get_language_instruction()}"""
|
{get_language_instruction()}"""
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from tradingagents.agents.schemas import PortfolioDecision, render_pm_decision
|
from tradingagents.agents.schemas import PortfolioDecision, render_pm_decision
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
)
|
)
|
||||||
from tradingagents.agents.utils.structured import (
|
from tradingagents.agents.utils.structured import (
|
||||||
@@ -25,7 +25,7 @@ def create_portfolio_manager(llm):
|
|||||||
structured_llm = bind_structured(llm, PortfolioDecision, "Portfolio Manager")
|
structured_llm = bind_structured(llm, PortfolioDecision, "Portfolio Manager")
|
||||||
|
|
||||||
def portfolio_manager_node(state) -> dict:
|
def portfolio_manager_node(state) -> dict:
|
||||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
|
||||||
history = state["risk_debate_state"]["history"]
|
history = state["risk_debate_state"]["history"]
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from tradingagents.agents.schemas import ResearchPlan, render_research_plan
|
from tradingagents.agents.schemas import ResearchPlan, render_research_plan
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
)
|
)
|
||||||
from tradingagents.agents.utils.structured import (
|
from tradingagents.agents.utils.structured import (
|
||||||
@@ -17,7 +17,7 @@ def create_research_manager(llm):
|
|||||||
structured_llm = bind_structured(llm, ResearchPlan, "Research Manager")
|
structured_llm = bind_structured(llm, ResearchPlan, "Research Manager")
|
||||||
|
|
||||||
def research_manager_node(state) -> dict:
|
def research_manager_node(state) -> dict:
|
||||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
history = state["investment_debate_state"].get("history", "")
|
history = state["investment_debate_state"].get("history", "")
|
||||||
|
|
||||||
investment_debate_state = state["investment_debate_state"]
|
investment_debate_state = state["investment_debate_state"]
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
get_instrument_context_from_state,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_bear_researcher(llm):
|
def create_bear_researcher(llm):
|
||||||
@@ -12,8 +15,16 @@ def create_bear_researcher(llm):
|
|||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
asset_type = state.get("asset_type", "stock")
|
||||||
|
target_label = "stock" if asset_type == "stock" else "asset"
|
||||||
|
fundamentals_label = (
|
||||||
|
"Company fundamentals report"
|
||||||
|
if asset_type == "stock"
|
||||||
|
else "Asset fundamentals report (may be unavailable for crypto)"
|
||||||
|
)
|
||||||
|
|
||||||
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
prompt = f"""You are a Bear Analyst making the case against investing in the {target_label}. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||||
|
|
||||||
Key points to focus on:
|
Key points to focus on:
|
||||||
|
|
||||||
@@ -25,13 +36,14 @@ Key points to focus on:
|
|||||||
|
|
||||||
Resources available:
|
Resources available:
|
||||||
|
|
||||||
|
{instrument_context}
|
||||||
Market research report: {market_research_report}
|
Market research report: {market_research_report}
|
||||||
Social media sentiment report: {sentiment_report}
|
Social media sentiment report: {sentiment_report}
|
||||||
Latest world affairs news: {news_report}
|
Latest world affairs news: {news_report}
|
||||||
Company fundamentals report: {fundamentals_report}
|
{fundamentals_label}: {fundamentals_report}
|
||||||
Conversation history of the debate: {history}
|
Conversation history of the debate: {history}
|
||||||
Last bull argument: {current_response}
|
Last bull argument: {current_response}
|
||||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock.
|
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the {target_label}.
|
||||||
""" + get_language_instruction()
|
""" + get_language_instruction()
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
get_instrument_context_from_state,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_bull_researcher(llm):
|
def create_bull_researcher(llm):
|
||||||
@@ -12,8 +15,16 @@ def create_bull_researcher(llm):
|
|||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
asset_type = state.get("asset_type", "stock")
|
||||||
|
target_label = "stock" if asset_type == "stock" else "asset"
|
||||||
|
fundamentals_label = (
|
||||||
|
"Company fundamentals report"
|
||||||
|
if asset_type == "stock"
|
||||||
|
else "Asset fundamentals report (may be unavailable for crypto)"
|
||||||
|
)
|
||||||
|
|
||||||
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
prompt = f"""You are a Bull Analyst advocating for investing in the {target_label}. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||||
|
|
||||||
Key points to focus on:
|
Key points to focus on:
|
||||||
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
||||||
@@ -23,10 +34,11 @@ Key points to focus on:
|
|||||||
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
|
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
|
||||||
|
|
||||||
Resources available:
|
Resources available:
|
||||||
|
{instrument_context}
|
||||||
Market research report: {market_research_report}
|
Market research report: {market_research_report}
|
||||||
Social media sentiment report: {sentiment_report}
|
Social media sentiment report: {sentiment_report}
|
||||||
Latest world affairs news: {news_report}
|
Latest world affairs news: {news_report}
|
||||||
Company fundamentals report: {fundamentals_report}
|
{fundamentals_label}: {fundamentals_report}
|
||||||
Conversation history of the debate: {history}
|
Conversation history of the debate: {history}
|
||||||
Last bear argument: {current_response}
|
Last bear argument: {current_response}
|
||||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position.
|
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position.
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
get_instrument_context_from_state,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_aggressive_debator(llm):
|
def create_aggressive_debator(llm):
|
||||||
@@ -14,6 +17,7 @@ def create_aggressive_debator(llm):
|
|||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
|
||||||
trader_decision = state["trader_investment_plan"]
|
trader_decision = state["trader_investment_plan"]
|
||||||
|
|
||||||
@@ -23,6 +27,7 @@ def create_aggressive_debator(llm):
|
|||||||
|
|
||||||
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
|
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
|
||||||
|
|
||||||
|
{instrument_context}
|
||||||
Market Research Report: {market_research_report}
|
Market Research Report: {market_research_report}
|
||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
get_instrument_context_from_state,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_conservative_debator(llm):
|
def create_conservative_debator(llm):
|
||||||
@@ -14,6 +17,7 @@ def create_conservative_debator(llm):
|
|||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
|
||||||
trader_decision = state["trader_investment_plan"]
|
trader_decision = state["trader_investment_plan"]
|
||||||
|
|
||||||
@@ -23,6 +27,7 @@ def create_conservative_debator(llm):
|
|||||||
|
|
||||||
Your task is to actively counter the arguments of the Aggressive and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
Your task is to actively counter the arguments of the Aggressive and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
||||||
|
|
||||||
|
{instrument_context}
|
||||||
Market Research Report: {market_research_report}
|
Market Research Report: {market_research_report}
|
||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from tradingagents.agents.utils.agent_utils import get_language_instruction
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
|
get_instrument_context_from_state,
|
||||||
|
get_language_instruction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_neutral_debator(llm):
|
def create_neutral_debator(llm):
|
||||||
@@ -14,6 +17,7 @@ def create_neutral_debator(llm):
|
|||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["fundamentals_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
|
|
||||||
trader_decision = state["trader_investment_plan"]
|
trader_decision = state["trader_investment_plan"]
|
||||||
|
|
||||||
@@ -23,6 +27,7 @@ def create_neutral_debator(llm):
|
|||||||
|
|
||||||
Your task is to challenge both the Aggressive and Conservative Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
Your task is to challenge both the Aggressive and Conservative Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
||||||
|
|
||||||
|
{instrument_context}
|
||||||
Market Research Report: {market_research_report}
|
Market Research Report: {market_research_report}
|
||||||
Social Media Sentiment Report: {sentiment_report}
|
Social Media Sentiment Report: {sentiment_report}
|
||||||
Latest World Affairs Report: {news_report}
|
Latest World Affairs Report: {news_report}
|
||||||
|
|||||||
@@ -19,11 +19,10 @@ so that:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Shared rating types
|
# Shared rating types
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -124,15 +123,15 @@ class TraderProposal(BaseModel):
|
|||||||
"the research plan. Two to four sentences."
|
"the research plan. Two to four sentences."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
entry_price: Optional[float] = Field(
|
entry_price: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional entry price target in the instrument's quote currency.",
|
description="Optional entry price target in the instrument's quote currency.",
|
||||||
)
|
)
|
||||||
stop_loss: Optional[float] = Field(
|
stop_loss: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional stop-loss price in the instrument's quote currency.",
|
description="Optional stop-loss price in the instrument's quote currency.",
|
||||||
)
|
)
|
||||||
position_sizing: Optional[str] = Field(
|
position_sizing: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional sizing guidance, e.g. '5% of portfolio'.",
|
description="Optional sizing guidance, e.g. '5% of portfolio'.",
|
||||||
)
|
)
|
||||||
@@ -196,11 +195,11 @@ class PortfolioDecision(BaseModel):
|
|||||||
"incorporate them; otherwise rely solely on the current analysis."
|
"incorporate them; otherwise rely solely on the current analysis."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
price_target: Optional[float] = Field(
|
price_target: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional target price in the instrument's quote currency.",
|
description="Optional target price in the instrument's quote currency.",
|
||||||
)
|
)
|
||||||
time_horizon: Optional[str] = Field(
|
time_horizon: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional recommended holding period, e.g. '3-6 months'.",
|
description="Optional recommended holding period, e.g. '3-6 months'.",
|
||||||
)
|
)
|
||||||
@@ -226,3 +225,94 @@ def render_pm_decision(decision: PortfolioDecision) -> str:
|
|||||||
if decision.time_horizon:
|
if decision.time_horizon:
|
||||||
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
|
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sentiment Analyst
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentBand(str, Enum):
|
||||||
|
"""Discrete sentiment direction produced by the Sentiment Analyst.
|
||||||
|
|
||||||
|
Six tiers keep the signal granular enough to be actionable while remaining
|
||||||
|
small enough for every provider to map reliably from its JSON output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BULLISH = "Bullish"
|
||||||
|
MILDLY_BULLISH = "Mildly Bullish"
|
||||||
|
NEUTRAL = "Neutral"
|
||||||
|
MIXED = "Mixed"
|
||||||
|
MILDLY_BEARISH = "Mildly Bearish"
|
||||||
|
BEARISH = "Bearish"
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentReport(BaseModel):
|
||||||
|
"""Structured sentiment report produced by the Sentiment Analyst.
|
||||||
|
|
||||||
|
Replaces the previous free-form prose output so downstream consumers
|
||||||
|
(dashboards, audit logs, PDF renderers, other agents) can read
|
||||||
|
``overall_band`` and ``overall_score`` without maintaining fragile regex
|
||||||
|
fallbacks that drift with every model release. ``narrative`` preserves the
|
||||||
|
rich source-by-source analysis; ``render_sentiment_report`` prepends a
|
||||||
|
deterministic header so the saved report stays human-readable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
overall_band: SentimentBand = Field(
|
||||||
|
description=(
|
||||||
|
"Overall sentiment direction. Exactly one of: "
|
||||||
|
"Bullish / Mildly Bullish / Neutral / Mixed / Mildly Bearish / Bearish. "
|
||||||
|
"Use Mixed when sources point in clearly different directions. "
|
||||||
|
"Use Neutral only when all sources are genuinely silent or non-committal."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
overall_score: float = Field(
|
||||||
|
ge=0.0,
|
||||||
|
le=10.0,
|
||||||
|
description=(
|
||||||
|
"Numeric sentiment intensity on a 0–10 scale. "
|
||||||
|
"0 = maximally bearish, 5 = neutral, 10 = maximally bullish. "
|
||||||
|
"Guideline for consistency with overall_band: "
|
||||||
|
"Bullish ~6.5–10, Mildly Bullish ~5.5–6.4, Neutral/Mixed ~4.5–5.5, "
|
||||||
|
"Mildly Bearish ~3.5–4.4, Bearish ~0–3.4. "
|
||||||
|
"Only the 0–10 bounds are enforced."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
confidence: Literal["low", "medium", "high"] = Field(
|
||||||
|
description=(
|
||||||
|
"Confidence in the assessment based on data quality and sample size. "
|
||||||
|
"Use 'low' when one or more sources returned a placeholder or fewer "
|
||||||
|
"than 5 data points; 'medium' when data is present but sparse; "
|
||||||
|
"'high' when all three sources returned substantive data."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
narrative: str = Field(
|
||||||
|
description=(
|
||||||
|
"Full sentiment report covering, in order: "
|
||||||
|
"(1) source-by-source breakdown with specific evidence (cite message "
|
||||||
|
"counts, ratios, notable posts); "
|
||||||
|
"(2) cross-source divergences and alignments; "
|
||||||
|
"(3) dominant narrative themes; "
|
||||||
|
"(4) catalysts and risks surfaced by the data; "
|
||||||
|
"(5) a markdown table summarising key sentiment signals, their "
|
||||||
|
"direction, source, and supporting evidence. "
|
||||||
|
"Keep it informative and substantive: develop each section thoroughly "
|
||||||
|
"with concrete evidence so every point adds new signal for the trader."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_sentiment_report(report: SentimentReport) -> str:
|
||||||
|
"""Render a SentimentReport to the markdown shape the rest of the system expects.
|
||||||
|
|
||||||
|
The structured header (band + score + confidence) is prepended to the
|
||||||
|
narrative so the saved report is both human-readable and machine-parseable
|
||||||
|
without regex.
|
||||||
|
"""
|
||||||
|
return "\n".join([
|
||||||
|
f"**Overall Sentiment:** **{report.overall_band.value}** "
|
||||||
|
f"(Score: {report.overall_score:.1f}/10)",
|
||||||
|
f"**Confidence:** {report.confidence.capitalize()}",
|
||||||
|
"",
|
||||||
|
report.narrative,
|
||||||
|
])
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
|
|||||||
|
|
||||||
from tradingagents.agents.schemas import TraderProposal, render_trader_proposal
|
from tradingagents.agents.schemas import TraderProposal, render_trader_proposal
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
)
|
)
|
||||||
from tradingagents.agents.utils.structured import (
|
from tradingagents.agents.utils.structured import (
|
||||||
@@ -22,7 +22,7 @@ def create_trader(llm):
|
|||||||
|
|
||||||
def trader_node(state, name):
|
def trader_node(state, name):
|
||||||
company_name = state["company_of_interest"]
|
company_name = state["company_of_interest"]
|
||||||
instrument_context = build_instrument_context(company_name)
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
investment_plan = state["investment_plan"]
|
investment_plan = state["investment_plan"]
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from typing_extensions import TypedDict
|
|
||||||
from langgraph.graph import MessagesState
|
from langgraph.graph import MessagesState
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
# Researcher team state
|
# Researcher team state
|
||||||
@@ -45,6 +46,8 @@ class RiskDebateState(TypedDict):
|
|||||||
|
|
||||||
class AgentState(MessagesState):
|
class AgentState(MessagesState):
|
||||||
company_of_interest: Annotated[str, "Company that we are interested in trading"]
|
company_of_interest: Annotated[str, "Company that we are interested in trading"]
|
||||||
|
asset_type: Annotated[str, "Asset type under analysis such as stock or crypto"]
|
||||||
|
instrument_context: Annotated[str, "Deterministic ticker identity resolved at run start"]
|
||||||
trade_date: Annotated[str, "What date we are trading at"]
|
trade_date: Annotated[str, "What date we are trading at"]
|
||||||
|
|
||||||
sender: Annotated[str, "Agent that sent this message"]
|
sender: Annotated[str, "Agent that sent this message"]
|
||||||
|
|||||||
@@ -1,23 +1,52 @@
|
|||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yfinance as yf
|
||||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||||
|
|
||||||
# Import tools from separate utility files
|
# Import tools from separate utility files
|
||||||
from tradingagents.agents.utils.core_stock_tools import (
|
from tradingagents.agents.utils.core_stock_tools import get_stock_data
|
||||||
get_stock_data
|
|
||||||
)
|
|
||||||
from tradingagents.agents.utils.technical_indicators_tools import (
|
|
||||||
get_indicators
|
|
||||||
)
|
|
||||||
from tradingagents.agents.utils.fundamental_data_tools import (
|
from tradingagents.agents.utils.fundamental_data_tools import (
|
||||||
get_fundamentals,
|
|
||||||
get_balance_sheet,
|
get_balance_sheet,
|
||||||
get_cashflow,
|
get_cashflow,
|
||||||
get_income_statement
|
get_fundamentals,
|
||||||
|
get_income_statement,
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.macro_data_tools import get_macro_indicators
|
||||||
|
from tradingagents.agents.utils.market_data_validation_tools import get_verified_market_snapshot
|
||||||
from tradingagents.agents.utils.news_data_tools import (
|
from tradingagents.agents.utils.news_data_tools import (
|
||||||
get_news,
|
get_global_news,
|
||||||
get_insider_transactions,
|
get_insider_transactions,
|
||||||
get_global_news
|
get_news,
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.prediction_markets_tools import get_prediction_markets
|
||||||
|
from tradingagents.agents.utils.technical_indicators_tools import get_indicators
|
||||||
|
|
||||||
|
# Public surface: the data tools are imported here so agents and the graph
|
||||||
|
# import them from one place, plus the instrument/language helpers defined below.
|
||||||
|
__all__ = [
|
||||||
|
"get_stock_data",
|
||||||
|
"get_indicators",
|
||||||
|
"get_fundamentals",
|
||||||
|
"get_balance_sheet",
|
||||||
|
"get_cashflow",
|
||||||
|
"get_income_statement",
|
||||||
|
"get_news",
|
||||||
|
"get_global_news",
|
||||||
|
"get_insider_transactions",
|
||||||
|
"get_macro_indicators",
|
||||||
|
"get_prediction_markets",
|
||||||
|
"get_verified_market_snapshot",
|
||||||
|
"build_instrument_context",
|
||||||
|
"resolve_instrument_identity",
|
||||||
|
"get_instrument_context_from_state",
|
||||||
|
"get_language_instruction",
|
||||||
|
"create_msg_delete",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_language_instruction() -> str:
|
def get_language_instruction() -> str:
|
||||||
@@ -36,28 +65,153 @@ def get_language_instruction() -> str:
|
|||||||
return f" Write your entire response in {lang}."
|
return f" Write your entire response in {lang}."
|
||||||
|
|
||||||
|
|
||||||
def build_instrument_context(ticker: str) -> str:
|
def _clean_identity_value(value: Any) -> str | None:
|
||||||
"""Describe the exact instrument so agents preserve exchange-qualified tickers."""
|
"""Return a trimmed string, or None for empty / placeholder-ish values."""
|
||||||
return (
|
if not isinstance(value, str):
|
||||||
f"The instrument to analyze is `{ticker}`. "
|
return None
|
||||||
"Use this exact ticker in every tool call, report, and recommendation, "
|
cleaned = value.strip()
|
||||||
"preserving any exchange suffix (e.g. `.TO`, `.L`, `.HK`, `.T`)."
|
if not cleaned or cleaned.lower() in {"none", "n/a", "nan", "null"}:
|
||||||
|
return None
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=256)
|
||||||
|
def resolve_instrument_identity(ticker: str) -> dict:
|
||||||
|
"""Resolve deterministic identity metadata (company name, sector, …) for a ticker.
|
||||||
|
|
||||||
|
This exists to stop the pipeline from hallucinating a *different* company
|
||||||
|
when a chart pattern suggests a different industry than the real one
|
||||||
|
(#814): without a ground-truth name, the market analyst would pattern-match
|
||||||
|
the price action to a narrative and invent an identity that then cascaded
|
||||||
|
through every downstream agent.
|
||||||
|
|
||||||
|
Best-effort by design: if yfinance is unavailable, rate-limited, or doesn't
|
||||||
|
recognise the ticker, we return ``{}`` and the caller falls back to
|
||||||
|
ticker-only context rather than failing before analysis starts. Cached so
|
||||||
|
the lookup happens at most once per ticker per process.
|
||||||
|
|
||||||
|
The symbol is normalized first (e.g. ``XAUUSD`` -> ``GC=F``) so identity
|
||||||
|
resolves for the same instrument the price path actually fetches (#983).
|
||||||
|
"""
|
||||||
|
from tradingagents.dataflows.symbol_utils import normalize_symbol
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = yf.Ticker(normalize_symbol(ticker)).info or {}
|
||||||
|
except Exception as exc: # noqa: BLE001 — fail open, never block the run
|
||||||
|
logger.debug("Could not resolve instrument identity for %s: %s", ticker, exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
identity: dict[str, str] = {}
|
||||||
|
company_name = _clean_identity_value(info.get("longName")) or _clean_identity_value(
|
||||||
|
info.get("shortName")
|
||||||
)
|
)
|
||||||
|
if company_name:
|
||||||
|
identity["company_name"] = company_name
|
||||||
|
for source_key, target_key in (
|
||||||
|
("sector", "sector"),
|
||||||
|
("industry", "industry"),
|
||||||
|
("exchange", "exchange"),
|
||||||
|
("quoteType", "quote_type"),
|
||||||
|
):
|
||||||
|
value = _clean_identity_value(info.get(source_key))
|
||||||
|
if value:
|
||||||
|
identity[target_key] = value
|
||||||
|
return identity
|
||||||
|
|
||||||
|
|
||||||
|
def build_instrument_context(
|
||||||
|
ticker: str,
|
||||||
|
asset_type: str = "stock",
|
||||||
|
identity: Mapping[str, str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Describe the exact instrument so agents preserve identity and ticker.
|
||||||
|
|
||||||
|
When ``identity`` is provided (resolved deterministically via
|
||||||
|
:func:`resolve_instrument_identity`), the company name and business
|
||||||
|
classification are injected so agents anchor to the real company rather
|
||||||
|
than pattern-matching the price chart to a wrong one (#814).
|
||||||
|
"""
|
||||||
|
is_crypto = asset_type == "crypto"
|
||||||
|
instrument_label = "asset" if is_crypto else "instrument"
|
||||||
|
context = (
|
||||||
|
f"The {instrument_label} to analyze is `{ticker}`. "
|
||||||
|
"Use this exact ticker in every tool call, report, and recommendation, "
|
||||||
|
"preserving any exchange suffix (e.g. `.TO`, `.L`, `.HK`, `.T`, `-USD`)."
|
||||||
|
)
|
||||||
|
|
||||||
|
details = []
|
||||||
|
if identity:
|
||||||
|
name = identity.get("company_name") or identity.get("name")
|
||||||
|
if name:
|
||||||
|
details.append(f"{'Name' if is_crypto else 'Company'}: {name}")
|
||||||
|
sector, industry = identity.get("sector"), identity.get("industry")
|
||||||
|
if sector and industry:
|
||||||
|
details.append(f"Business classification: {sector} / {industry}")
|
||||||
|
elif sector:
|
||||||
|
details.append(f"Sector: {sector}")
|
||||||
|
elif industry:
|
||||||
|
details.append(f"Industry: {industry}")
|
||||||
|
if identity.get("exchange"):
|
||||||
|
details.append(f"Exchange: {identity['exchange']}")
|
||||||
|
|
||||||
|
if details:
|
||||||
|
context += (
|
||||||
|
f" Resolved identity: {'; '.join(details)}. "
|
||||||
|
"Do not substitute a different company or ticker unless a tool "
|
||||||
|
"result explicitly disproves this resolved identity."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_crypto:
|
||||||
|
context += (
|
||||||
|
" Treat it as a crypto asset rather than a company, and do not "
|
||||||
|
"assume company fundamentals are available."
|
||||||
|
)
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def get_instrument_context_from_state(state: Mapping[str, Any]) -> str:
|
||||||
|
"""Return the instrument context for the current run.
|
||||||
|
|
||||||
|
Prefers the identity-resolved context computed once at run start and
|
||||||
|
stored on the state (see ``TradingAgentsGraph.resolve_instrument_context``).
|
||||||
|
Falls back to a ticker-only context — with no network lookup — when the
|
||||||
|
state was constructed without it (bare programmatic states, tests), so a
|
||||||
|
consumer is never forced to make a yfinance call mid-graph.
|
||||||
|
"""
|
||||||
|
context = state.get("instrument_context")
|
||||||
|
if isinstance(context, str) and context.strip():
|
||||||
|
return context
|
||||||
|
return build_instrument_context(
|
||||||
|
str(state["company_of_interest"]),
|
||||||
|
state.get("asset_type", "stock"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_msg_delete():
|
def create_msg_delete():
|
||||||
def delete_messages(state):
|
def delete_messages(state):
|
||||||
"""Clear messages and add placeholder for Anthropic compatibility"""
|
"""Clear messages and add a context-anchored placeholder.
|
||||||
messages = state["messages"]
|
|
||||||
|
|
||||||
# Remove all messages
|
The placeholder must not be a bare ``"Continue"``: some
|
||||||
|
OpenAI-compatible providers interpret that literally as the user task
|
||||||
|
and produce output about the word "continue" instead of analysing the
|
||||||
|
instrument (#888). Anchoring it to the resolved instrument context and
|
||||||
|
date keeps the next analyst on-task even if the provider treats the
|
||||||
|
placeholder as a standalone request.
|
||||||
|
"""
|
||||||
|
messages = state["messages"]
|
||||||
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||||
|
|
||||||
# Add a minimal placeholder message
|
instrument_context = get_instrument_context_from_state(state)
|
||||||
placeholder = HumanMessage(content="Continue")
|
trade_date = state.get("trade_date", "the requested date")
|
||||||
|
placeholder = HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"Proceed with your assigned analysis for this workflow. "
|
||||||
|
f"{instrument_context} The analysis date is {trade_date}."
|
||||||
|
)
|
||||||
|
)
|
||||||
return {"messages": removal_operations + [placeholder]}
|
return {"messages": removal_operations + [placeholder]}
|
||||||
|
|
||||||
return delete_messages
|
return delete_messages
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from langchain_core.tools import tool
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from tradingagents.dataflows.interface import route_to_vendor
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from langchain_core.tools import tool
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from tradingagents.dataflows.interface import route_to_vendor
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
@@ -74,4 +76,4 @@ def get_income_statement(
|
|||||||
Returns:
|
Returns:
|
||||||
str: A formatted report containing income statement data
|
str: A formatted report containing income statement data
|
||||||
"""
|
"""
|
||||||
return route_to_vendor("get_income_statement", ticker, freq, curr_date)
|
return route_to_vendor("get_income_statement", ticker, freq, curr_date)
|
||||||
|
|||||||
36
tradingagents/agents/utils/macro_data_tools.py
Normal file
36
tradingagents/agents/utils/macro_data_tools.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_macro_indicators(
|
||||||
|
indicator: Annotated[
|
||||||
|
str,
|
||||||
|
"Macro indicator: a friendly alias such as 'cpi', 'core_pce', "
|
||||||
|
"'unemployment', 'fed_funds_rate', '10y_treasury', 'yield_curve', "
|
||||||
|
"'real_gdp', 'vix', or a raw FRED series ID such as 'CPIAUCSL'.",
|
||||||
|
],
|
||||||
|
curr_date: Annotated[str, "Current date in yyyy-mm-dd format; the end of the window"],
|
||||||
|
look_back_days: Annotated[
|
||||||
|
int | None, "Trailing window length in days; omit for a 1-year window"
|
||||||
|
] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve a macroeconomic indicator time series from FRED (Federal Reserve
|
||||||
|
Economic Data): policy rates, Treasury yields, inflation, labor, and growth.
|
||||||
|
Returns the series title, units, frequency, the latest value, the change
|
||||||
|
over the window, and a recent observation table. Uses the configured
|
||||||
|
macro_data vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicator (str): Friendly alias or raw FRED series ID
|
||||||
|
curr_date (str): Current date in yyyy-mm-dd format
|
||||||
|
look_back_days (int): Trailing window length; omit for a 1-year window
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A formatted markdown report of the macro series
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_macro_indicators", indicator, curr_date, look_back_days)
|
||||||
23
tradingagents/agents/utils/market_data_validation_tools.py
Normal file
23
tradingagents/agents/utils/market_data_validation_tools.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from tradingagents.dataflows.market_data_validator import build_verified_market_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_verified_market_snapshot(
|
||||||
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
curr_date: Annotated[str, "the current trading date, YYYY-mm-dd"],
|
||||||
|
look_back_days: Annotated[
|
||||||
|
int, "number of recent trading rows to include for sanity-checking"
|
||||||
|
] = 30,
|
||||||
|
) -> str:
|
||||||
|
"""Deterministic verification snapshot for exact market-data claims.
|
||||||
|
|
||||||
|
Returns the latest OHLCV row on or before curr_date, common technical
|
||||||
|
indicators, and recent closes. Call this before making exact claims about
|
||||||
|
price levels, Bollinger bands, RSI, MACD, moving averages, support /
|
||||||
|
resistance, or historical comparisons, and treat it as the source of truth.
|
||||||
|
"""
|
||||||
|
return build_verified_market_snapshot(symbol, curr_date, look_back_days)
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Append-only markdown decision log for TradingAgents."""
|
"""Append-only markdown decision log for TradingAgents."""
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
from pathlib import Path
|
|
||||||
import re
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from tradingagents.agents.utils.rating import parse_rating
|
from tradingagents.agents.utils.rating import parse_rating
|
||||||
|
|
||||||
@@ -51,7 +50,7 @@ class TradingMemoryLog:
|
|||||||
|
|
||||||
# --- Read path (Phase A) ---
|
# --- Read path (Phase A) ---
|
||||||
|
|
||||||
def load_entries(self) -> List[dict]:
|
def load_entries(self) -> list[dict]:
|
||||||
"""Parse all entries from log. Returns list of dicts."""
|
"""Parse all entries from log. Returns list of dicts."""
|
||||||
if not self._log_path or not self._log_path.exists():
|
if not self._log_path or not self._log_path.exists():
|
||||||
return []
|
return []
|
||||||
@@ -64,7 +63,7 @@ class TradingMemoryLog:
|
|||||||
entries.append(parsed)
|
entries.append(parsed)
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
def get_pending_entries(self) -> List[dict]:
|
def get_pending_entries(self) -> list[dict]:
|
||||||
"""Return entries with outcome:pending (for Phase B)."""
|
"""Return entries with outcome:pending (for Phase B)."""
|
||||||
return [e for e in self.load_entries() if e.get("pending")]
|
return [e for e in self.load_entries() if e.get("pending")]
|
||||||
|
|
||||||
@@ -162,7 +161,7 @@ class TradingMemoryLog:
|
|||||||
tmp_path.write_text(new_text, encoding="utf-8")
|
tmp_path.write_text(new_text, encoding="utf-8")
|
||||||
tmp_path.replace(self._log_path)
|
tmp_path.replace(self._log_path)
|
||||||
|
|
||||||
def batch_update_with_outcomes(self, updates: List[dict]) -> None:
|
def batch_update_with_outcomes(self, updates: list[dict]) -> None:
|
||||||
"""Apply multiple outcome updates in a single read + atomic write.
|
"""Apply multiple outcome updates in a single read + atomic write.
|
||||||
|
|
||||||
Each element of updates must have keys: ticker, trade_date,
|
Each element of updates must have keys: ticker, trade_date,
|
||||||
@@ -218,7 +217,7 @@ class TradingMemoryLog:
|
|||||||
|
|
||||||
# --- Helpers ---
|
# --- Helpers ---
|
||||||
|
|
||||||
def _apply_rotation(self, blocks: List[str]) -> List[str]:
|
def _apply_rotation(self, blocks: list[str]) -> list[str]:
|
||||||
"""Drop oldest resolved blocks when their count exceeds max_entries.
|
"""Drop oldest resolved blocks when their count exceeds max_entries.
|
||||||
|
|
||||||
Pending blocks are always kept (they represent unprocessed work).
|
Pending blocks are always kept (they represent unprocessed work).
|
||||||
@@ -247,7 +246,7 @@ class TradingMemoryLog:
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
to_drop = resolved_count - self._max_entries
|
to_drop = resolved_count - self._max_entries
|
||||||
kept: List[str] = []
|
kept: list[str] = []
|
||||||
for block, is_resolved in decisions:
|
for block, is_resolved in decisions:
|
||||||
if is_resolved and to_drop > 0:
|
if is_resolved and to_drop > 0:
|
||||||
to_drop -= 1
|
to_drop -= 1
|
||||||
@@ -255,7 +254,7 @@ class TradingMemoryLog:
|
|||||||
kept.append(block)
|
kept.append(block)
|
||||||
return kept
|
return kept
|
||||||
|
|
||||||
def _parse_entry(self, raw: str) -> Optional[dict]:
|
def _parse_entry(self, raw: str) -> dict | None:
|
||||||
lines = raw.strip().splitlines()
|
lines = raw.strip().splitlines()
|
||||||
if not lines:
|
if not lines:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from typing import Annotated, Optional
|
|
||||||
from tradingagents.dataflows.interface import route_to_vendor
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_news(
|
def get_news(
|
||||||
ticker: Annotated[str, "Ticker symbol"],
|
ticker: Annotated[str, "Ticker symbol"],
|
||||||
@@ -23,8 +26,8 @@ def get_news(
|
|||||||
@tool
|
@tool
|
||||||
def get_global_news(
|
def get_global_news(
|
||||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||||
look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None,
|
look_back_days: Annotated[int | None, "Days to look back; omit to use the configured default"] = None,
|
||||||
limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None,
|
limit: Annotated[int | None, "Max articles to return; omit to use the configured default"] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve global news data.
|
Retrieve global news data.
|
||||||
|
|||||||
31
tradingagents/agents/utils/prediction_markets_tools.py
Normal file
31
tradingagents/agents/utils/prediction_markets_tools.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_prediction_markets(
|
||||||
|
topic: Annotated[
|
||||||
|
str,
|
||||||
|
"Event topic/keyword, e.g. 'Fed rate cut', 'recession 2026', "
|
||||||
|
"'US election', or a sector/company event.",
|
||||||
|
],
|
||||||
|
limit: Annotated[int | None, "Max markets to return; omit for a default of 6"] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve live, market-implied probabilities for forward-looking events from
|
||||||
|
prediction markets (Polymarket): Fed decisions, recession, elections,
|
||||||
|
geopolitics, crypto. Returns the most-traded open markets matching the
|
||||||
|
topic, each with its implied probability, traded volume, resolution date,
|
||||||
|
and recent move. Uses the configured prediction_markets vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic (str): Event keyword(s) to search
|
||||||
|
limit (int): Max markets to return; omit for a default of 6
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A formatted markdown report of matching prediction markets
|
||||||
|
"""
|
||||||
|
return route_to_vendor("get_prediction_markets", topic, limit)
|
||||||
@@ -12,11 +12,9 @@ Centralising it here avoids drift between those call sites.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
|
||||||
# Canonical, ordered 5-tier scale (most bullish to most bearish).
|
# Canonical, ordered 5-tier scale (most bullish to most bearish).
|
||||||
RATINGS_5_TIER: Tuple[str, ...] = (
|
RATINGS_5_TIER: tuple[str, ...] = (
|
||||||
"Buy", "Overweight", "Hold", "Underweight", "Sell",
|
"Buy", "Overweight", "Hold", "Underweight", "Sell",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ all three agents log the same warnings when fallback fires.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Optional, TypeVar
|
from collections.abc import Callable
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]:
|
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Any | None:
|
||||||
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
|
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
|
||||||
|
|
||||||
Logs a warning when the binding fails so the user understands the agent
|
Logs a warning when the binding fails so the user understands the agent
|
||||||
@@ -46,7 +47,7 @@ def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]
|
|||||||
|
|
||||||
|
|
||||||
def invoke_structured_or_freetext(
|
def invoke_structured_or_freetext(
|
||||||
structured_llm: Optional[Any],
|
structured_llm: Any | None,
|
||||||
plain_llm: Any,
|
plain_llm: Any,
|
||||||
prompt: Any,
|
prompt: Any,
|
||||||
render: Callable[[T], str],
|
render: Callable[[T], str],
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from langchain_core.tools import tool
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from tradingagents.dataflows.interface import route_to_vendor
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_indicators(
|
def get_indicators(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
@@ -29,4 +32,4 @@ def get_indicators(
|
|||||||
results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days))
|
results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
results.append(str(e))
|
results.append(str(e))
|
||||||
return "\n\n".join(results)
|
return "\n\n".join(results)
|
||||||
|
|||||||
@@ -1,5 +1,23 @@
|
|||||||
# Import functions from specialized modules
|
# Aggregates the per-category Alpha Vantage implementations into one module the
|
||||||
from .alpha_vantage_stock import get_stock
|
# vendor router imports from; the imports below are the public surface.
|
||||||
|
from .alpha_vantage_fundamentals import (
|
||||||
|
get_balance_sheet,
|
||||||
|
get_cashflow,
|
||||||
|
get_fundamentals,
|
||||||
|
get_income_statement,
|
||||||
|
)
|
||||||
from .alpha_vantage_indicator import get_indicator
|
from .alpha_vantage_indicator import get_indicator
|
||||||
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
|
from .alpha_vantage_news import get_global_news, get_insider_transactions, get_news
|
||||||
from .alpha_vantage_news import get_news, get_global_news, get_insider_transactions
|
from .alpha_vantage_stock import get_stock
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_balance_sheet",
|
||||||
|
"get_cashflow",
|
||||||
|
"get_fundamentals",
|
||||||
|
"get_income_statement",
|
||||||
|
"get_indicator",
|
||||||
|
"get_global_news",
|
||||||
|
"get_insider_transactions",
|
||||||
|
"get_news",
|
||||||
|
"get_stock",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,17 +1,37 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
import pandas as pd
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .errors import VendorNotConfiguredError, VendorRateLimitError
|
||||||
|
|
||||||
API_BASE_URL = "https://www.alphavantage.co/query"
|
API_BASE_URL = "https://www.alphavantage.co/query"
|
||||||
|
|
||||||
|
# Network timeout (seconds) so a stalled Alpha Vantage request can't hang the
|
||||||
|
# CLI/agents indefinitely (#990).
|
||||||
|
REQUEST_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
|
class AlphaVantageNotConfiguredError(VendorNotConfiguredError):
|
||||||
|
"""Raised when Alpha Vantage is selected but no API key is configured.
|
||||||
|
|
||||||
|
A VendorNotConfiguredError (and thus still a ValueError), so the routing
|
||||||
|
layer's "vendor unavailable" handling and existing ValueError callers both
|
||||||
|
keep working.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_api_key() -> str:
|
def get_api_key() -> str:
|
||||||
"""Retrieve the API key for Alpha Vantage from environment variables."""
|
"""Retrieve the API key for Alpha Vantage from environment variables."""
|
||||||
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
|
raise AlphaVantageNotConfiguredError(
|
||||||
|
"ALPHA_VANTAGE_API_KEY environment variable is not set."
|
||||||
|
)
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
def format_datetime_for_api(date_input) -> str:
|
def format_datetime_for_api(date_input) -> str:
|
||||||
@@ -29,19 +49,19 @@ def format_datetime_for_api(date_input) -> str:
|
|||||||
dt = datetime.strptime(date_input, "%Y-%m-%d %H:%M")
|
dt = datetime.strptime(date_input, "%Y-%m-%d %H:%M")
|
||||||
return dt.strftime("%Y%m%dT%H%M")
|
return dt.strftime("%Y%m%dT%H%M")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"Unsupported date format: {date_input}")
|
raise ValueError(f"Unsupported date format: {date_input}") from None
|
||||||
elif isinstance(date_input, datetime):
|
elif isinstance(date_input, datetime):
|
||||||
return date_input.strftime("%Y%m%dT%H%M")
|
return date_input.strftime("%Y%m%dT%H%M")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Date must be string or datetime object, got {type(date_input)}")
|
raise ValueError(f"Date must be string or datetime object, got {type(date_input)}")
|
||||||
|
|
||||||
class AlphaVantageRateLimitError(Exception):
|
class AlphaVantageRateLimitError(VendorRateLimitError):
|
||||||
"""Exception raised when Alpha Vantage API rate limit is exceeded."""
|
"""Raised when the Alpha Vantage API rate limit is exceeded."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _make_api_request(function_name: str, params: dict) -> dict | str:
|
def _make_api_request(function_name: str, params: dict) -> dict | str:
|
||||||
"""Helper function to make API requests and handle responses.
|
"""Helper function to make API requests and handle responses.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AlphaVantageRateLimitError: When API rate limit is exceeded
|
AlphaVantageRateLimitError: When API rate limit is exceeded
|
||||||
"""
|
"""
|
||||||
@@ -52,33 +72,42 @@ def _make_api_request(function_name: str, params: dict) -> dict | str:
|
|||||||
"apikey": get_api_key(),
|
"apikey": get_api_key(),
|
||||||
"source": "trading_agents",
|
"source": "trading_agents",
|
||||||
})
|
})
|
||||||
|
|
||||||
# Handle entitlement parameter if present in params or global variable
|
# Handle entitlement parameter if present in params or global variable
|
||||||
current_entitlement = globals().get('_current_entitlement')
|
current_entitlement = globals().get('_current_entitlement')
|
||||||
entitlement = api_params.get("entitlement") or current_entitlement
|
entitlement = api_params.get("entitlement") or current_entitlement
|
||||||
|
|
||||||
if entitlement:
|
if entitlement:
|
||||||
api_params["entitlement"] = entitlement
|
api_params["entitlement"] = entitlement
|
||||||
elif "entitlement" in api_params:
|
elif "entitlement" in api_params:
|
||||||
# Remove entitlement if it's None or empty
|
# Remove entitlement if it's None or empty
|
||||||
api_params.pop("entitlement", None)
|
api_params.pop("entitlement", None)
|
||||||
|
|
||||||
response = requests.get(API_BASE_URL, params=api_params)
|
response = requests.get(API_BASE_URL, params=api_params, timeout=REQUEST_TIMEOUT)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
|
|
||||||
# Check if response is JSON (error responses are typically JSON)
|
# Error responses are JSON; data responses are usually CSV (or data-keyed
|
||||||
|
# JSON). A non-JSON body is normal data.
|
||||||
try:
|
try:
|
||||||
response_json = json.loads(response_text)
|
response_json = json.loads(response_text)
|
||||||
# Check for rate limit error
|
|
||||||
if "Information" in response_json:
|
|
||||||
info_message = response_json["Information"]
|
|
||||||
if "rate limit" in info_message.lower() or "api key" in info_message.lower():
|
|
||||||
raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}")
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Response is not JSON (likely CSV data), which is normal
|
return response_text
|
||||||
pass
|
|
||||||
|
# Alpha Vantage reports problems via "Information" / "Note". Classify so a
|
||||||
|
# genuine rate limit and an invalid/missing key aren't conflated (#991):
|
||||||
|
# rate-limit phrasing is checked first because those notices also mention
|
||||||
|
# "API key" ("your API key ... 25 requests per day").
|
||||||
|
notice = response_json.get("Information") or response_json.get("Note")
|
||||||
|
if notice:
|
||||||
|
low = notice.lower()
|
||||||
|
if any(m in low for m in ("rate limit", "requests per day", "call frequency", "premium")):
|
||||||
|
raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {notice}")
|
||||||
|
if "api key" in low or "apikey" in low:
|
||||||
|
# Reuse the existing "not configured" error so a bad key surfaces as
|
||||||
|
# a real, actionable failure rather than a mislabeled rate limit (#991).
|
||||||
|
raise AlphaVantageNotConfiguredError(f"Alpha Vantage API key invalid or missing: {notice}")
|
||||||
|
|
||||||
return response_text
|
return response_text
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .alpha_vantage_common import _make_api_request
|
from .alpha_vantage_common import AlphaVantageNotConfiguredError, _make_api_request
|
||||||
|
|
||||||
|
|
||||||
def get_indicator(
|
def get_indicator(
|
||||||
symbol: str,
|
symbol: str,
|
||||||
@@ -25,6 +26,7 @@ def get_indicator(
|
|||||||
String containing indicator values and description
|
String containing indicator values and description
|
||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
supported_indicators = {
|
supported_indicators = {
|
||||||
@@ -98,21 +100,7 @@ def get_indicator(
|
|||||||
"series_type": series_type,
|
"series_type": series_type,
|
||||||
"datatype": "csv"
|
"datatype": "csv"
|
||||||
})
|
})
|
||||||
elif indicator == "macd":
|
elif indicator == "macd" or indicator == "macds" or indicator == "macdh":
|
||||||
data = _make_api_request("MACD", {
|
|
||||||
"symbol": symbol,
|
|
||||||
"interval": interval,
|
|
||||||
"series_type": series_type,
|
|
||||||
"datatype": "csv"
|
|
||||||
})
|
|
||||||
elif indicator == "macds":
|
|
||||||
data = _make_api_request("MACD", {
|
|
||||||
"symbol": symbol,
|
|
||||||
"interval": interval,
|
|
||||||
"series_type": series_type,
|
|
||||||
"datatype": "csv"
|
|
||||||
})
|
|
||||||
elif indicator == "macdh":
|
|
||||||
data = _make_api_request("MACD", {
|
data = _make_api_request("MACD", {
|
||||||
"symbol": symbol,
|
"symbol": symbol,
|
||||||
"interval": interval,
|
"interval": interval,
|
||||||
@@ -217,6 +205,11 @@ def get_indicator(
|
|||||||
|
|
||||||
return result_str
|
return result_str
|
||||||
|
|
||||||
|
except AlphaVantageNotConfiguredError:
|
||||||
|
# Vendor unavailable (no API key). Let it propagate so the router can
|
||||||
|
# fall back / emit the no-data sentinel instead of returning this as a
|
||||||
|
# successful-looking error string.
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting Alpha Vantage indicator data for {indicator}: {e}")
|
print(f"Error getting Alpha Vantage indicator data for {indicator}: {e}")
|
||||||
return f"Error retrieving {indicator} data: {str(e)}"
|
return f"Error retrieving {indicator} data: {str(e)}"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||||
|
|
||||||
|
|
||||||
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||||
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
||||||
|
|
||||||
@@ -68,4 +69,4 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str:
|
|||||||
"symbol": symbol,
|
"symbol": symbol,
|
||||||
}
|
}
|
||||||
|
|
||||||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
|
|
||||||
|
from .alpha_vantage_common import _filter_csv_by_date_range, _make_api_request
|
||||||
|
|
||||||
|
|
||||||
def get_stock(
|
def get_stock(
|
||||||
symbol: str,
|
symbol: str,
|
||||||
@@ -35,4 +37,4 @@ def get_stock(
|
|||||||
|
|
||||||
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
|
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
|
||||||
|
|
||||||
return _filter_csv_by_date_range(response, start_date, end_date)
|
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import tradingagents.default_config as default_config
|
import tradingagents.default_config as default_config
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
# Use default config but allow it to be overridden
|
||||||
_config: Optional[Dict] = None
|
_config: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
def initialize_config():
|
def initialize_config():
|
||||||
@@ -14,7 +13,7 @@ def initialize_config():
|
|||||||
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Dict):
|
def set_config(config: dict):
|
||||||
"""Update the configuration with custom values.
|
"""Update the configuration with custom values.
|
||||||
|
|
||||||
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
|
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
|
||||||
@@ -31,7 +30,7 @@ def set_config(config: Dict):
|
|||||||
_config[key] = value
|
_config[key] = value
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> Dict:
|
def get_config() -> dict:
|
||||||
"""Get the current configuration."""
|
"""Get the current configuration."""
|
||||||
if _config is None:
|
if _config is None:
|
||||||
initialize_config()
|
initialize_config()
|
||||||
|
|||||||
55
tradingagents/dataflows/errors.py
Normal file
55
tradingagents/dataflows/errors.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Vendor data-error taxonomy.
|
||||||
|
|
||||||
|
A single hierarchy so the routing layer reacts by *behavior*, not by vendor:
|
||||||
|
every condition where a vendor cannot return usable data derives from
|
||||||
|
``VendorError``, and the router catches the base types. A new vendor raises
|
||||||
|
these (or a thin vendor-named subclass) and needs no new ``except`` clause.
|
||||||
|
|
||||||
|
VendorError
|
||||||
|
├── NoMarketDataError no usable rows (empty result OR stale data)
|
||||||
|
├── VendorRateLimitError transient throttle -> skip to next vendor
|
||||||
|
└── VendorNotConfiguredError missing API key/config -> vendor unavailable
|
||||||
|
|
||||||
|
The number of types is the number of distinct router reactions, not the number
|
||||||
|
of human-describable causes: empty and stale data get identical handling, so
|
||||||
|
they share ``NoMarketDataError`` and differ only in the free-text ``detail``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class VendorError(Exception):
|
||||||
|
"""Base for any condition where a vendor could not return usable data."""
|
||||||
|
|
||||||
|
|
||||||
|
class NoMarketDataError(VendorError):
|
||||||
|
"""A vendor returned no usable rows for a symbol (empty result or stale data).
|
||||||
|
|
||||||
|
Carries both the symbol the user requested and the canonical symbol the
|
||||||
|
vendor was actually queried with, plus a free-text ``detail``, so callers
|
||||||
|
can build a clear message instead of emitting a vendor-specific empty
|
||||||
|
string into the data channel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, symbol: str, canonical: str | None = None, detail: str = ""):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.canonical = canonical or symbol
|
||||||
|
self.detail = detail
|
||||||
|
msg = f"No market data for {symbol!r}"
|
||||||
|
if canonical and canonical != symbol:
|
||||||
|
msg += f" (queried as {canonical!r})"
|
||||||
|
if detail:
|
||||||
|
msg += f": {detail}"
|
||||||
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
|
class VendorRateLimitError(VendorError):
|
||||||
|
"""A vendor throttled the request; the router skips to the next vendor."""
|
||||||
|
|
||||||
|
|
||||||
|
class VendorNotConfiguredError(VendorError, ValueError):
|
||||||
|
"""A vendor was selected but its API key/configuration is missing.
|
||||||
|
|
||||||
|
Also a ``ValueError`` so existing callers that catch ``ValueError`` keep
|
||||||
|
working while the routing layer can treat it as "vendor unavailable".
|
||||||
|
"""
|
||||||
217
tradingagents/dataflows/fred.py
Normal file
217
tradingagents/dataflows/fred.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""FRED (Federal Reserve Economic Data) macro vendor.
|
||||||
|
|
||||||
|
Fetches macroeconomic time series — policy rates, Treasury yields, inflation,
|
||||||
|
labor, growth — from the St. Louis Fed's free API. Used by the news analyst to
|
||||||
|
ground macro commentary in actual numbers rather than headlines alone.
|
||||||
|
|
||||||
|
A free API key (https://fred.stlouisfed.org/docs/api/api_key.html) is read from
|
||||||
|
``FRED_API_KEY``; if it is unset the vendor raises ``FredNotConfiguredError`` so
|
||||||
|
the routing layer treats it as "unavailable" rather than a hard crash.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .errors import VendorNotConfiguredError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FRED_API_BASE = "https://api.stlouisfed.org/fred"
|
||||||
|
|
||||||
|
# Network timeout (seconds) so a stalled request can't hang the agents,
|
||||||
|
# mirroring the Alpha Vantage client.
|
||||||
|
REQUEST_TIMEOUT = 30
|
||||||
|
|
||||||
|
# Default trailing window when the caller does not specify one. A year captures
|
||||||
|
# the trend and the year-over-year base for most monthly/quarterly series.
|
||||||
|
DEFAULT_LOOKBACK_DAYS = 365
|
||||||
|
|
||||||
|
# Rows cap for the rendered table: recent values matter most for a decision, and
|
||||||
|
# daily series (yields, VIX) over a long window would otherwise flood context.
|
||||||
|
MAX_ROWS = 40
|
||||||
|
|
||||||
|
# Curated human-friendly aliases -> FRED series IDs. Anything not listed is used
|
||||||
|
# verbatim as a raw FRED series ID, so power users are never limited to this set.
|
||||||
|
MACRO_SERIES = {
|
||||||
|
# Policy rate & Treasury yields
|
||||||
|
"fed_funds_rate": "FEDFUNDS",
|
||||||
|
"federal_funds_rate": "FEDFUNDS",
|
||||||
|
"fed_funds": "FEDFUNDS",
|
||||||
|
"2y_treasury": "DGS2",
|
||||||
|
"10y_treasury": "DGS10",
|
||||||
|
"30y_treasury": "DGS30",
|
||||||
|
"10y_2y_spread": "T10Y2Y",
|
||||||
|
"yield_curve": "T10Y2Y",
|
||||||
|
# Inflation
|
||||||
|
"cpi": "CPIAUCSL",
|
||||||
|
"core_cpi": "CPILFESL",
|
||||||
|
"pce": "PCEPI",
|
||||||
|
"core_pce": "PCEPILFE",
|
||||||
|
"inflation_expectations": "T10YIE",
|
||||||
|
# Growth & output
|
||||||
|
"real_gdp": "GDPC1",
|
||||||
|
"gdp": "GDP",
|
||||||
|
"industrial_production": "INDPRO",
|
||||||
|
# Labor
|
||||||
|
"unemployment_rate": "UNRATE",
|
||||||
|
"unemployment": "UNRATE",
|
||||||
|
"nonfarm_payrolls": "PAYEMS",
|
||||||
|
"payrolls": "PAYEMS",
|
||||||
|
"initial_claims": "ICSA",
|
||||||
|
# Money & markets
|
||||||
|
"m2": "M2SL",
|
||||||
|
"money_supply": "M2SL",
|
||||||
|
"vix": "VIXCLS",
|
||||||
|
"dollar_index": "DTWEXBGS",
|
||||||
|
# Sentiment & housing
|
||||||
|
"consumer_sentiment": "UMCSENT",
|
||||||
|
"housing_starts": "HOUST",
|
||||||
|
"retail_sales": "RSAFS",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FredNotConfiguredError(VendorNotConfiguredError):
|
||||||
|
"""Raised when FRED is selected but no API key is configured.
|
||||||
|
|
||||||
|
A VendorNotConfiguredError (and thus still a ValueError), so the routing
|
||||||
|
layer's "vendor unavailable" handling and existing ValueError callers both
|
||||||
|
keep working.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_key() -> str:
|
||||||
|
"""Retrieve the FRED API key from the environment."""
|
||||||
|
api_key = os.getenv("FRED_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise FredNotConfiguredError(
|
||||||
|
"FRED_API_KEY environment variable is not set. Get a free key at "
|
||||||
|
"https://fred.stlouisfed.org/docs/api/api_key.html."
|
||||||
|
)
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_series_id(indicator: str) -> str:
|
||||||
|
"""Map a friendly alias to a FRED series ID, or pass a raw ID through."""
|
||||||
|
key = indicator.strip().lower().replace(" ", "_").replace("-", "_")
|
||||||
|
if key in MACRO_SERIES:
|
||||||
|
return MACRO_SERIES[key]
|
||||||
|
# Not a known alias: treat the input as a raw FRED series ID (FRED IDs are
|
||||||
|
# conventionally uppercase, e.g. "DGS10", "CPIAUCSL").
|
||||||
|
return indicator.strip().upper()
|
||||||
|
|
||||||
|
|
||||||
|
def _request(path: str, params: dict) -> dict:
|
||||||
|
"""GET a FRED endpoint, surfacing FRED's JSON error body on a bad request."""
|
||||||
|
api_params = {**params, "api_key": get_api_key(), "file_type": "json"}
|
||||||
|
response = requests.get(
|
||||||
|
f"{FRED_API_BASE}/{path}", params=api_params, timeout=REQUEST_TIMEOUT
|
||||||
|
)
|
||||||
|
# FRED returns 400 with a JSON {"error_message": ...} for unknown series IDs
|
||||||
|
# or malformed params; turn that into a clear, actionable error.
|
||||||
|
if response.status_code == 400:
|
||||||
|
try:
|
||||||
|
message = response.json().get("error_message", response.text)
|
||||||
|
except ValueError:
|
||||||
|
message = response.text
|
||||||
|
raise ValueError(f"FRED request failed: {message}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def get_macro_data(
|
||||||
|
indicator: str,
|
||||||
|
curr_date: str,
|
||||||
|
look_back_days: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Fetch a FRED macroeconomic series as a formatted markdown report.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicator: A friendly alias (e.g. "cpi", "unemployment", "10y_treasury")
|
||||||
|
or a raw FRED series ID (e.g. "CPIAUCSL", "DGS10").
|
||||||
|
curr_date: End of the window (yyyy-mm-dd); no later observations are
|
||||||
|
returned, so a past date never leaks future data.
|
||||||
|
look_back_days: Trailing window length; ``None`` uses DEFAULT_LOOKBACK_DAYS.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A markdown report with the series title, units, frequency, the latest
|
||||||
|
value, the change over the window, and a recent observation table.
|
||||||
|
"""
|
||||||
|
if look_back_days is None:
|
||||||
|
look_back_days = DEFAULT_LOOKBACK_DAYS
|
||||||
|
|
||||||
|
end_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
start_date = (end_dt - timedelta(days=look_back_days)).strftime("%Y-%m-%d")
|
||||||
|
series_id = _resolve_series_id(indicator)
|
||||||
|
|
||||||
|
meta = _request("series", {"series_id": series_id}).get("seriess") or []
|
||||||
|
if not meta:
|
||||||
|
raise ValueError(
|
||||||
|
f"FRED series '{series_id}' not found. Pass a known alias "
|
||||||
|
f"(e.g. 'cpi', 'unemployment') or a valid FRED series ID."
|
||||||
|
)
|
||||||
|
info = meta[0]
|
||||||
|
title = info.get("title", series_id)
|
||||||
|
units = info.get("units_short") or info.get("units", "")
|
||||||
|
frequency = info.get("frequency", "")
|
||||||
|
seasonal = info.get("seasonal_adjustment_short", "")
|
||||||
|
|
||||||
|
observations = _request(
|
||||||
|
"series/observations",
|
||||||
|
{
|
||||||
|
"series_id": series_id,
|
||||||
|
"observation_start": start_date,
|
||||||
|
"observation_end": curr_date,
|
||||||
|
"sort_order": "asc",
|
||||||
|
},
|
||||||
|
).get("observations", [])
|
||||||
|
|
||||||
|
# FRED encodes a missing observation as ".".
|
||||||
|
points = [
|
||||||
|
(o["date"], o["value"])
|
||||||
|
for o in observations
|
||||||
|
if o.get("value") not in (".", None, "")
|
||||||
|
]
|
||||||
|
|
||||||
|
header = (
|
||||||
|
f"## FRED: {title} ({series_id})\n"
|
||||||
|
f"- Units: {units}\n"
|
||||||
|
f"- Frequency: {frequency}"
|
||||||
|
f"{f' ({seasonal})' if seasonal else ''}\n"
|
||||||
|
f"- Window: {start_date} to {curr_date}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not points:
|
||||||
|
return header + (
|
||||||
|
f"\nNo observations for {series_id} in this window. The series may "
|
||||||
|
f"report less frequently than the window length; widen look_back_days."
|
||||||
|
)
|
||||||
|
|
||||||
|
first_date, first_val = points[0]
|
||||||
|
last_date, last_val = points[-1]
|
||||||
|
try:
|
||||||
|
delta = float(last_val) - float(first_val)
|
||||||
|
base = float(first_val)
|
||||||
|
pct = f" ({delta / base * 100:+.2f}%)" if base != 0 else ""
|
||||||
|
summary = (
|
||||||
|
f"\n**Latest:** {last_val} ({last_date}) | "
|
||||||
|
f"**Change over window:** {delta:+.2f}{pct} "
|
||||||
|
f"from {first_val} ({first_date})\n"
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
summary = f"\n**Latest:** {last_val} ({last_date})\n"
|
||||||
|
|
||||||
|
shown = points
|
||||||
|
note = ""
|
||||||
|
if len(points) > MAX_ROWS:
|
||||||
|
shown = points[-MAX_ROWS:]
|
||||||
|
note = f"\n_(showing the most recent {MAX_ROWS} of {len(points)} observations)_\n"
|
||||||
|
|
||||||
|
table = (
|
||||||
|
"\n| Date | Value |\n| --- | --- |\n"
|
||||||
|
+ "\n".join(f"| {d} | {v} |" for d, v in shown)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return header + summary + note + table
|
||||||
@@ -1,31 +1,36 @@
|
|||||||
from typing import Annotated
|
import logging
|
||||||
|
|
||||||
# Import from vendor-specific modules
|
|
||||||
from .y_finance import (
|
|
||||||
get_YFin_data_online,
|
|
||||||
get_stock_stats_indicators_window,
|
|
||||||
get_fundamentals as get_yfinance_fundamentals,
|
|
||||||
get_balance_sheet as get_yfinance_balance_sheet,
|
|
||||||
get_cashflow as get_yfinance_cashflow,
|
|
||||||
get_income_statement as get_yfinance_income_statement,
|
|
||||||
get_insider_transactions as get_yfinance_insider_transactions,
|
|
||||||
)
|
|
||||||
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
|
|
||||||
from .alpha_vantage import (
|
from .alpha_vantage import (
|
||||||
get_stock as get_alpha_vantage_stock,
|
|
||||||
get_indicator as get_alpha_vantage_indicator,
|
|
||||||
get_fundamentals as get_alpha_vantage_fundamentals,
|
|
||||||
get_balance_sheet as get_alpha_vantage_balance_sheet,
|
get_balance_sheet as get_alpha_vantage_balance_sheet,
|
||||||
get_cashflow as get_alpha_vantage_cashflow,
|
get_cashflow as get_alpha_vantage_cashflow,
|
||||||
|
get_fundamentals as get_alpha_vantage_fundamentals,
|
||||||
|
get_global_news as get_alpha_vantage_global_news,
|
||||||
get_income_statement as get_alpha_vantage_income_statement,
|
get_income_statement as get_alpha_vantage_income_statement,
|
||||||
|
get_indicator as get_alpha_vantage_indicator,
|
||||||
get_insider_transactions as get_alpha_vantage_insider_transactions,
|
get_insider_transactions as get_alpha_vantage_insider_transactions,
|
||||||
get_news as get_alpha_vantage_news,
|
get_news as get_alpha_vantage_news,
|
||||||
get_global_news as get_alpha_vantage_global_news,
|
get_stock as get_alpha_vantage_stock,
|
||||||
)
|
)
|
||||||
from .alpha_vantage_common import AlphaVantageRateLimitError
|
|
||||||
|
|
||||||
# Configuration and routing logic
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
|
from .errors import (
|
||||||
|
NoMarketDataError,
|
||||||
|
VendorNotConfiguredError,
|
||||||
|
VendorRateLimitError,
|
||||||
|
)
|
||||||
|
from .fred import get_macro_data as get_fred_macro_data
|
||||||
|
from .polymarket import get_prediction_markets as get_polymarket_prediction_markets
|
||||||
|
from .y_finance import (
|
||||||
|
get_balance_sheet as get_yfinance_balance_sheet,
|
||||||
|
get_cashflow as get_yfinance_cashflow,
|
||||||
|
get_fundamentals as get_yfinance_fundamentals,
|
||||||
|
get_income_statement as get_yfinance_income_statement,
|
||||||
|
get_insider_transactions as get_yfinance_insider_transactions,
|
||||||
|
get_stock_stats_indicators_window,
|
||||||
|
get_YFin_data_online,
|
||||||
|
)
|
||||||
|
from .yfinance_news import get_global_news_yfinance, get_news_yfinance
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Tools organized by category
|
# Tools organized by category
|
||||||
TOOLS_CATEGORIES = {
|
TOOLS_CATEGORIES = {
|
||||||
@@ -57,11 +62,25 @@ TOOLS_CATEGORIES = {
|
|||||||
"get_global_news",
|
"get_global_news",
|
||||||
"get_insider_transactions",
|
"get_insider_transactions",
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"macro_data": {
|
||||||
|
"description": "Macroeconomic indicators (rates, inflation, labor, growth)",
|
||||||
|
"tools": [
|
||||||
|
"get_macro_indicators",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"prediction_markets": {
|
||||||
|
"description": "Market-implied probabilities for forward-looking events",
|
||||||
|
"tools": [
|
||||||
|
"get_prediction_markets",
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VENDOR_LIST = [
|
VENDOR_LIST = [
|
||||||
"yfinance",
|
"yfinance",
|
||||||
|
"fred",
|
||||||
|
"polymarket",
|
||||||
"alpha_vantage",
|
"alpha_vantage",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -107,6 +126,14 @@ VENDOR_METHODS = {
|
|||||||
"alpha_vantage": get_alpha_vantage_insider_transactions,
|
"alpha_vantage": get_alpha_vantage_insider_transactions,
|
||||||
"yfinance": get_yfinance_insider_transactions,
|
"yfinance": get_yfinance_insider_transactions,
|
||||||
},
|
},
|
||||||
|
# macro_data
|
||||||
|
"get_macro_indicators": {
|
||||||
|
"fred": get_fred_macro_data,
|
||||||
|
},
|
||||||
|
# prediction_markets
|
||||||
|
"get_prediction_markets": {
|
||||||
|
"polymarket": get_polymarket_prediction_markets,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_category_for_method(method: str) -> str:
|
def get_category_for_method(method: str) -> str:
|
||||||
@@ -140,23 +167,81 @@ def route_to_vendor(method: str, *args, **kwargs):
|
|||||||
if method not in VENDOR_METHODS:
|
if method not in VENDOR_METHODS:
|
||||||
raise ValueError(f"Method '{method}' not supported")
|
raise ValueError(f"Method '{method}' not supported")
|
||||||
|
|
||||||
# Build fallback chain: primary vendors first, then remaining available vendors
|
|
||||||
all_available_vendors = list(VENDOR_METHODS[method].keys())
|
all_available_vendors = list(VENDOR_METHODS[method].keys())
|
||||||
fallback_vendors = primary_vendors.copy()
|
|
||||||
for vendor in all_available_vendors:
|
|
||||||
if vendor not in fallback_vendors:
|
|
||||||
fallback_vendors.append(vendor)
|
|
||||||
|
|
||||||
for vendor in fallback_vendors:
|
# The configured vendor list IS the chain: we do NOT silently fall back to
|
||||||
if vendor not in VENDOR_METHODS[method]:
|
# vendors the user did not choose (#988/#289) — that returned data from an
|
||||||
continue
|
# unexpected source and caused cross-vendor inconsistencies. For multi-vendor
|
||||||
|
# fallback, list them in order, e.g. data_vendors="yfinance,alpha_vantage".
|
||||||
|
# The "default" sentinel (no explicit config) uses all available vendors.
|
||||||
|
explicit = [v for v in primary_vendors if v and v != "default"]
|
||||||
|
if explicit:
|
||||||
|
vendor_chain = [v for v in explicit if v in VENDOR_METHODS[method]]
|
||||||
|
if not vendor_chain:
|
||||||
|
raise ValueError(
|
||||||
|
f"Configured vendor(s) {explicit} not available for '{method}'. "
|
||||||
|
f"Available: {all_available_vendors}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vendor_chain = all_available_vendors
|
||||||
|
|
||||||
|
last_no_data: NoMarketDataError | None = None
|
||||||
|
first_error: Exception | None = None
|
||||||
|
for vendor in vendor_chain:
|
||||||
vendor_impl = VENDOR_METHODS[method][vendor]
|
vendor_impl = VENDOR_METHODS[method][vendor]
|
||||||
impl_func = vendor_impl[0] if isinstance(vendor_impl, list) else vendor_impl
|
impl_func = vendor_impl[0] if isinstance(vendor_impl, list) else vendor_impl
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return impl_func(*args, **kwargs)
|
return impl_func(*args, **kwargs)
|
||||||
except AlphaVantageRateLimitError:
|
except VendorRateLimitError:
|
||||||
continue # Only rate limits trigger fallback
|
logger.warning("Vendor %r rate-limited for %s; trying next vendor.", vendor, method)
|
||||||
|
continue
|
||||||
|
except VendorNotConfiguredError as e:
|
||||||
|
logger.warning("Vendor %r not configured for %s; trying next vendor.", vendor, method)
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e # Surface it if no other vendor can serve the call.
|
||||||
|
continue
|
||||||
|
except NoMarketDataError as e:
|
||||||
|
last_no_data = e # No data here; another configured vendor may have it
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# Don't let one vendor's failure crash the call when another can
|
||||||
|
# serve it, but never swallow silently: a broken primary must be
|
||||||
|
# visible in the logs (#989), not hidden behind a fallback's verdict.
|
||||||
|
logger.warning("Vendor %r failed for %s: %s", vendor, method, e)
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
continue
|
||||||
|
|
||||||
raise RuntimeError(f"No available vendor for '{method}'")
|
# If any vendor reported "no data", the symbol is genuinely unavailable.
|
||||||
|
# Return one explicit, instructive sentinel rather than a vendor-specific
|
||||||
|
# empty string, so the agent reports "unavailable" instead of inventing a
|
||||||
|
# value. This takes precedence over incidental fallback errors.
|
||||||
|
if last_no_data is not None:
|
||||||
|
if first_error is not None:
|
||||||
|
# A vendor also hit a real error; surface it in logs so the no-data
|
||||||
|
# verdict can't hide a broken primary (network/auth/etc.).
|
||||||
|
logger.warning(
|
||||||
|
"Returning NO_DATA for %s, but a vendor errored earlier: %s",
|
||||||
|
method, first_error,
|
||||||
|
)
|
||||||
|
sym = last_no_data.symbol
|
||||||
|
canonical = last_no_data.canonical
|
||||||
|
resolved = "" if canonical == sym else f" (resolved to '{canonical}')"
|
||||||
|
# Surface the typed error's detail (e.g. "latest row is 2025-06-11 ...
|
||||||
|
# stale") so the agent sees the specific reason — invalid symbol, no
|
||||||
|
# coverage, or stale data — not just a generic "unavailable".
|
||||||
|
reason = f" ({last_no_data.detail})" if last_no_data.detail else ""
|
||||||
|
return (
|
||||||
|
f"NO_DATA_AVAILABLE: No usable market data for '{sym}'{resolved} from "
|
||||||
|
f"any configured vendor{reason}. The symbol may be invalid, delisted, "
|
||||||
|
f"not covered, or the vendor returned stale data. Do not estimate or "
|
||||||
|
f"fabricate values — report that data is unavailable for this symbol."
|
||||||
|
)
|
||||||
|
|
||||||
|
# No vendor returned data and none reported clean "no data" — surface the
|
||||||
|
# first real error (e.g. the primary vendor's network failure).
|
||||||
|
if first_error is not None:
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
raise RuntimeError(f"No available vendor for '{method}'")
|
||||||
|
|||||||
123
tradingagents/dataflows/market_data_validator.py
Normal file
123
tradingagents/dataflows/market_data_validator.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""Deterministic market-data verification snapshot.
|
||||||
|
|
||||||
|
The market analyst is an LLM that can confabulate exact numbers — citing a
|
||||||
|
Bollinger band or a "historically validated bounce" that the underlying data
|
||||||
|
doesn't support (#830). This module computes a ground-truth snapshot (latest
|
||||||
|
OHLCV row on or before the analysis date, common indicators, recent closes)
|
||||||
|
the analyst is told to treat as the source of truth for any exact numeric
|
||||||
|
claim. Deterministic, no LLM involved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from stockstats import wrap
|
||||||
|
|
||||||
|
from tradingagents.dataflows.stockstats_utils import load_ohlcv
|
||||||
|
|
||||||
|
# A fixed, common indicator set so the snapshot is the same shape every run.
|
||||||
|
DEFAULT_SNAPSHOT_INDICATORS: tuple[str, ...] = (
|
||||||
|
"close_10_ema", "close_50_sma", "close_200_sma",
|
||||||
|
"rsi", "boll", "boll_ub", "boll_lb",
|
||||||
|
"macd", "macds", "macdh", "atr",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _verified_rows(symbol: str, curr_date: str) -> pd.DataFrame:
|
||||||
|
"""OHLCV on or before curr_date, date-sorted. Raises if nothing usable.
|
||||||
|
|
||||||
|
``load_ohlcv`` already normalizes the Date column and filters out
|
||||||
|
look-ahead rows, but we re-apply the cutoff defensively — this is a
|
||||||
|
verification path, so it must not trust its input to be pre-filtered.
|
||||||
|
"""
|
||||||
|
data = load_ohlcv(symbol, curr_date)
|
||||||
|
if data is None or data.empty:
|
||||||
|
raise ValueError(f"No OHLCV data available for {symbol}.")
|
||||||
|
|
||||||
|
df = data.copy()
|
||||||
|
df["Date"] = pd.to_datetime(df["Date"], errors="coerce")
|
||||||
|
df = df.dropna(subset=["Date"])
|
||||||
|
df = df[df["Date"] <= pd.to_datetime(curr_date)].sort_values("Date")
|
||||||
|
if df.empty:
|
||||||
|
raise ValueError(f"No OHLCV rows on or before {curr_date} for {symbol}.")
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt(value) -> str:
|
||||||
|
if value is None or pd.isna(value):
|
||||||
|
return "N/A"
|
||||||
|
if isinstance(value, pd.Timestamp):
|
||||||
|
return value.strftime("%Y-%m-%d")
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return str(value)
|
||||||
|
if isinstance(value, (int,)):
|
||||||
|
return str(value)
|
||||||
|
if isinstance(value, float):
|
||||||
|
return f"{value:.2f}"
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def build_verified_market_snapshot(
|
||||||
|
symbol: str,
|
||||||
|
curr_date: str,
|
||||||
|
look_back_days: int = 30,
|
||||||
|
indicators: Iterable[str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Render a ground-truth snapshot: latest OHLCV row, indicators, recent closes."""
|
||||||
|
# `df` keeps the original capitalized OHLCV columns (Open/High/Low/Close/
|
||||||
|
# Volume); stockstats `wrap()` lowercases columns and adds indicator
|
||||||
|
# columns, so read raw prices from `df` and indicators from `stock_df`.
|
||||||
|
df = _verified_rows(symbol, curr_date)
|
||||||
|
stock_df = wrap(df.copy())
|
||||||
|
|
||||||
|
selected = tuple(indicators or DEFAULT_SNAPSHOT_INDICATORS)
|
||||||
|
indicator_values: dict[str, str] = {}
|
||||||
|
for name in selected:
|
||||||
|
try:
|
||||||
|
stock_df[name] # triggers stockstats calculation
|
||||||
|
indicator_values[name] = _fmt(stock_df.iloc[-1][name])
|
||||||
|
except Exception as exc: # noqa: BLE001 — one bad indicator shouldn't sink the snapshot
|
||||||
|
indicator_values[name] = f"N/A ({type(exc).__name__})"
|
||||||
|
|
||||||
|
latest = df.iloc[-1]
|
||||||
|
latest_date = _fmt(latest["Date"])
|
||||||
|
window = max(1, min(int(look_back_days), 30))
|
||||||
|
recent = df.tail(window)
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"## Verified market data snapshot for {symbol.upper()}",
|
||||||
|
"",
|
||||||
|
f"- Requested analysis date: {curr_date}",
|
||||||
|
f"- Latest trading row used: {latest_date}",
|
||||||
|
"- Rows after the requested analysis date are excluded before verification.",
|
||||||
|
"",
|
||||||
|
"### Latest verified OHLCV row",
|
||||||
|
"",
|
||||||
|
"| Field | Value |",
|
||||||
|
"|---|---:|",
|
||||||
|
]
|
||||||
|
for field in ("Open", "High", "Low", "Close", "Volume"):
|
||||||
|
lines.append(f"| {field} | {_fmt(latest.get(field))} |")
|
||||||
|
|
||||||
|
lines += ["", "### Verified technical indicators (latest row)", "",
|
||||||
|
"| Indicator | Value |", "|---|---:|"]
|
||||||
|
for name, value in indicator_values.items():
|
||||||
|
lines.append(f"| {name} | {value} |")
|
||||||
|
|
||||||
|
lines += ["", f"### Recent verified closes (last {len(recent)} rows)", "",
|
||||||
|
"| Date | Close |", "|---|---:|"]
|
||||||
|
for _, row in recent.iterrows():
|
||||||
|
lines.append(f"| {_fmt(row['Date'])} | {_fmt(row.get('Close'))} |")
|
||||||
|
|
||||||
|
lines += [
|
||||||
|
"",
|
||||||
|
"Use this snapshot as the source of truth for exact OHLCV, price-level, "
|
||||||
|
"and indicator-value claims. If another tool output conflicts with it, "
|
||||||
|
"flag the discrepancy rather than inventing a reconciled number. Do not "
|
||||||
|
"claim historical validation, support/resistance bounces, or exact "
|
||||||
|
"percentage moves unless directly supported by tool output with concrete "
|
||||||
|
"dates and prices.",
|
||||||
|
]
|
||||||
|
return "\n".join(lines)
|
||||||
139
tradingagents/dataflows/polymarket.py
Normal file
139
tradingagents/dataflows/polymarket.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Polymarket prediction-market vendor.
|
||||||
|
|
||||||
|
Surfaces live, market-implied probabilities for forward-looking events (Fed
|
||||||
|
decisions, recession, elections, geopolitics, crypto) to the news analyst, as a
|
||||||
|
complement to news (what happened) and FRED macro data (where things stand):
|
||||||
|
what the crowd actually prices to happen next.
|
||||||
|
|
||||||
|
Uses Polymarket's public Gamma API (https://gamma-api.polymarket.com) — no key,
|
||||||
|
no auth. Each market's ``outcomePrices`` are the implied probabilities of its
|
||||||
|
outcomes (a "Yes" at 0.76 means the market prices a 76% chance).
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GAMMA_BASE = "https://gamma-api.polymarket.com"
|
||||||
|
|
||||||
|
# Network timeout (seconds), consistent with the other vendors.
|
||||||
|
REQUEST_TIMEOUT = 30
|
||||||
|
|
||||||
|
# Default number of markets to return, ranked by traded volume.
|
||||||
|
DEFAULT_LIMIT = 6
|
||||||
|
|
||||||
|
|
||||||
|
def _request(path: str, params: dict) -> dict:
|
||||||
|
response = requests.get(
|
||||||
|
f"{GAMMA_BASE}/{path}", params=params, timeout=REQUEST_TIMEOUT
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_list(value) -> list:
|
||||||
|
"""Gamma encodes ``outcomes``/``outcomePrices`` as JSON-string arrays."""
|
||||||
|
if isinstance(value, list):
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _is_forward_looking(market: dict, now: datetime) -> bool:
|
||||||
|
"""Keep only open markets that resolve in the future.
|
||||||
|
|
||||||
|
``closed`` is the reliable resolved flag (``active`` stays True even for
|
||||||
|
settled markets), and a past ``endDate`` means the event already resolved —
|
||||||
|
either way it is not a forward-looking signal.
|
||||||
|
"""
|
||||||
|
if market.get("closed"):
|
||||||
|
return False
|
||||||
|
end_date = market.get("endDate")
|
||||||
|
if end_date:
|
||||||
|
try:
|
||||||
|
if datetime.fromisoformat(end_date.replace("Z", "+00:00")) < now:
|
||||||
|
return False
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return bool(_parse_json_list(market.get("outcomePrices"))) and bool(
|
||||||
|
_parse_json_list(market.get("outcomes"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prediction_markets(topic: str, limit: int | None = None) -> str:
|
||||||
|
"""Return live prediction-market probabilities for an event topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Event keyword(s), e.g. "Fed rate cut", "recession 2026",
|
||||||
|
"US election", or a sector/company event.
|
||||||
|
limit: Max markets to return (ranked by traded volume); ``None`` uses
|
||||||
|
DEFAULT_LIMIT.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A markdown report of the most-traded open markets matching the topic,
|
||||||
|
each with its implied probability, traded volume, resolution date, and
|
||||||
|
recent (1-week) move.
|
||||||
|
"""
|
||||||
|
if limit is None:
|
||||||
|
limit = DEFAULT_LIMIT
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = _request("public-search", {"q": topic, "limit_per_type": 20})
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.warning("Polymarket search failed for %r: %s", topic, e)
|
||||||
|
return (
|
||||||
|
f"Polymarket data is currently unavailable (network error: {e}). "
|
||||||
|
f"Proceed without prediction-market signal for '{topic}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
candidates = [
|
||||||
|
m
|
||||||
|
for event in data.get("events", [])
|
||||||
|
for m in event.get("markets", [])
|
||||||
|
if _is_forward_looking(m, now)
|
||||||
|
]
|
||||||
|
candidates.sort(key=lambda m: m.get("volumeNum") or 0, reverse=True)
|
||||||
|
|
||||||
|
header = (
|
||||||
|
f'## Polymarket prediction markets: "{topic}"\n'
|
||||||
|
f"Live, market-implied probabilities (higher traded volume = deeper, "
|
||||||
|
f"more reliable). A probability is the crowd's priced odds of the event, "
|
||||||
|
f"not a forecast you should take as certain.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return header + (
|
||||||
|
f"No open prediction markets matched '{topic}'. Polymarket coverage "
|
||||||
|
f"is concentrated in macro, political, geopolitical, and crypto "
|
||||||
|
f"events; a specific equity may have none."
|
||||||
|
)
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for m in candidates[:limit]:
|
||||||
|
prices = _parse_json_list(m.get("outcomePrices"))
|
||||||
|
outcomes = _parse_json_list(m.get("outcomes"))
|
||||||
|
try:
|
||||||
|
prob = float(prices[0])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
continue
|
||||||
|
label = outcomes[0] if outcomes else "Yes"
|
||||||
|
volume = m.get("volumeNum") or 0
|
||||||
|
end_date = (m.get("endDate") or "")[:10]
|
||||||
|
wk = m.get("oneWeekPriceChange")
|
||||||
|
wk_str = (
|
||||||
|
f", 1-week {wk * 100:+.1f}pp"
|
||||||
|
if isinstance(wk, (int, float)) and wk
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
f"- **{m.get('question')}** — {label} {prob:.0%} "
|
||||||
|
f"(${volume:,.0f} volume, resolves {end_date}{wk_str})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return header + "\n".join(lines) + "\n"
|
||||||
@@ -1,29 +1,45 @@
|
|||||||
"""Reddit search fetcher for ticker-specific discussion posts.
|
"""Reddit search fetcher for ticker-specific discussion posts.
|
||||||
|
|
||||||
Uses Reddit's public JSON endpoints (``reddit.com/r/{sub}/search.json``)
|
Default path is Reddit's public Atom/RSS search feed
|
||||||
which do not require an API key. Public throughput is ~10 requests per
|
(``reddit.com/r/{sub}/search.rss``). The richer JSON search endpoint
|
||||||
minute per IP, well within budget for a single agent run that queries
|
(``/search.json``) is reliably WAF-blocked (``HTTP 403``) for public clients
|
||||||
a handful of finance subreddits per ticker.
|
(issue #862), and probing it on every call only doubled our request volume
|
||||||
|
against Reddit's per-IP rate limit — tripping ``429`` on the RSS fallback — so
|
||||||
|
it is kept (``_fetch_subreddit_json``) but not used by default. On a 429 we back
|
||||||
|
off once (honouring ``Retry-After``). RSS lacks score / comment counts, so those
|
||||||
|
posts are marked and the formatter omits the metrics rather than printing fake
|
||||||
|
zeros.
|
||||||
|
|
||||||
Returns formatted plaintext blocks ready for prompt injection. Degrades
|
No API key required. Returns formatted plaintext blocks ready for prompt
|
||||||
gracefully — returns a placeholder string rather than raising, so callers
|
injection and degrades gracefully — returns a placeholder string rather than
|
||||||
never have to special-case missing data.
|
raising, so callers never special-case missing data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import html
|
||||||
|
import http.client
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Iterable
|
import xml.etree.ElementTree as ET
|
||||||
from urllib.error import HTTPError, URLError
|
from collections.abc import Iterable
|
||||||
|
from datetime import datetime
|
||||||
|
from urllib.error import HTTPError
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_API = "https://www.reddit.com/r/{sub}/search.json?{qs}"
|
_API = "https://www.reddit.com/r/{sub}/search.json?{qs}"
|
||||||
|
_RSS = "https://www.reddit.com/r/{sub}/search.rss?{qs}"
|
||||||
|
# A descriptive, identified User-Agent (per Reddit's API etiquette). Reddit
|
||||||
|
# blocks generic/anonymous tokens like bare "Mozilla/5.0" or "curl/…" but
|
||||||
|
# serves this one on both endpoints; the RSS feed accepts it even when the
|
||||||
|
# JSON search endpoint 403s, so no browser-spoofing is needed.
|
||||||
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
|
_UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
|
||||||
|
_ATOM_NS = {"atom": "http://www.w3.org/2005/Atom"}
|
||||||
|
|
||||||
# Default subreddits ordered roughly by signal density for ticker-specific
|
# Default subreddits ordered roughly by signal density for ticker-specific
|
||||||
# discussion. wallstreetbets has the most volume but most noise; stocks /
|
# discussion. wallstreetbets has the most volume but most noise; stocks /
|
||||||
@@ -31,29 +47,143 @@ _UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)"
|
|||||||
DEFAULT_SUBREDDITS = ("wallstreetbets", "stocks", "investing")
|
DEFAULT_SUBREDDITS = ("wallstreetbets", "stocks", "investing")
|
||||||
|
|
||||||
|
|
||||||
|
def _search_qs(ticker: str, limit: int) -> str:
|
||||||
|
return urlencode({
|
||||||
|
"q": ticker,
|
||||||
|
"restrict_sr": "on",
|
||||||
|
"sort": "new",
|
||||||
|
"t": "week", # last 7 days
|
||||||
|
"limit": limit,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _iso_to_timestamp(iso_str: str | None) -> float | None:
|
||||||
|
"""Parse an Atom ``published`` timestamp to a UTC epoch, or None."""
|
||||||
|
if not iso_str:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
normalized = iso_str[:-1] + "+00:00" if iso_str.endswith("Z") else iso_str
|
||||||
|
return datetime.fromisoformat(normalized).timestamp()
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(content: str) -> str:
|
||||||
|
"""Reduce the HTML body Reddit embeds in an Atom entry to plain text."""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
# Reddit wraps the real selftext between SC_OFF / SC_ON markers.
|
||||||
|
if "<!-- SC_OFF -->" in content and "<!-- SC_ON -->" in content:
|
||||||
|
content = content.split("<!-- SC_OFF -->")[1].split("<!-- SC_ON -->")[0]
|
||||||
|
text = re.sub(r"<[^>]+>", " ", content)
|
||||||
|
return " ".join(html.unescape(text).split())
|
||||||
|
|
||||||
|
|
||||||
|
def _retry_after_seconds(exc: HTTPError) -> float | None:
|
||||||
|
"""Seconds to wait from a 429's ``Retry-After`` header, capped at 30s."""
|
||||||
|
try:
|
||||||
|
val = exc.headers.get("Retry-After") if getattr(exc, "headers", None) else None
|
||||||
|
return min(float(val), 30.0) if val else None
|
||||||
|
except (ValueError, TypeError, AttributeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_subreddit_rss(
|
||||||
|
ticker: str,
|
||||||
|
sub: str,
|
||||||
|
limit: int,
|
||||||
|
timeout: float,
|
||||||
|
_retry: bool = True,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Default path: parse the public Atom search feed for a subreddit.
|
||||||
|
|
||||||
|
Carries no score / comment counts, so those fields are left None and the
|
||||||
|
post is tagged ``source="rss"`` for honest display. On a 429 (Reddit's
|
||||||
|
per-IP rate limit) we back off once — honouring ``Retry-After`` when
|
||||||
|
present — before giving up, so a transient burst doesn't blank the feed.
|
||||||
|
"""
|
||||||
|
url = _RSS.format(sub=sub, qs=_search_qs(ticker, limit))
|
||||||
|
req = Request(url, headers={"User-Agent": _UA})
|
||||||
|
try:
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
root = ET.fromstring(resp.read())
|
||||||
|
except HTTPError as exc:
|
||||||
|
if exc.code == 429 and _retry:
|
||||||
|
wait = _retry_after_seconds(exc) or 5.0
|
||||||
|
logger.warning(
|
||||||
|
"Reddit RSS 429 for r/%s · %s — backing off %.1fs then retrying once",
|
||||||
|
sub, ticker, wait,
|
||||||
|
)
|
||||||
|
time.sleep(wait)
|
||||||
|
return _fetch_subreddit_rss(ticker, sub, limit, timeout, _retry=False)
|
||||||
|
logger.warning("Reddit RSS fetch failed for r/%s · %s: %s", sub, ticker, exc)
|
||||||
|
return []
|
||||||
|
except (OSError, http.client.HTTPException, ET.ParseError) as exc:
|
||||||
|
# OSError covers URLError/TimeoutError/connection resets; HTTPException
|
||||||
|
# covers chunked-transfer errors (IncompleteRead/BadStatusLine, #1024).
|
||||||
|
logger.warning("Reddit RSS fetch failed for r/%s · %s: %s", sub, ticker, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
posts = []
|
||||||
|
for entry in root.findall("atom:entry", _ATOM_NS)[:limit]:
|
||||||
|
title_el = entry.find("atom:title", _ATOM_NS)
|
||||||
|
published_el = entry.find("atom:published", _ATOM_NS)
|
||||||
|
content_el = entry.find("atom:content", _ATOM_NS)
|
||||||
|
posts.append({
|
||||||
|
"title": (title_el.text if title_el is not None else "") or "",
|
||||||
|
"score": None,
|
||||||
|
"num_comments": None,
|
||||||
|
"created_utc": _iso_to_timestamp(
|
||||||
|
published_el.text if published_el is not None else None
|
||||||
|
),
|
||||||
|
"selftext": _strip_html(content_el.text if content_el is not None else ""),
|
||||||
|
"source": "rss",
|
||||||
|
})
|
||||||
|
return posts
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_subreddit_json(
|
||||||
|
ticker: str,
|
||||||
|
sub: str,
|
||||||
|
limit: int,
|
||||||
|
timeout: float,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Richer JSON search path (carries score / comment counts).
|
||||||
|
|
||||||
|
Reddit's WAF currently returns ``403 Blocked`` on this endpoint for
|
||||||
|
non-OAuth clients (issue #862), so it is NOT used by default — calling it on
|
||||||
|
every request only doubled our volume against the per-IP rate limit and
|
||||||
|
triggered 429s on the RSS fallback. Kept for the day the WAF relaxes or an
|
||||||
|
OAuth token is wired in; degrades to RSS on failure.
|
||||||
|
"""
|
||||||
|
url = _API.format(sub=sub, qs=_search_qs(ticker, limit))
|
||||||
|
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
|
||||||
|
try:
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
payload = json.loads(resp.read())
|
||||||
|
children = (payload.get("data") or {}).get("children") or []
|
||||||
|
return [c.get("data", {}) for c in children if isinstance(c, dict)]
|
||||||
|
except (OSError, http.client.HTTPException, json.JSONDecodeError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Reddit JSON fetch failed for r/%s · %s: %s — falling back to RSS feed.",
|
||||||
|
sub, ticker, exc,
|
||||||
|
)
|
||||||
|
return _fetch_subreddit_rss(ticker, sub, limit, timeout)
|
||||||
|
|
||||||
|
|
||||||
def _fetch_subreddit(
|
def _fetch_subreddit(
|
||||||
ticker: str,
|
ticker: str,
|
||||||
sub: str,
|
sub: str,
|
||||||
limit: int,
|
limit: int,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
qs = urlencode({
|
"""Fetch one subreddit, RSS-first.
|
||||||
"q": ticker,
|
|
||||||
"restrict_sr": "on",
|
The JSON search endpoint is reliably WAF-blocked (403) for public clients,
|
||||||
"sort": "new",
|
so we go straight to the RSS feed — which serves our identified User-Agent
|
||||||
"t": "week", # last 7 days
|
reliably — halving our request volume against Reddit's per-IP rate limit.
|
||||||
"limit": limit,
|
"""
|
||||||
})
|
return _fetch_subreddit_rss(ticker, sub, limit, timeout)
|
||||||
url = _API.format(sub=sub, qs=qs)
|
|
||||||
req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"})
|
|
||||||
try:
|
|
||||||
with urlopen(req, timeout=timeout) as resp:
|
|
||||||
payload = json.loads(resp.read())
|
|
||||||
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
|
|
||||||
logger.warning("Reddit fetch failed for r/%s · %s: %s", sub, ticker, exc)
|
|
||||||
return []
|
|
||||||
children = (payload.get("data") or {}).get("children") or []
|
|
||||||
return [c.get("data", {}) for c in children if isinstance(c, dict)]
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_reddit_posts(
|
def fetch_reddit_posts(
|
||||||
@@ -61,13 +191,14 @@ def fetch_reddit_posts(
|
|||||||
subreddits: Iterable[str] = DEFAULT_SUBREDDITS,
|
subreddits: Iterable[str] = DEFAULT_SUBREDDITS,
|
||||||
limit_per_sub: int = 5,
|
limit_per_sub: int = 5,
|
||||||
timeout: float = 10.0,
|
timeout: float = 10.0,
|
||||||
inter_request_delay: float = 0.4,
|
inter_request_delay: float = 1.0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Fetch recent Reddit posts mentioning ``ticker`` across finance
|
"""Fetch recent Reddit posts mentioning ``ticker`` across finance
|
||||||
subreddits and return them as a formatted plaintext block.
|
subreddits and return them as a formatted plaintext block.
|
||||||
|
|
||||||
``inter_request_delay`` keeps us under Reddit's public rate limit
|
``inter_request_delay`` paces the (now RSS-only) per-subreddit requests to
|
||||||
(~10 req/min per IP) even if the caller queries many subreddits.
|
stay under Reddit's public per-IP rate limit; combined with the RSS-first
|
||||||
|
path it makes 429s rare even when several analyses run back-to-back.
|
||||||
"""
|
"""
|
||||||
blocks = []
|
blocks = []
|
||||||
total_posts = 0
|
total_posts = 0
|
||||||
@@ -80,20 +211,28 @@ def fetch_reddit_posts(
|
|||||||
blocks.append(f"r/{sub}: <no posts found mentioning {ticker.upper()} in the past 7 days>")
|
blocks.append(f"r/{sub}: <no posts found mentioning {ticker.upper()} in the past 7 days>")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lines = [f"r/{sub} — {len(posts)} recent posts mentioning {ticker.upper()}:"]
|
via_rss = any(p.get("source") == "rss" for p in posts)
|
||||||
|
header = f"r/{sub} — {len(posts)} recent posts mentioning {ticker.upper()}"
|
||||||
|
header += " (via RSS feed; scores/comments unavailable):" if via_rss else ":"
|
||||||
|
lines = [header]
|
||||||
for p in posts:
|
for p in posts:
|
||||||
title = (p.get("title") or "").replace("\n", " ").strip()
|
title = (p.get("title") or "").replace("\n", " ").strip()
|
||||||
score = p.get("score", 0)
|
score = p.get("score")
|
||||||
comments = p.get("num_comments", 0)
|
comments = p.get("num_comments")
|
||||||
created = p.get("created_utc")
|
created = p.get("created_utc")
|
||||||
created_str = (
|
created_str = (
|
||||||
time.strftime("%Y-%m-%d", time.gmtime(created)) if created else "?"
|
time.strftime("%Y-%m-%d", time.gmtime(created)) if created else "?"
|
||||||
)
|
)
|
||||||
|
# Score / comment counts are absent on the RSS fallback path —
|
||||||
|
# show them only when present rather than printing fake zeros.
|
||||||
|
meta = created_str
|
||||||
|
if score is not None and comments is not None:
|
||||||
|
meta += f" · {score:>4}↑ · {comments:>3}c"
|
||||||
selftext = (p.get("selftext") or "").replace("\n", " ").strip()
|
selftext = (p.get("selftext") or "").replace("\n", " ").strip()
|
||||||
if len(selftext) > 240:
|
if len(selftext) > 240:
|
||||||
selftext = selftext[:240] + "…"
|
selftext = selftext[:240] + "…"
|
||||||
lines.append(
|
lines.append(
|
||||||
f" [{created_str} · {score:>4}↑ · {comments:>3}c] {title}"
|
f" [{meta}] {title}"
|
||||||
+ (f"\n body excerpt: {selftext}" if selftext else "")
|
+ (f"\n body excerpt: {selftext}" if selftext else "")
|
||||||
)
|
)
|
||||||
blocks.append("\n".join(lines))
|
blocks.append("\n".join(lines))
|
||||||
|
|||||||
@@ -1,17 +1,24 @@
|
|||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from yfinance.exceptions import YFRateLimitError
|
|
||||||
from stockstats import wrap
|
from stockstats import wrap
|
||||||
from typing import Annotated
|
from yfinance.exceptions import YFRateLimitError
|
||||||
import os
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
|
from .symbol_utils import NoMarketDataError, normalize_symbol
|
||||||
from .utils import safe_ticker_component
|
from .utils import safe_ticker_component
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# A vendor's latest OHLCV row this many calendar days before the requested date
|
||||||
|
# is treated as stale. Generous enough to span long holiday weekends, tight
|
||||||
|
# enough to catch the year-old frames yfinance occasionally returns (#1021).
|
||||||
|
MAX_OHLCV_STALE_DAYS = 10
|
||||||
|
|
||||||
|
|
||||||
def yf_retry(func, max_retries=3, base_delay=2.0):
|
def yf_retry(func, max_retries=3, base_delay=2.0):
|
||||||
"""Execute a yfinance call with exponential backoff on rate limits.
|
"""Execute a yfinance call with exponential backoff on rate limits.
|
||||||
@@ -32,8 +39,24 @@ def yf_retry(func, max_retries=3, base_delay=2.0):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_date_column(data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Normalize the date column to ``Date``.
|
||||||
|
|
||||||
|
Some yfinance builds leave the index unnamed (so ``reset_index()`` yields
|
||||||
|
``index``) or use ``Datetime`` for intraday data. Rename the first
|
||||||
|
date-like column so indicators don't silently drop when it isn't ``Date``.
|
||||||
|
"""
|
||||||
|
if "Date" in data.columns:
|
||||||
|
return data
|
||||||
|
for candidate in ("index", "Datetime", "date"):
|
||||||
|
if candidate in data.columns:
|
||||||
|
return data.rename(columns={candidate: "Date"})
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
||||||
"""Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps."""
|
"""Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps."""
|
||||||
|
data = _ensure_date_column(data)
|
||||||
data["Date"] = pd.to_datetime(data["Date"], errors="coerce")
|
data["Date"] = pd.to_datetime(data["Date"], errors="coerce")
|
||||||
data = data.dropna(subset=["Date"])
|
data = data.dropna(subset=["Date"])
|
||||||
|
|
||||||
@@ -45,25 +68,84 @@ def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_ohlcv_dates(data: pd.DataFrame) -> pd.Series:
|
||||||
|
"""Return parsed dates from an OHLCV frame, whether Date is a column or the index."""
|
||||||
|
if "Date" in data.columns:
|
||||||
|
return pd.to_datetime(data["Date"], errors="coerce").dropna()
|
||||||
|
# yfinance keeps the dates in the index (a DatetimeIndex, sometimes unnamed).
|
||||||
|
if isinstance(data.index, pd.DatetimeIndex):
|
||||||
|
return pd.Series(pd.to_datetime(data.index, errors="coerce")).dropna()
|
||||||
|
# Fallback: expose the index and look for any date-like column.
|
||||||
|
df = data.reset_index()
|
||||||
|
for col in ("Date", "Datetime", "date", "index"):
|
||||||
|
if col in df.columns:
|
||||||
|
parsed = pd.to_datetime(df[col], errors="coerce").dropna()
|
||||||
|
if not parsed.empty:
|
||||||
|
return parsed
|
||||||
|
return pd.Series(dtype="datetime64[ns]")
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_ohlcv_not_stale(
|
||||||
|
data: pd.DataFrame,
|
||||||
|
curr_date: str,
|
||||||
|
symbol: str,
|
||||||
|
canonical: str | None = None,
|
||||||
|
*,
|
||||||
|
max_stale_days: int = MAX_OHLCV_STALE_DAYS,
|
||||||
|
) -> None:
|
||||||
|
"""Reject OHLCV whose latest row is far older than curr_date.
|
||||||
|
|
||||||
|
Raises NoMarketDataError (with a stale-specific detail) so the router treats
|
||||||
|
it like any other "no usable data from this vendor" — try the next vendor,
|
||||||
|
then emit one clear unavailable signal. Empty frames are left to the
|
||||||
|
caller's existing no-data handling; this guards only the dangerous case of
|
||||||
|
present-but-stale rows (a vendor returning a year-old frame that would
|
||||||
|
otherwise feed wrong prices to the agent, #1021).
|
||||||
|
"""
|
||||||
|
if data is None or data.empty:
|
||||||
|
return
|
||||||
|
requested = pd.to_datetime(curr_date, errors="coerce")
|
||||||
|
if pd.isna(requested):
|
||||||
|
return
|
||||||
|
requested = requested.normalize()
|
||||||
|
dates = _coerce_ohlcv_dates(data)
|
||||||
|
if dates.empty:
|
||||||
|
return
|
||||||
|
latest = dates.max().normalize()
|
||||||
|
stale_days = (requested - latest).days
|
||||||
|
if stale_days > max_stale_days:
|
||||||
|
raise NoMarketDataError(
|
||||||
|
symbol,
|
||||||
|
canonical,
|
||||||
|
f"latest row is {latest.date()}, {stale_days} days before the "
|
||||||
|
f"requested {requested.date()} (stale) — refusing to use it",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
||||||
"""Fetch OHLCV data with caching, filtered to prevent look-ahead bias.
|
"""Fetch OHLCV data with caching, filtered to prevent look-ahead bias.
|
||||||
|
|
||||||
Downloads 15 years of data up to today and caches per symbol. On
|
Downloads 5 years of data up to today and caches per symbol. On
|
||||||
subsequent calls the cache is reused. Rows after curr_date are
|
subsequent calls the cache is reused. Rows after curr_date are
|
||||||
filtered out so backtests never see future prices.
|
filtered out so backtests never see future prices.
|
||||||
"""
|
"""
|
||||||
# Reject ticker values that would escape the cache directory when
|
# Resolve broker/forex symbols (XAUUSD+ -> GC=F) to Yahoo's convention,
|
||||||
|
# then reject values that would escape the cache directory when
|
||||||
# interpolated into the cache filename (e.g. ``../../tmp/x``).
|
# interpolated into the cache filename (e.g. ``../../tmp/x``).
|
||||||
safe_symbol = safe_ticker_component(symbol)
|
canonical = normalize_symbol(symbol)
|
||||||
|
safe_symbol = safe_ticker_component(canonical)
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
curr_date_dt = pd.to_datetime(curr_date)
|
curr_date_dt = pd.to_datetime(curr_date)
|
||||||
|
|
||||||
# Cache uses a fixed window (15y to today) so one file per symbol
|
# Cache uses a fixed window (5y to today) so one file per symbol.
|
||||||
today_date = pd.Timestamp.today()
|
today_date = pd.Timestamp.today()
|
||||||
start_date = today_date - pd.DateOffset(years=5)
|
start_date = today_date - pd.DateOffset(years=5)
|
||||||
start_str = start_date.strftime("%Y-%m-%d")
|
start_str = start_date.strftime("%Y-%m-%d")
|
||||||
end_str = today_date.strftime("%Y-%m-%d")
|
# yfinance ``end`` is EXCLUSIVE; request tomorrow so today's row is included
|
||||||
|
# when curr_date is the current day (#986). Look-ahead is still prevented by
|
||||||
|
# the curr_date filter below.
|
||||||
|
end_str = (today_date + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||||
data_file = os.path.join(
|
data_file = os.path.join(
|
||||||
@@ -71,25 +153,42 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
|||||||
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
|
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A cached file may be empty if a prior fetch failed (unknown symbol,
|
||||||
|
# transient rate limit). Treat an empty/columnless cache as a miss and
|
||||||
|
# re-fetch rather than serving the poisoned file forever.
|
||||||
|
data = None
|
||||||
if os.path.exists(data_file):
|
if os.path.exists(data_file):
|
||||||
data = pd.read_csv(data_file, on_bad_lines="skip", encoding="utf-8")
|
cached = pd.read_csv(data_file, on_bad_lines="skip", encoding="utf-8")
|
||||||
else:
|
if not cached.empty and "Close" in cached.columns:
|
||||||
data = yf_retry(lambda: yf.download(
|
data = cached
|
||||||
symbol,
|
|
||||||
|
if data is None:
|
||||||
|
downloaded = yf_retry(lambda: yf.download(
|
||||||
|
canonical,
|
||||||
start=start_str,
|
start=start_str,
|
||||||
end=end_str,
|
end=end_str,
|
||||||
multi_level_index=False,
|
multi_level_index=False,
|
||||||
progress=False,
|
progress=False,
|
||||||
auto_adjust=True,
|
auto_adjust=True,
|
||||||
))
|
))
|
||||||
data = data.reset_index()
|
downloaded = _ensure_date_column(downloaded.reset_index())
|
||||||
data.to_csv(data_file, index=False, encoding="utf-8")
|
# Only cache real data — never persist an empty frame.
|
||||||
|
if downloaded.empty or "Close" not in downloaded.columns:
|
||||||
|
raise NoMarketDataError(
|
||||||
|
symbol, canonical, "Yahoo Finance returned no rows"
|
||||||
|
)
|
||||||
|
downloaded.to_csv(data_file, index=False, encoding="utf-8")
|
||||||
|
data = downloaded
|
||||||
|
|
||||||
data = _clean_dataframe(data)
|
data = _clean_dataframe(data)
|
||||||
|
|
||||||
# Filter to curr_date to prevent look-ahead bias in backtesting
|
# Filter to curr_date to prevent look-ahead bias in backtesting
|
||||||
data = data[data["Date"] <= curr_date_dt]
|
data = data[data["Date"] <= curr_date_dt]
|
||||||
|
|
||||||
|
# Reject a stale frame (latest row far older than curr_date) rather than
|
||||||
|
# feeding year-old prices into indicators (#1021).
|
||||||
|
_assert_ohlcv_not_stale(data, curr_date, symbol, canonical)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,9 @@ network call succeeded.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import http.client
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional
|
|
||||||
from urllib.error import HTTPError, URLError
|
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -40,7 +38,9 @@ def fetch_stocktwits_messages(ticker: str, limit: int = 30, timeout: float = 10.
|
|||||||
try:
|
try:
|
||||||
with urlopen(req, timeout=timeout) as resp:
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
data = json.loads(resp.read())
|
data = json.loads(resp.read())
|
||||||
except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc:
|
except (OSError, http.client.HTTPException, json.JSONDecodeError) as exc:
|
||||||
|
# OSError covers URLError/TimeoutError/connection resets; HTTPException
|
||||||
|
# covers chunked-transfer errors (IncompleteRead/BadStatusLine, #1024).
|
||||||
logger.warning("StockTwits fetch failed for %s: %s", ticker, exc)
|
logger.warning("StockTwits fetch failed for %s: %s", ticker, exc)
|
||||||
return f"<stocktwits unavailable: {type(exc).__name__}>"
|
return f"<stocktwits unavailable: {type(exc).__name__}>"
|
||||||
|
|
||||||
|
|||||||
138
tradingagents/dataflows/symbol_utils.py
Normal file
138
tradingagents/dataflows/symbol_utils.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Symbol normalization and market-data error types for vendor calls.
|
||||||
|
|
||||||
|
Yahoo Finance (the default vendor) uses specific ticker conventions that
|
||||||
|
differ from the broker / TradingView / MT5 style symbols users often type:
|
||||||
|
|
||||||
|
user types Yahoo wants why
|
||||||
|
--------------- --------------- -----------------------------------
|
||||||
|
XAUUSD, XAUUSD+ GC=F gold has no forex pair on Yahoo;
|
||||||
|
it is quoted as a COMEX future
|
||||||
|
EURUSD EURUSD=X spot forex pairs take a ``=X`` suffix
|
||||||
|
BTCUSD BTC-USD crypto pairs use a ``-`` separator
|
||||||
|
SPX500, US500 ^GSPC index CFDs map to Yahoo index symbols
|
||||||
|
|
||||||
|
Passing the raw broker symbol to Yahoo returns an empty result, which the
|
||||||
|
agents previously received as free text and could hallucinate a price
|
||||||
|
around (see issue #781). Centralizing the mapping here means every yfinance
|
||||||
|
entry point resolves symbols the same way, and new instruments are added by
|
||||||
|
appending a table row rather than editing call sites.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
# NoMarketDataError lives in the vendor-error taxonomy (errors.py); re-exported
|
||||||
|
# here for the many call sites that import it alongside normalize_symbol.
|
||||||
|
from .errors import NoMarketDataError as NoMarketDataError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ISO-4217 codes common enough to appear in retail forex pairs. A bare
|
||||||
|
# six-letter symbol whose halves are BOTH in this set is treated as a spot
|
||||||
|
# forex pair and given Yahoo's ``=X`` suffix.
|
||||||
|
_FOREX_CURRENCIES = frozenset(
|
||||||
|
{
|
||||||
|
"USD", "EUR", "GBP", "JPY", "CHF", "CAD", "AUD", "NZD",
|
||||||
|
"CNY", "CNH", "HKD", "SGD", "SEK", "NOK", "DKK", "PLN",
|
||||||
|
"MXN", "ZAR", "TRY", "INR", "KRW", "BRL", "RUB", "THB",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Crypto bases that brokers quote against USD without a separator.
|
||||||
|
_CRYPTO_BASES = frozenset(
|
||||||
|
{"BTC", "ETH", "SOL", "XRP", "ADA", "DOGE", "LTC", "BCH", "DOT", "AVAX", "LINK"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit aliases for instruments whose broker symbol does not map to a
|
||||||
|
# Yahoo symbol by rule. Metals/energy resolve to their front-month future;
|
||||||
|
# index CFD names resolve to the underlying Yahoo index symbol. Extend by
|
||||||
|
# adding rows — no call site changes required.
|
||||||
|
_ALIASES = {
|
||||||
|
# Precious metals (spot names -> COMEX/NYMEX futures)
|
||||||
|
"XAUUSD": "GC=F", "XAU": "GC=F", "GOLD": "GC=F",
|
||||||
|
"XAGUSD": "SI=F", "XAG": "SI=F", "SILVER": "SI=F",
|
||||||
|
"XPTUSD": "PL=F", "XPDUSD": "PA=F",
|
||||||
|
# Energy
|
||||||
|
"WTICOUSD": "CL=F", "USOIL": "CL=F", "WTI": "CL=F",
|
||||||
|
"BCOUSD": "BZ=F", "UKOIL": "BZ=F", "BRENT": "BZ=F",
|
||||||
|
"NATGAS": "NG=F", "XNGUSD": "NG=F",
|
||||||
|
"COPPER": "HG=F", "XCUUSD": "HG=F",
|
||||||
|
# Index CFDs -> Yahoo index symbols
|
||||||
|
"SPX500": "^GSPC", "US500": "^GSPC", "SPX": "^GSPC",
|
||||||
|
"NAS100": "^NDX", "US100": "^NDX", "USTEC": "^NDX",
|
||||||
|
"US30": "^DJI", "DJI30": "^DJI", "WS30": "^DJI",
|
||||||
|
"GER40": "^GDAXI", "GER30": "^GDAXI", "DE40": "^GDAXI",
|
||||||
|
"UK100": "^FTSE", "JP225": "^N225", "JPN225": "^N225",
|
||||||
|
"FRA40": "^FCHI", "EU50": "^STOXX50E", "HK50": "^HSI",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Yahoo symbols may contain letters, digits, and these structural characters.
|
||||||
|
_YAHOO_SAFE = re.compile(r"^[A-Za-z0-9._\-\^=]+$")
|
||||||
|
|
||||||
|
|
||||||
|
# Crypto quote currencies that all map to Yahoo's USD pair. Yahoo lists only
|
||||||
|
# ``<BASE>-USD`` (not the USDT/USDC stablecoin pairs), so a broker symbol quoted
|
||||||
|
# in any of these resolves to ``-USD`` (#982). Longest first so ``USDT``/``USDC``
|
||||||
|
# match before the ``USD`` substring.
|
||||||
|
_CRYPTO_QUOTES = ("USDT", "USDC", "USD")
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_crypto(s: str) -> str | None:
|
||||||
|
"""Return ``<BASE>-USD`` if ``s`` is a known crypto quoted in USD/USDT/USDC.
|
||||||
|
|
||||||
|
Accepts dashed or undashed forms: ``BTCUSD``, ``BTCUSDT``, ``BTC-USDT``,
|
||||||
|
``BTC-USDC`` all resolve to ``BTC-USD``. Returns None otherwise.
|
||||||
|
"""
|
||||||
|
compact = s.replace("-", "")
|
||||||
|
for quote in _CRYPTO_QUOTES:
|
||||||
|
if compact.endswith(quote):
|
||||||
|
base = compact[: -len(quote)]
|
||||||
|
if base in _CRYPTO_BASES:
|
||||||
|
return f"{base}-USD"
|
||||||
|
break
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_symbol(raw: str) -> str:
|
||||||
|
"""Map a user/broker symbol to its canonical Yahoo Finance symbol.
|
||||||
|
|
||||||
|
Resolution order (first match wins):
|
||||||
|
1. Explicit alias table (metals, energy, index CFDs).
|
||||||
|
2. Crypto rule: a known crypto base quoted in USD/USDT/USDC (dashed or
|
||||||
|
not) -> ``BASE-USD``.
|
||||||
|
3. Forex rule: six letters that are two ISO currency codes -> ``PAIR=X``.
|
||||||
|
4. Otherwise the upper-cased symbol is returned unchanged (plain
|
||||||
|
equities, ETFs, Yahoo-native symbols like ``GC=F`` or ``^GSPC``).
|
||||||
|
|
||||||
|
A trailing ``+`` (broker CFD marker, e.g. ``XAUUSD+``) is stripped before
|
||||||
|
matching. The function is purely syntactic — it performs no network
|
||||||
|
calls — so it is safe to apply on every request.
|
||||||
|
"""
|
||||||
|
if not isinstance(raw, str) or not raw.strip():
|
||||||
|
return raw
|
||||||
|
|
||||||
|
s = raw.strip().upper()
|
||||||
|
# Broker CFD/qualifier suffixes Yahoo never uses.
|
||||||
|
s = s.rstrip("+")
|
||||||
|
|
||||||
|
crypto = _normalize_crypto(s)
|
||||||
|
if s in _ALIASES:
|
||||||
|
canonical = _ALIASES[s]
|
||||||
|
elif crypto is not None:
|
||||||
|
canonical = crypto
|
||||||
|
elif len(s) == 6 and s[:3] in _FOREX_CURRENCIES and s[3:] in _FOREX_CURRENCIES:
|
||||||
|
canonical = f"{s}=X"
|
||||||
|
else:
|
||||||
|
canonical = s
|
||||||
|
|
||||||
|
if canonical != raw.strip().upper():
|
||||||
|
logger.info("Resolved symbol %r to Yahoo symbol %r", raw, canonical)
|
||||||
|
return canonical
|
||||||
|
|
||||||
|
|
||||||
|
def is_yahoo_safe(symbol: str) -> bool:
|
||||||
|
"""True when ``symbol`` only contains characters Yahoo symbols use."""
|
||||||
|
return bool(symbol) and _YAHOO_SAFE.fullmatch(symbol) is not None
|
||||||
@@ -1,16 +1,17 @@
|
|||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import json
|
from datetime import date, datetime, timedelta
|
||||||
import pandas as pd
|
|
||||||
from datetime import date, timedelta, datetime
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||||
|
|
||||||
# Tickers can contain letters, digits, dot, dash, underscore, and caret
|
# Tickers can contain letters, digits, dot, dash, underscore, caret
|
||||||
# (for index symbols like ^GSPC). Anything else is rejected so the value
|
# (index symbols like ^GSPC), equals (futures like GC=F), and plus
|
||||||
# never escapes a containing directory when interpolated into a path.
|
# (forex/CFD symbols like XAUUSD+). None of these enable directory
|
||||||
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$")
|
# traversal, so the value never escapes a containing directory when
|
||||||
|
# interpolated into a path. Anything else is rejected.
|
||||||
|
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^=+]+$")
|
||||||
|
|
||||||
|
|
||||||
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
|
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
from typing import Annotated
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dateutil.relativedelta import relativedelta
|
from typing import Annotated
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
import os
|
from dateutil.relativedelta import relativedelta
|
||||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date
|
|
||||||
|
from .stockstats_utils import (
|
||||||
|
StockstatsUtils,
|
||||||
|
_assert_ohlcv_not_stale,
|
||||||
|
filter_financials_by_date,
|
||||||
|
load_ohlcv,
|
||||||
|
yf_retry,
|
||||||
|
)
|
||||||
|
from .symbol_utils import NoMarketDataError, normalize_symbol
|
||||||
|
|
||||||
|
|
||||||
def get_YFin_data_online(
|
def get_YFin_data_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
@@ -13,24 +22,35 @@ def get_YFin_data_online(
|
|||||||
):
|
):
|
||||||
|
|
||||||
datetime.strptime(start_date, "%Y-%m-%d")
|
datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
datetime.strptime(end_date, "%Y-%m-%d")
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
|
||||||
# Create ticker object
|
# Resolve broker/forex symbols to Yahoo's convention (XAUUSD+ -> GC=F).
|
||||||
ticker = yf.Ticker(symbol.upper())
|
canonical = normalize_symbol(symbol)
|
||||||
|
ticker = yf.Ticker(canonical)
|
||||||
|
|
||||||
# Fetch historical data for the specified date range
|
# yfinance treats ``end`` as EXCLUSIVE, so it would drop the requested
|
||||||
data = yf_retry(lambda: ticker.history(start=start_date, end=end_date))
|
# end_date row (and the current day when end_date is today). Request one day
|
||||||
|
# past end_date so the requested range is actually inclusive (#986/#987).
|
||||||
|
end_inclusive = (end_dt + relativedelta(days=1)).strftime("%Y-%m-%d")
|
||||||
|
data = yf_retry(lambda: ticker.history(start=start_date, end=end_inclusive))
|
||||||
|
|
||||||
# Check if data is empty
|
# Empty result means the symbol is unknown/delisted. Raise a typed error
|
||||||
|
# instead of returning prose: the routing layer turns it into a single
|
||||||
|
# unambiguous "no data" signal so the agent never fabricates a price.
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return (
|
raise NoMarketDataError(
|
||||||
f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
symbol, canonical, f"no rows between {start_date} and {end_date}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove timezone info from index for cleaner output
|
# Remove timezone info from index for cleaner output
|
||||||
if data.index.tz is not None:
|
if data.index.tz is not None:
|
||||||
data.index = data.index.tz_localize(None)
|
data.index = data.index.tz_localize(None)
|
||||||
|
|
||||||
|
# Reject a stale frame (e.g. a year-old partial response) before it is
|
||||||
|
# formatted into the report. Raises NoMarketDataError, which the router
|
||||||
|
# turns into one clear unavailable signal (#1021).
|
||||||
|
_assert_ohlcv_not_stale(data, end_date, symbol, canonical)
|
||||||
|
|
||||||
# Round numerical values to 2 decimal places for cleaner display
|
# Round numerical values to 2 decimal places for cleaner display
|
||||||
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
|
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
|
||||||
for col in numeric_columns:
|
for col in numeric_columns:
|
||||||
@@ -40,8 +60,10 @@ def get_YFin_data_online(
|
|||||||
# Convert DataFrame to CSV string
|
# Convert DataFrame to CSV string
|
||||||
csv_string = data.to_csv()
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
# Add header information
|
# Add header information; note the resolved symbol when it differs so the
|
||||||
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
|
# agent (and user) can see which instrument was actually priced.
|
||||||
|
label = canonical if canonical == symbol.upper() else f"{canonical} (from {symbol})"
|
||||||
|
header = f"# Stock data for {label} from {start_date} to {end_date}\n"
|
||||||
header += f"# Total records: {len(data)}\n"
|
header += f"# Total records: {len(data)}\n"
|
||||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
@@ -141,28 +163,30 @@ def get_stock_stats_indicators_window(
|
|||||||
# Optimized: Get stock data once and calculate indicators for all dates
|
# Optimized: Get stock data once and calculate indicators for all dates
|
||||||
try:
|
try:
|
||||||
indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date)
|
indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date)
|
||||||
|
|
||||||
# Generate the date range we need
|
# Generate the date range we need
|
||||||
current_dt = curr_date_dt
|
current_dt = curr_date_dt
|
||||||
date_values = []
|
date_values = []
|
||||||
|
|
||||||
while current_dt >= before:
|
while current_dt >= before:
|
||||||
date_str = current_dt.strftime('%Y-%m-%d')
|
date_str = current_dt.strftime('%Y-%m-%d')
|
||||||
|
|
||||||
# Look up the indicator value for this date
|
# Look up the indicator value for this date
|
||||||
if date_str in indicator_data:
|
if date_str in indicator_data:
|
||||||
indicator_value = indicator_data[date_str]
|
indicator_value = indicator_data[date_str]
|
||||||
else:
|
else:
|
||||||
indicator_value = "N/A: Not a trading day (weekend or holiday)"
|
indicator_value = "N/A: Not a trading day (weekend or holiday)"
|
||||||
|
|
||||||
date_values.append((date_str, indicator_value))
|
date_values.append((date_str, indicator_value))
|
||||||
current_dt = current_dt - relativedelta(days=1)
|
current_dt = current_dt - relativedelta(days=1)
|
||||||
|
|
||||||
# Build the result string
|
# Build the result string
|
||||||
ind_string = ""
|
ind_string = ""
|
||||||
for date_str, value in date_values:
|
for date_str, value in date_values:
|
||||||
ind_string += f"{date_str}: {value}\n"
|
ind_string += f"{date_str}: {value}\n"
|
||||||
|
|
||||||
|
except NoMarketDataError:
|
||||||
|
raise # Unknown/delisted symbol — let the router emit the sentinel
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting bulk stockstats data: {e}")
|
print(f"Error getting bulk stockstats data: {e}")
|
||||||
# Fallback to original implementation if bulk method fails
|
# Fallback to original implementation if bulk method fails
|
||||||
@@ -200,22 +224,22 @@ def _get_stock_stats_bulk(
|
|||||||
data = load_ohlcv(symbol, curr_date)
|
data = load_ohlcv(symbol, curr_date)
|
||||||
df = wrap(data)
|
df = wrap(data)
|
||||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# Calculate the indicator for all rows at once
|
# Calculate the indicator for all rows at once
|
||||||
df[indicator] # This triggers stockstats to calculate the indicator
|
df[indicator] # This triggers stockstats to calculate the indicator
|
||||||
|
|
||||||
# Create a dictionary mapping date strings to indicator values
|
# Create a dictionary mapping date strings to indicator values
|
||||||
result_dict = {}
|
result_dict = {}
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
date_str = row["Date"]
|
date_str = row["Date"]
|
||||||
indicator_value = row[indicator]
|
indicator_value = row[indicator]
|
||||||
|
|
||||||
# Handle NaN/None values
|
# Handle NaN/None values
|
||||||
if pd.isna(indicator_value):
|
if pd.isna(indicator_value):
|
||||||
result_dict[date_str] = "N/A"
|
result_dict[date_str] = "N/A"
|
||||||
else:
|
else:
|
||||||
result_dict[date_str] = str(indicator_value)
|
result_dict[date_str] = str(indicator_value)
|
||||||
|
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -236,6 +260,8 @@ def get_stockstats_indicator(
|
|||||||
indicator,
|
indicator,
|
||||||
curr_date,
|
curr_date,
|
||||||
)
|
)
|
||||||
|
except NoMarketDataError:
|
||||||
|
raise # Unknown/delisted symbol — let the router emit the sentinel
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
|
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
|
||||||
@@ -250,12 +276,13 @@ def get_fundamentals(
|
|||||||
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
||||||
):
|
):
|
||||||
"""Get company fundamentals overview from yfinance."""
|
"""Get company fundamentals overview from yfinance."""
|
||||||
|
canonical = normalize_symbol(ticker)
|
||||||
try:
|
try:
|
||||||
ticker_obj = yf.Ticker(ticker.upper())
|
ticker_obj = yf.Ticker(canonical)
|
||||||
info = yf_retry(lambda: ticker_obj.info)
|
info = yf_retry(lambda: ticker_obj.info)
|
||||||
|
|
||||||
if not info:
|
if not info:
|
||||||
return f"No fundamentals data found for symbol '{ticker}'"
|
raise NoMarketDataError(ticker, canonical, "no fundamentals returned")
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
("Name", info.get("longName")),
|
("Name", info.get("longName")),
|
||||||
@@ -293,11 +320,20 @@ def get_fundamentals(
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
lines.append(f"{label}: {value}")
|
lines.append(f"{label}: {value}")
|
||||||
|
|
||||||
header = f"# Company Fundamentals for {ticker.upper()}\n"
|
# yfinance returns a stub dict (e.g. {"trailingPegRatio": None}) for
|
||||||
|
# unknown symbols, so `info` is truthy but every field is empty. Treat
|
||||||
|
# "no usable fields" as no data rather than emitting a bare header the
|
||||||
|
# agent might fabricate around.
|
||||||
|
if not lines:
|
||||||
|
raise NoMarketDataError(ticker, canonical, "no fundamental fields returned")
|
||||||
|
|
||||||
|
header = f"# Company Fundamentals for {canonical}\n"
|
||||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
return header + "\n".join(lines)
|
return header + "\n".join(lines)
|
||||||
|
|
||||||
|
except NoMarketDataError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving fundamentals for {ticker}: {str(e)}"
|
return f"Error retrieving fundamentals for {ticker}: {str(e)}"
|
||||||
|
|
||||||
@@ -308,8 +344,9 @@ def get_balance_sheet(
|
|||||||
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
):
|
):
|
||||||
"""Get balance sheet data from yfinance."""
|
"""Get balance sheet data from yfinance."""
|
||||||
|
canonical = normalize_symbol(ticker)
|
||||||
try:
|
try:
|
||||||
ticker_obj = yf.Ticker(ticker.upper())
|
ticker_obj = yf.Ticker(canonical)
|
||||||
|
|
||||||
if freq.lower() == "quarterly":
|
if freq.lower() == "quarterly":
|
||||||
data = yf_retry(lambda: ticker_obj.quarterly_balance_sheet)
|
data = yf_retry(lambda: ticker_obj.quarterly_balance_sheet)
|
||||||
@@ -319,17 +356,19 @@ def get_balance_sheet(
|
|||||||
data = filter_financials_by_date(data, curr_date)
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No balance sheet data found for symbol '{ticker}'"
|
raise NoMarketDataError(ticker, canonical, "no balance sheet data")
|
||||||
|
|
||||||
# Convert to CSV string for consistency with other functions
|
# Convert to CSV string for consistency with other functions
|
||||||
csv_string = data.to_csv()
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
# Add header information
|
# Add header information
|
||||||
header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n"
|
header = f"# Balance Sheet data for {canonical} ({freq})\n"
|
||||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
return header + csv_string
|
return header + csv_string
|
||||||
|
|
||||||
|
except NoMarketDataError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving balance sheet for {ticker}: {str(e)}"
|
return f"Error retrieving balance sheet for {ticker}: {str(e)}"
|
||||||
|
|
||||||
@@ -340,8 +379,9 @@ def get_cashflow(
|
|||||||
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
):
|
):
|
||||||
"""Get cash flow data from yfinance."""
|
"""Get cash flow data from yfinance."""
|
||||||
|
canonical = normalize_symbol(ticker)
|
||||||
try:
|
try:
|
||||||
ticker_obj = yf.Ticker(ticker.upper())
|
ticker_obj = yf.Ticker(canonical)
|
||||||
|
|
||||||
if freq.lower() == "quarterly":
|
if freq.lower() == "quarterly":
|
||||||
data = yf_retry(lambda: ticker_obj.quarterly_cashflow)
|
data = yf_retry(lambda: ticker_obj.quarterly_cashflow)
|
||||||
@@ -351,17 +391,19 @@ def get_cashflow(
|
|||||||
data = filter_financials_by_date(data, curr_date)
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No cash flow data found for symbol '{ticker}'"
|
raise NoMarketDataError(ticker, canonical, "no cash flow data")
|
||||||
|
|
||||||
# Convert to CSV string for consistency with other functions
|
# Convert to CSV string for consistency with other functions
|
||||||
csv_string = data.to_csv()
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
# Add header information
|
# Add header information
|
||||||
header = f"# Cash Flow data for {ticker.upper()} ({freq})\n"
|
header = f"# Cash Flow data for {canonical} ({freq})\n"
|
||||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
return header + csv_string
|
return header + csv_string
|
||||||
|
|
||||||
|
except NoMarketDataError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving cash flow for {ticker}: {str(e)}"
|
return f"Error retrieving cash flow for {ticker}: {str(e)}"
|
||||||
|
|
||||||
@@ -372,8 +414,9 @@ def get_income_statement(
|
|||||||
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||||
):
|
):
|
||||||
"""Get income statement data from yfinance."""
|
"""Get income statement data from yfinance."""
|
||||||
|
canonical = normalize_symbol(ticker)
|
||||||
try:
|
try:
|
||||||
ticker_obj = yf.Ticker(ticker.upper())
|
ticker_obj = yf.Ticker(canonical)
|
||||||
|
|
||||||
if freq.lower() == "quarterly":
|
if freq.lower() == "quarterly":
|
||||||
data = yf_retry(lambda: ticker_obj.quarterly_income_stmt)
|
data = yf_retry(lambda: ticker_obj.quarterly_income_stmt)
|
||||||
@@ -383,17 +426,19 @@ def get_income_statement(
|
|||||||
data = filter_financials_by_date(data, curr_date)
|
data = filter_financials_by_date(data, curr_date)
|
||||||
|
|
||||||
if data.empty:
|
if data.empty:
|
||||||
return f"No income statement data found for symbol '{ticker}'"
|
raise NoMarketDataError(ticker, canonical, "no income statement data")
|
||||||
|
|
||||||
# Convert to CSV string for consistency with other functions
|
# Convert to CSV string for consistency with other functions
|
||||||
csv_string = data.to_csv()
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
# Add header information
|
# Add header information
|
||||||
header = f"# Income Statement data for {ticker.upper()} ({freq})\n"
|
header = f"# Income Statement data for {canonical} ({freq})\n"
|
||||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
return header + csv_string
|
return header + csv_string
|
||||||
|
|
||||||
|
except NoMarketDataError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving income statement for {ticker}: {str(e)}"
|
return f"Error retrieving income statement for {ticker}: {str(e)}"
|
||||||
|
|
||||||
@@ -402,21 +447,24 @@ def get_insider_transactions(
|
|||||||
ticker: Annotated[str, "ticker symbol of the company"]
|
ticker: Annotated[str, "ticker symbol of the company"]
|
||||||
):
|
):
|
||||||
"""Get insider transactions data from yfinance."""
|
"""Get insider transactions data from yfinance."""
|
||||||
|
canonical = normalize_symbol(ticker)
|
||||||
try:
|
try:
|
||||||
ticker_obj = yf.Ticker(ticker.upper())
|
ticker_obj = yf.Ticker(canonical)
|
||||||
data = yf_retry(lambda: ticker_obj.insider_transactions)
|
data = yf_retry(lambda: ticker_obj.insider_transactions)
|
||||||
|
|
||||||
|
# Empty is normal here (many valid symbols have no insider filings),
|
||||||
|
# so report it plainly rather than treating the symbol as invalid.
|
||||||
if data is None or data.empty:
|
if data is None or data.empty:
|
||||||
return f"No insider transactions data found for symbol '{ticker}'"
|
return f"No insider transactions reported for symbol '{canonical}'"
|
||||||
|
|
||||||
# Convert to CSV string for consistency with other functions
|
# Convert to CSV string for consistency with other functions
|
||||||
csv_string = data.to_csv()
|
csv_string = data.to_csv()
|
||||||
|
|
||||||
# Add header information
|
# Add header information
|
||||||
header = f"# Insider Transactions data for {ticker.upper()}\n"
|
header = f"# Insider Transactions data for {canonical}\n"
|
||||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
return header + csv_string
|
return header + csv_string
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user