mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
merge upstream main into crypto-analysis-mvp
This commit is contained in:
117
cli/main.py
117
cli/main.py
@@ -6,8 +6,9 @@ from functools import wraps
|
||||
from rich.console import Console
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
load_dotenv(".env.enterprise", override=False)
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
@@ -79,7 +80,7 @@ class MessageBuffer:
|
||||
self.current_agent = None
|
||||
self.report_sections = {}
|
||||
self.selected_analysts = []
|
||||
self._last_message_id = None
|
||||
self._processed_message_ids = set()
|
||||
|
||||
def init_for_analysis(self, selected_analysts):
|
||||
"""Initialize agent status and report sections based on selected analysts.
|
||||
@@ -114,7 +115,7 @@ class MessageBuffer:
|
||||
self.current_agent = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
self._last_message_id = None
|
||||
self._processed_message_ids.clear()
|
||||
|
||||
def get_completed_reports_count(self):
|
||||
"""Count reports that are finalized (their finalizing agent is completed).
|
||||
@@ -462,7 +463,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
|
||||
def get_user_selections():
|
||||
"""Get all user selections before starting the analysis display."""
|
||||
# Display ASCII art welcome message
|
||||
with open("./cli/static/welcome.txt", "r", encoding="utf-8") as f:
|
||||
with open(Path(__file__).parent / "static" / "welcome.txt", "r") as f:
|
||||
welcome_ascii = f.read()
|
||||
|
||||
# Create welcome box content
|
||||
@@ -501,7 +502,9 @@ def get_user_selections():
|
||||
# Step 1: Ticker symbol
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
|
||||
"Step 1: Ticker Symbol",
|
||||
"Enter the exact ticker symbol to analyze, including exchange suffix when needed (examples: SPY, CNC.TO, 7203.T, 0700.HK)",
|
||||
"SPY",
|
||||
)
|
||||
)
|
||||
selected_ticker = get_ticker()
|
||||
@@ -521,10 +524,19 @@ def get_user_selections():
|
||||
)
|
||||
analysis_date = get_analysis_date()
|
||||
|
||||
# Step 3: Select analysts
|
||||
# Step 3: Output language
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||
"Step 3: Output Language",
|
||||
"Select the language for analyst reports and final decision"
|
||||
)
|
||||
)
|
||||
output_language = ask_output_language()
|
||||
|
||||
# Step 4: Select analysts
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 4: Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||
)
|
||||
)
|
||||
selected_analysts = select_analysts(asset_type)
|
||||
@@ -532,40 +544,41 @@ def get_user_selections():
|
||||
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
||||
)
|
||||
|
||||
# Step 4: Research depth
|
||||
# Step 5: Research depth
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 4: Research Depth", "Select your research depth level"
|
||||
"Step 5: Research Depth", "Select your research depth level"
|
||||
)
|
||||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
||||
# Step 5: OpenAI backend
|
||||
# Step 6: LLM Provider
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 5: OpenAI backend", "Select which service to talk to"
|
||||
"Step 6: LLM Provider", "Select your LLM provider"
|
||||
)
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
||||
# Step 6: Thinking agents
|
||||
|
||||
# Step 7: Thinking agents
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
|
||||
"Step 7: Thinking Agents", "Select your thinking agents for analysis"
|
||||
)
|
||||
)
|
||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
||||
|
||||
# Step 7: Provider-specific thinking configuration
|
||||
# Step 8: Provider-specific thinking configuration
|
||||
thinking_level = None
|
||||
reasoning_effort = None
|
||||
anthropic_effort = None
|
||||
|
||||
provider_lower = selected_llm_provider.lower()
|
||||
if provider_lower == "google":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Thinking Mode",
|
||||
"Step 8: Thinking Mode",
|
||||
"Configure Gemini thinking mode"
|
||||
)
|
||||
)
|
||||
@@ -573,11 +586,19 @@ def get_user_selections():
|
||||
elif provider_lower == "openai":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Reasoning Effort",
|
||||
"Step 8: Reasoning Effort",
|
||||
"Configure OpenAI reasoning effort level"
|
||||
)
|
||||
)
|
||||
reasoning_effort = ask_openai_reasoning_effort()
|
||||
elif provider_lower == "anthropic":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 8: Effort Level",
|
||||
"Configure Claude effort level"
|
||||
)
|
||||
)
|
||||
anthropic_effort = ask_anthropic_effort()
|
||||
|
||||
return {
|
||||
"ticker": selected_ticker,
|
||||
@@ -591,6 +612,8 @@ def get_user_selections():
|
||||
"deep_thinker": selected_deep_thinker,
|
||||
"google_thinking_level": thinking_level,
|
||||
"openai_reasoning_effort": reasoning_effort,
|
||||
"anthropic_effort": anthropic_effort,
|
||||
"output_language": output_language,
|
||||
}
|
||||
|
||||
|
||||
@@ -793,9 +816,11 @@ ANALYST_REPORT_MAP = {
|
||||
|
||||
|
||||
def update_analyst_statuses(message_buffer, chunk):
|
||||
"""Update all analyst statuses based on current report state.
|
||||
"""Update analyst statuses based on accumulated report state.
|
||||
|
||||
Logic:
|
||||
- Store new report content from the current chunk if present
|
||||
- Check accumulated report_sections (not just current chunk) for status
|
||||
- Analysts with reports = completed
|
||||
- First analyst without report = in_progress
|
||||
- Remaining analysts without reports = pending
|
||||
@@ -810,11 +835,16 @@ def update_analyst_statuses(message_buffer, chunk):
|
||||
|
||||
agent_name = ANALYST_AGENT_NAMES[analyst_key]
|
||||
report_key = ANALYST_REPORT_MAP[analyst_key]
|
||||
has_report = bool(chunk.get(report_key))
|
||||
|
||||
# Capture new report content from current chunk
|
||||
if chunk.get(report_key):
|
||||
message_buffer.update_report_section(report_key, chunk[report_key])
|
||||
|
||||
# Determine status from accumulated sections, not just current chunk
|
||||
has_report = bool(message_buffer.report_sections.get(report_key))
|
||||
|
||||
if has_report:
|
||||
message_buffer.update_agent_status(agent_name, "completed")
|
||||
message_buffer.update_report_section(report_key, chunk[report_key])
|
||||
elif not found_active:
|
||||
message_buffer.update_agent_status(agent_name, "in_progress")
|
||||
found_active = True
|
||||
@@ -916,6 +946,8 @@ def run_analysis():
|
||||
# Provider-specific thinking configuration
|
||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
config["anthropic_effort"] = selections.get("anthropic_effort")
|
||||
config["output_language"] = selections.get("output_language", "English")
|
||||
|
||||
# Create stats callback handler for tracking LLM/tool calls
|
||||
stats_handler = StatsCallbackHandler()
|
||||
@@ -953,7 +985,7 @@ def run_analysis():
|
||||
func(*args, **kwargs)
|
||||
timestamp, message_type, content = obj.messages[-1]
|
||||
content = content.replace("\n", " ") # Replace newlines with spaces
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
with open(log_file, "a") as f:
|
||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||
return wrapper
|
||||
|
||||
@@ -964,7 +996,7 @@ def run_analysis():
|
||||
func(*args, **kwargs)
|
||||
timestamp, tool_name, args = obj.tool_calls[-1]
|
||||
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
with open(log_file, "a") as f:
|
||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
||||
return wrapper
|
||||
|
||||
@@ -977,8 +1009,9 @@ def run_analysis():
|
||||
content = obj.report_sections[section_name]
|
||||
if content:
|
||||
file_name = f"{section_name}.md"
|
||||
with open(report_dir / file_name, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
text = "\n".join(str(item) for item in content) if isinstance(content, list) else content
|
||||
with open(report_dir / file_name, "w") as f:
|
||||
f.write(text)
|
||||
return wrapper
|
||||
|
||||
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
||||
@@ -1028,28 +1061,24 @@ def run_analysis():
|
||||
# Stream the analysis
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
# Process messages if present (skip duplicates via message ID)
|
||||
if len(chunk["messages"]) > 0:
|
||||
last_message = chunk["messages"][-1]
|
||||
msg_id = getattr(last_message, "id", None)
|
||||
# Process all messages in chunk, deduplicating by message ID
|
||||
for message in chunk.get("messages", []):
|
||||
msg_id = getattr(message, "id", None)
|
||||
if msg_id is not None:
|
||||
if msg_id in message_buffer._processed_message_ids:
|
||||
continue
|
||||
message_buffer._processed_message_ids.add(msg_id)
|
||||
|
||||
if msg_id != message_buffer._last_message_id:
|
||||
message_buffer._last_message_id = msg_id
|
||||
msg_type, content = classify_message_type(message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# Add message to buffer
|
||||
msg_type, content = classify_message_type(last_message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# Handle tool calls
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(
|
||||
tool_call["name"], tool_call["args"]
|
||||
)
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
||||
# Update analyst statuses based on report state (runs on every chunk)
|
||||
update_analyst_statuses(message_buffer, chunk)
|
||||
|
||||
Reference in New Issue
Block a user