mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-01 14:33:10 +03:00
feat: add footer statistics tracking with LangChain callbacks
- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens - Integrate callbacks into TradingAgentsGraph and all LLM clients - Dynamic agent/report counts based on selected analysts - Fix report completion counting (tied to agent completion)
This commit is contained in:
255
cli/main.py
255
cli/main.py
@@ -15,7 +15,6 @@ from rich.columns import Columns
|
||||
from rich.markdown import Markdown
|
||||
from rich.layout import Layout
|
||||
from rich.text import Text
|
||||
from rich.live import Live
|
||||
from rich.table import Table
|
||||
from collections import deque
|
||||
import time
|
||||
@@ -29,6 +28,7 @@ from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from cli.models import AnalystType
|
||||
from cli.utils import *
|
||||
from cli.announcements import fetch_announcements, display_announcements
|
||||
from cli.stats_handler import StatsCallbackHandler
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -41,40 +41,99 @@ app = typer.Typer(
|
||||
|
||||
# Create a deque to store recent messages with a maximum length
|
||||
class MessageBuffer:
|
||||
# Fixed teams that always run (not user-selectable)
|
||||
FIXED_AGENTS = {
|
||||
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
|
||||
"Trading Team": ["Trader"],
|
||||
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
|
||||
"Portfolio Management": ["Portfolio Manager"],
|
||||
}
|
||||
|
||||
# Analyst name mapping
|
||||
ANALYST_MAPPING = {
|
||||
"market": "Market Analyst",
|
||||
"social": "Social Analyst",
|
||||
"news": "News Analyst",
|
||||
"fundamentals": "Fundamentals Analyst",
|
||||
}
|
||||
|
||||
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
|
||||
# analyst_key: which analyst selection controls this section (None = always included)
|
||||
# finalizing_agent: which agent must be "completed" for this report to count as done
|
||||
REPORT_SECTIONS = {
|
||||
"market_report": ("market", "Market Analyst"),
|
||||
"sentiment_report": ("social", "Social Analyst"),
|
||||
"news_report": ("news", "News Analyst"),
|
||||
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
|
||||
"investment_plan": (None, "Research Manager"),
|
||||
"trader_investment_plan": (None, "Trader"),
|
||||
"final_trade_decision": (None, "Portfolio Manager"),
|
||||
}
|
||||
|
||||
def __init__(self, max_length=100):
|
||||
self.messages = deque(maxlen=max_length)
|
||||
self.tool_calls = deque(maxlen=max_length)
|
||||
self.current_report = None
|
||||
self.final_report = None # Store the complete final report
|
||||
self.agent_status = {
|
||||
# Analyst Team
|
||||
"Market Analyst": "pending",
|
||||
"Social Analyst": "pending",
|
||||
"News Analyst": "pending",
|
||||
"Fundamentals Analyst": "pending",
|
||||
# Research Team
|
||||
"Bull Researcher": "pending",
|
||||
"Bear Researcher": "pending",
|
||||
"Research Manager": "pending",
|
||||
# Trading Team
|
||||
"Trader": "pending",
|
||||
# Risk Management Team
|
||||
"Aggressive Analyst": "pending",
|
||||
"Neutral Analyst": "pending",
|
||||
"Conservative Analyst": "pending",
|
||||
# Portfolio Management Team
|
||||
"Portfolio Manager": "pending",
|
||||
}
|
||||
self.agent_status = {}
|
||||
self.current_agent = None
|
||||
self.report_sections = {
|
||||
"market_report": None,
|
||||
"sentiment_report": None,
|
||||
"news_report": None,
|
||||
"fundamentals_report": None,
|
||||
"investment_plan": None,
|
||||
"trader_investment_plan": None,
|
||||
"final_trade_decision": None,
|
||||
}
|
||||
self.report_sections = {}
|
||||
self.selected_analysts = []
|
||||
|
||||
def init_for_analysis(self, selected_analysts):
|
||||
"""Initialize agent status and report sections based on selected analysts.
|
||||
|
||||
Args:
|
||||
selected_analysts: List of analyst type strings (e.g., ["market", "news"])
|
||||
"""
|
||||
self.selected_analysts = [a.lower() for a in selected_analysts]
|
||||
|
||||
# Build agent_status dynamically
|
||||
self.agent_status = {}
|
||||
|
||||
# Add selected analysts
|
||||
for analyst_key in self.selected_analysts:
|
||||
if analyst_key in self.ANALYST_MAPPING:
|
||||
self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending"
|
||||
|
||||
# Add fixed teams
|
||||
for team_agents in self.FIXED_AGENTS.values():
|
||||
for agent in team_agents:
|
||||
self.agent_status[agent] = "pending"
|
||||
|
||||
# Build report_sections dynamically
|
||||
self.report_sections = {}
|
||||
for section, (analyst_key, _) in self.REPORT_SECTIONS.items():
|
||||
if analyst_key is None or analyst_key in self.selected_analysts:
|
||||
self.report_sections[section] = None
|
||||
|
||||
# Reset other state
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
self.current_agent = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
|
||||
def get_completed_reports_count(self):
|
||||
"""Count reports that are finalized (their finalizing agent is completed).
|
||||
|
||||
A report is considered complete when:
|
||||
1. The report section has content (not None), AND
|
||||
2. The agent responsible for finalizing that report has status "completed"
|
||||
|
||||
This prevents interim updates (like debate rounds) from counting as completed.
|
||||
"""
|
||||
count = 0
|
||||
for section in self.report_sections:
|
||||
if section not in self.REPORT_SECTIONS:
|
||||
continue
|
||||
_, finalizing_agent = self.REPORT_SECTIONS[section]
|
||||
# Report is complete if it has content AND its finalizing agent is done
|
||||
has_content = self.report_sections.get(section) is not None
|
||||
agent_done = self.agent_status.get(finalizing_agent) == "completed"
|
||||
if has_content and agent_done:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def add_message(self, message_type, content):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
@@ -126,46 +185,39 @@ class MessageBuffer:
|
||||
def _update_final_report(self):
|
||||
report_parts = []
|
||||
|
||||
# Analyst Team Reports
|
||||
if any(
|
||||
self.report_sections[section]
|
||||
for section in [
|
||||
"market_report",
|
||||
"sentiment_report",
|
||||
"news_report",
|
||||
"fundamentals_report",
|
||||
]
|
||||
):
|
||||
# Analyst Team Reports - use .get() to handle missing sections
|
||||
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
|
||||
if any(self.report_sections.get(section) for section in analyst_sections):
|
||||
report_parts.append("## Analyst Team Reports")
|
||||
if self.report_sections["market_report"]:
|
||||
if self.report_sections.get("market_report"):
|
||||
report_parts.append(
|
||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
||||
)
|
||||
if self.report_sections["sentiment_report"]:
|
||||
if self.report_sections.get("sentiment_report"):
|
||||
report_parts.append(
|
||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
||||
)
|
||||
if self.report_sections["news_report"]:
|
||||
if self.report_sections.get("news_report"):
|
||||
report_parts.append(
|
||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
||||
)
|
||||
if self.report_sections["fundamentals_report"]:
|
||||
if self.report_sections.get("fundamentals_report"):
|
||||
report_parts.append(
|
||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
||||
)
|
||||
|
||||
# Research Team Reports
|
||||
if self.report_sections["investment_plan"]:
|
||||
if self.report_sections.get("investment_plan"):
|
||||
report_parts.append("## Research Team Decision")
|
||||
report_parts.append(f"{self.report_sections['investment_plan']}")
|
||||
|
||||
# Trading Team Reports
|
||||
if self.report_sections["trader_investment_plan"]:
|
||||
if self.report_sections.get("trader_investment_plan"):
|
||||
report_parts.append("## Trading Team Plan")
|
||||
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
||||
|
||||
# Portfolio Management Decision
|
||||
if self.report_sections["final_trade_decision"]:
|
||||
if self.report_sections.get("final_trade_decision"):
|
||||
report_parts.append("## Portfolio Management Decision")
|
||||
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
||||
|
||||
@@ -191,7 +243,14 @@ def create_layout():
|
||||
return layout
|
||||
|
||||
|
||||
def update_display(layout, spinner_text=None):
|
||||
def format_tokens(n):
|
||||
"""Format token count for display."""
|
||||
if n >= 1000:
|
||||
return f"{n/1000:.1f}k"
|
||||
return str(n)
|
||||
|
||||
|
||||
def update_display(layout, spinner_text=None, stats_handler=None, start_time=None):
|
||||
# Header with welcome message
|
||||
layout["header"].update(
|
||||
Panel(
|
||||
@@ -218,8 +277,8 @@ def update_display(layout, spinner_text=None):
|
||||
progress_table.add_column("Agent", style="green", justify="center", width=20)
|
||||
progress_table.add_column("Status", style="yellow", justify="center", width=20)
|
||||
|
||||
# Group agents by team
|
||||
teams = {
|
||||
# Group agents by team - filter to only include agents in agent_status
|
||||
all_teams = {
|
||||
"Analyst Team": [
|
||||
"Market Analyst",
|
||||
"Social Analyst",
|
||||
@@ -232,10 +291,17 @@ def update_display(layout, spinner_text=None):
|
||||
"Portfolio Management": ["Portfolio Manager"],
|
||||
}
|
||||
|
||||
# Filter teams to only include agents that are in agent_status
|
||||
teams = {}
|
||||
for team, agents in all_teams.items():
|
||||
active_agents = [a for a in agents if a in message_buffer.agent_status]
|
||||
if active_agents:
|
||||
teams[team] = active_agents
|
||||
|
||||
for team, agents in teams.items():
|
||||
# Add first agent with team name
|
||||
first_agent = agents[0]
|
||||
status = message_buffer.agent_status[first_agent]
|
||||
status = message_buffer.agent_status.get(first_agent, "pending")
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||
@@ -252,7 +318,7 @@ def update_display(layout, spinner_text=None):
|
||||
|
||||
# Add remaining agents in team
|
||||
for agent in agents[1:]:
|
||||
status = message_buffer.agent_status[agent]
|
||||
status = message_buffer.agent_status.get(agent, "pending")
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||
@@ -379,19 +445,43 @@ def update_display(layout, spinner_text=None):
|
||||
)
|
||||
|
||||
# Footer with statistics
|
||||
tool_calls_count = len(message_buffer.tool_calls)
|
||||
llm_calls_count = sum(
|
||||
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
|
||||
)
|
||||
reports_count = sum(
|
||||
1 for content in message_buffer.report_sections.values() if content is not None
|
||||
# Agent progress - derived from agent_status dict
|
||||
agents_completed = sum(
|
||||
1 for status in message_buffer.agent_status.values() if status == "completed"
|
||||
)
|
||||
agents_total = len(message_buffer.agent_status)
|
||||
|
||||
# Report progress - based on agent completion (not just content existence)
|
||||
reports_completed = message_buffer.get_completed_reports_count()
|
||||
reports_total = len(message_buffer.report_sections)
|
||||
|
||||
# Build stats parts
|
||||
stats_parts = [f"Agents: {agents_completed}/{agents_total}"]
|
||||
|
||||
# LLM and tool stats from callback handler
|
||||
if stats_handler:
|
||||
stats = stats_handler.get_stats()
|
||||
stats_parts.append(f"LLM: {stats['llm_calls']}")
|
||||
stats_parts.append(f"Tools: {stats['tool_calls']}")
|
||||
|
||||
# Token display with graceful fallback
|
||||
if stats["tokens_in"] > 0 or stats["tokens_out"] > 0:
|
||||
tokens_str = f"Tokens: {format_tokens(stats['tokens_in'])}\u2191 {format_tokens(stats['tokens_out'])}\u2193"
|
||||
else:
|
||||
tokens_str = "Tokens: --"
|
||||
stats_parts.append(tokens_str)
|
||||
|
||||
stats_parts.append(f"Reports: {reports_completed}/{reports_total}")
|
||||
|
||||
# Elapsed time
|
||||
if start_time:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed_str = f"\u23f1 {int(elapsed // 60):02d}:{int(elapsed % 60):02d}"
|
||||
stats_parts.append(elapsed_str)
|
||||
|
||||
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
|
||||
stats_table.add_column("Stats", justify="center")
|
||||
stats_table.add_row(
|
||||
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
|
||||
)
|
||||
stats_table.add_row(" | ".join(stats_parts))
|
||||
|
||||
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
||||
|
||||
@@ -803,11 +893,24 @@ def run_analysis():
|
||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
|
||||
# Initialize the graph
|
||||
# Create stats callback handler for tracking LLM/tool calls
|
||||
stats_handler = StatsCallbackHandler()
|
||||
|
||||
# Initialize the graph with callbacks bound to LLMs
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
||||
[analyst.value for analyst in selections["analysts"]],
|
||||
config=config,
|
||||
debug=True,
|
||||
callbacks=[stats_handler],
|
||||
)
|
||||
|
||||
# Initialize message buffer with selected analysts
|
||||
selected_analyst_keys = [analyst.value for analyst in selections["analysts"]]
|
||||
message_buffer.init_for_analysis(selected_analyst_keys)
|
||||
|
||||
# Track start time for elapsed display
|
||||
start_time = time.time()
|
||||
|
||||
# Create result directory
|
||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -860,7 +963,7 @@ def run_analysis():
|
||||
|
||||
with Live(layout, refresh_per_second=4) as live:
|
||||
# Initial display
|
||||
update_display(layout)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Add initial messages
|
||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||
@@ -871,34 +974,26 @@ def run_analysis():
|
||||
"System",
|
||||
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
||||
)
|
||||
update_display(layout)
|
||||
|
||||
# Reset agent statuses
|
||||
for agent in message_buffer.agent_status:
|
||||
message_buffer.update_agent_status(agent, "pending")
|
||||
|
||||
# Reset report sections
|
||||
for section in message_buffer.report_sections:
|
||||
message_buffer.report_sections[section] = None
|
||||
message_buffer.current_report = None
|
||||
message_buffer.final_report = None
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Update agent status to in_progress for the first analyst
|
||||
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
|
||||
message_buffer.update_agent_status(first_analyst, "in_progress")
|
||||
update_display(layout)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Create spinner text
|
||||
spinner_text = (
|
||||
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
|
||||
)
|
||||
update_display(layout, spinner_text)
|
||||
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Initialize state and get graph args
|
||||
# Initialize state and get graph args with callbacks
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
selections["ticker"], selections["analysis_date"]
|
||||
)
|
||||
args = graph.propagator.get_graph_args()
|
||||
# Pass callbacks to graph config for tool execution tracking
|
||||
# (LLM tracking is handled separately via LLM constructor)
|
||||
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
|
||||
|
||||
# Stream the analysis
|
||||
trace = []
|
||||
@@ -1112,7 +1207,7 @@ def run_analysis():
|
||||
)
|
||||
|
||||
# Update the display
|
||||
update_display(layout)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
trace.append(chunk)
|
||||
|
||||
@@ -1136,7 +1231,7 @@ def run_analysis():
|
||||
# Display the complete final report
|
||||
display_complete_report(final_state)
|
||||
|
||||
update_display(layout)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
Reference in New Issue
Block a user