Merge remote-tracking branch 'upstream/main' into analyst-phase1-observability

# Conflicts:
#	tradingagents/default_config.py
#	tradingagents/graph/setup.py
This commit is contained in:
CadeYu
2026-05-11 16:44:00 +08:00
71 changed files with 5225 additions and 760 deletions

View File

@@ -1,13 +1,10 @@
from typing import Optional
import datetime
import typer
import questionary
from pathlib import Path
from functools import wraps
from rich.console import Console
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
@@ -58,7 +55,7 @@ class MessageBuffer:
# Analyst name mapping
ANALYST_MAPPING = {
"market": "Market Analyst",
"social": "Social Analyst",
"social": "Sentiment Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
}
@@ -68,7 +65,7 @@ class MessageBuffer:
# finalizing_agent: which agent must be "completed" for this report to count as done
REPORT_SECTIONS = {
"market_report": ("market", "Market Analyst"),
"sentiment_report": ("social", "Social Analyst"),
"sentiment_report": ("social", "Sentiment Analyst"),
"news_report": ("news", "News Analyst"),
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
"investment_plan": (None, "Research Manager"),
@@ -85,7 +82,7 @@ class MessageBuffer:
self.current_agent = None
self.report_sections = {}
self.selected_analysts = []
self._last_message_id = None
self._processed_message_ids = set()
def init_for_analysis(self, selected_analysts):
"""Initialize agent status and report sections based on selected analysts.
@@ -120,7 +117,7 @@ class MessageBuffer:
self.current_agent = None
self.messages.clear()
self.tool_calls.clear()
self._last_message_id = None
self._processed_message_ids.clear()
def get_completed_reports_count(self):
"""Count reports that are finalized (their finalizing agent is completed).
@@ -289,7 +286,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
all_teams = {
"Analyst Team": [
"Market Analyst",
"Social Analyst",
"Sentiment Analyst",
"News Analyst",
"Fundamentals Analyst",
],
@@ -468,7 +465,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
def get_user_selections():
"""Get all user selections before starting the analysis display."""
# Display ASCII art welcome message
with open(Path(__file__).parent / "static" / "welcome.txt", "r") as f:
with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f:
welcome_ascii = f.read()
# Create welcome box content
@@ -561,6 +558,21 @@ def get_user_selections():
)
selected_llm_provider, backend_url = select_llm_provider()
# Providers with regional endpoints prompt for the region as a secondary
# step so the main dropdown stays clean (mainland China and international
# accounts cannot share API keys).
if selected_llm_provider == "qwen":
selected_llm_provider, backend_url = ask_qwen_region()
elif selected_llm_provider == "minimax":
selected_llm_provider, backend_url = ask_minimax_region()
elif selected_llm_provider == "glm":
selected_llm_provider, backend_url = ask_glm_region()
# Confirm the provider's API key is present; prompt the user to paste
# one and persist it to .env if it's missing, so the analysis run
# doesn't fail later at the first API call.
ensure_api_key(selected_llm_provider)
# Step 7: Thinking agents
console.print(
create_question_box(
@@ -618,8 +630,26 @@ def get_user_selections():
def get_ticker():
"""Get ticker symbol from user input."""
return typer.prompt("", default="SPY")
"""Get ticker symbol from user input, preserving exchange suffixes."""
# typer.prompt strips trailing dot-suffixes on some shells (e.g. 000404.SH
# collapses to 000404). questionary.text reads the raw line.
ticker = questionary.text(
"",
validate=lambda value: (
not value.strip()
or (
all(ch.isalnum() or ch in "._-^" for ch in value.strip())
and len(value.strip()) <= 32
)
)
or "Please enter a valid ticker symbol, e.g. AAPL, 000404.SZ, 0700.HK.",
).ask()
if ticker is None:
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
raise typer.Exit(1)
return (ticker.strip() or "SPY").upper()
def get_analysis_date():
@@ -651,19 +681,19 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
analyst_parts = []
if final_state.get("market_report"):
analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "market.md").write_text(final_state["market_report"])
(analysts_dir / "market.md").write_text(final_state["market_report"], encoding="utf-8")
analyst_parts.append(("Market Analyst", final_state["market_report"]))
if final_state.get("sentiment_report"):
analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"])
analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"], encoding="utf-8")
analyst_parts.append(("Sentiment Analyst", final_state["sentiment_report"]))
if final_state.get("news_report"):
analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "news.md").write_text(final_state["news_report"])
(analysts_dir / "news.md").write_text(final_state["news_report"], encoding="utf-8")
analyst_parts.append(("News Analyst", final_state["news_report"]))
if final_state.get("fundamentals_report"):
analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"])
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"], encoding="utf-8")
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
if analyst_parts:
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
@@ -676,15 +706,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
research_parts = []
if debate.get("bull_history"):
research_dir.mkdir(exist_ok=True)
(research_dir / "bull.md").write_text(debate["bull_history"])
(research_dir / "bull.md").write_text(debate["bull_history"], encoding="utf-8")
research_parts.append(("Bull Researcher", debate["bull_history"]))
if debate.get("bear_history"):
research_dir.mkdir(exist_ok=True)
(research_dir / "bear.md").write_text(debate["bear_history"])
(research_dir / "bear.md").write_text(debate["bear_history"], encoding="utf-8")
research_parts.append(("Bear Researcher", debate["bear_history"]))
if debate.get("judge_decision"):
research_dir.mkdir(exist_ok=True)
(research_dir / "manager.md").write_text(debate["judge_decision"])
(research_dir / "manager.md").write_text(debate["judge_decision"], encoding="utf-8")
research_parts.append(("Research Manager", debate["judge_decision"]))
if research_parts:
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
@@ -694,7 +724,7 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
if final_state.get("trader_investment_plan"):
trading_dir = save_path / "3_trading"
trading_dir.mkdir(exist_ok=True)
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"])
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"], encoding="utf-8")
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
# 4. Risk Management
@@ -704,15 +734,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
risk_parts = []
if risk.get("aggressive_history"):
risk_dir.mkdir(exist_ok=True)
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"])
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"], encoding="utf-8")
risk_parts.append(("Aggressive Analyst", risk["aggressive_history"]))
if risk.get("conservative_history"):
risk_dir.mkdir(exist_ok=True)
(risk_dir / "conservative.md").write_text(risk["conservative_history"])
(risk_dir / "conservative.md").write_text(risk["conservative_history"], encoding="utf-8")
risk_parts.append(("Conservative Analyst", risk["conservative_history"]))
if risk.get("neutral_history"):
risk_dir.mkdir(exist_ok=True)
(risk_dir / "neutral.md").write_text(risk["neutral_history"])
(risk_dir / "neutral.md").write_text(risk["neutral_history"], encoding="utf-8")
risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
if risk_parts:
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
@@ -722,12 +752,12 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
if risk.get("judge_decision"):
portfolio_dir = save_path / "5_portfolio"
portfolio_dir.mkdir(exist_ok=True)
(portfolio_dir / "decision.md").write_text(risk["judge_decision"])
(portfolio_dir / "decision.md").write_text(risk["judge_decision"], encoding="utf-8")
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
# Write consolidated report
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections))
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections), encoding="utf-8")
return save_path / "complete_report.md"
@@ -741,7 +771,7 @@ def display_complete_report(final_state):
if final_state.get("market_report"):
analysts.append(("Market Analyst", final_state["market_report"]))
if final_state.get("sentiment_report"):
analysts.append(("Social Analyst", final_state["sentiment_report"]))
analysts.append(("Sentiment Analyst", final_state["sentiment_report"]))
if final_state.get("news_report"):
analysts.append(("News Analyst", final_state["news_report"]))
if final_state.get("fundamentals_report"):
@@ -803,7 +833,7 @@ def update_research_team_status(status):
ANALYST_ORDER = ["market", "social", "news", "fundamentals"]
ANALYST_AGENT_NAMES = {
"market": "Market Analyst",
"social": "Social Analyst",
"social": "Sentiment Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
}
@@ -934,7 +964,7 @@ def format_tool_args(args, max_length=80) -> str:
return result[:max_length - 3] + "..."
return result
def run_analysis():
def run_analysis(checkpoint: bool = False):
# First get all user selections
selections = get_user_selections()
@@ -951,6 +981,7 @@ def run_analysis():
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
config["anthropic_effort"] = selections.get("anthropic_effort")
config["output_language"] = selections.get("output_language", "English")
config["checkpoint_enabled"] = checkpoint
# Create stats callback handler for tracking LLM/tool calls
stats_handler = StatsCallbackHandler()
@@ -993,7 +1024,7 @@ def run_analysis():
func(*args, **kwargs)
timestamp, message_type, content = obj.messages[-1]
content = content.replace("\n", " ") # Replace newlines with spaces
with open(log_file, "a") as f:
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} [{message_type}] {content}\n")
return wrapper
@@ -1004,7 +1035,7 @@ def run_analysis():
func(*args, **kwargs)
timestamp, tool_name, args = obj.tool_calls[-1]
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
with open(log_file, "a") as f:
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
return wrapper
@@ -1018,7 +1049,7 @@ def run_analysis():
if content:
file_name = f"{section_name}.md"
text = "\n".join(str(item) for item in content) if isinstance(content, list) else content
with open(report_dir / file_name, "w") as f:
with open(report_dir / file_name, "w", encoding="utf-8") as f:
f.write(text)
return wrapper
@@ -1067,28 +1098,24 @@ def run_analysis():
# Stream the analysis
trace = []
for chunk in graph.graph.stream(init_agent_state, **args):
# Process messages if present (skip duplicates via message ID)
if len(chunk["messages"]) > 0:
last_message = chunk["messages"][-1]
msg_id = getattr(last_message, "id", None)
# Process all messages in chunk, deduplicating by message ID
for message in chunk.get("messages", []):
msg_id = getattr(message, "id", None)
if msg_id is not None:
if msg_id in message_buffer._processed_message_ids:
continue
message_buffer._processed_message_ids.add(msg_id)
if msg_id != message_buffer._last_message_id:
message_buffer._last_message_id = msg_id
msg_type, content = classify_message_type(message)
if content and content.strip():
message_buffer.add_message(msg_type, content)
# Add message to buffer
msg_type, content = classify_message_type(last_message)
if content and content.strip():
message_buffer.add_message(msg_type, content)
# Handle tool calls
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(
tool_call["name"], tool_call["args"]
)
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
# Update analyst statuses based on report state (runs on every chunk)
update_analyst_statuses(
@@ -1173,8 +1200,11 @@ def run_analysis():
trace.append(chunk)
# Get final state and decision
final_state = trace[-1]
# Streamed chunks are per-node deltas, not full state. Merge them
# so every report field populated across the run is present.
final_state = {}
for chunk in trace:
final_state.update(chunk)
decision = graph.process_signal(final_state["final_trade_decision"])
# Update all agent statuses to completed
@@ -1221,8 +1251,23 @@ def run_analysis():
@app.command()
def analyze():
run_analysis()
def analyze(
checkpoint: bool = typer.Option(
False,
"--checkpoint",
help="Enable checkpoint/resume: save state after each node so a crashed run can resume.",
),
clear_checkpoints: bool = typer.Option(
False,
"--clear-checkpoints",
help="Delete all saved checkpoints before running (force fresh start).",
),
):
if clear_checkpoints:
from tradingagents.graph.checkpointer import clear_all_checkpoints
n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"])
console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]")
run_analysis(checkpoint=checkpoint)
if __name__ == "__main__":