mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
Merge remote-tracking branch 'upstream/main' into analyst-phase1-observability
# Conflicts: # tradingagents/default_config.py # tradingagents/graph/setup.py
This commit is contained in:
157
cli/main.py
157
cli/main.py
@@ -1,13 +1,10 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import typer
|
||||
import questionary
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
from rich.console import Console
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
@@ -58,7 +55,7 @@ class MessageBuffer:
|
||||
# Analyst name mapping
|
||||
ANALYST_MAPPING = {
|
||||
"market": "Market Analyst",
|
||||
"social": "Social Analyst",
|
||||
"social": "Sentiment Analyst",
|
||||
"news": "News Analyst",
|
||||
"fundamentals": "Fundamentals Analyst",
|
||||
}
|
||||
@@ -68,7 +65,7 @@ class MessageBuffer:
|
||||
# finalizing_agent: which agent must be "completed" for this report to count as done
|
||||
REPORT_SECTIONS = {
|
||||
"market_report": ("market", "Market Analyst"),
|
||||
"sentiment_report": ("social", "Social Analyst"),
|
||||
"sentiment_report": ("social", "Sentiment Analyst"),
|
||||
"news_report": ("news", "News Analyst"),
|
||||
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
|
||||
"investment_plan": (None, "Research Manager"),
|
||||
@@ -85,7 +82,7 @@ class MessageBuffer:
|
||||
self.current_agent = None
|
||||
self.report_sections = {}
|
||||
self.selected_analysts = []
|
||||
self._last_message_id = None
|
||||
self._processed_message_ids = set()
|
||||
|
||||
def init_for_analysis(self, selected_analysts):
|
||||
"""Initialize agent status and report sections based on selected analysts.
|
||||
@@ -120,7 +117,7 @@ class MessageBuffer:
|
||||
self.current_agent = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
self._last_message_id = None
|
||||
self._processed_message_ids.clear()
|
||||
|
||||
def get_completed_reports_count(self):
|
||||
"""Count reports that are finalized (their finalizing agent is completed).
|
||||
@@ -289,7 +286,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
|
||||
all_teams = {
|
||||
"Analyst Team": [
|
||||
"Market Analyst",
|
||||
"Social Analyst",
|
||||
"Sentiment Analyst",
|
||||
"News Analyst",
|
||||
"Fundamentals Analyst",
|
||||
],
|
||||
@@ -468,7 +465,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
|
||||
def get_user_selections():
|
||||
"""Get all user selections before starting the analysis display."""
|
||||
# Display ASCII art welcome message
|
||||
with open(Path(__file__).parent / "static" / "welcome.txt", "r") as f:
|
||||
with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f:
|
||||
welcome_ascii = f.read()
|
||||
|
||||
# Create welcome box content
|
||||
@@ -561,6 +558,21 @@ def get_user_selections():
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
||||
# Providers with regional endpoints prompt for the region as a secondary
|
||||
# step so the main dropdown stays clean (mainland China and international
|
||||
# accounts cannot share API keys).
|
||||
if selected_llm_provider == "qwen":
|
||||
selected_llm_provider, backend_url = ask_qwen_region()
|
||||
elif selected_llm_provider == "minimax":
|
||||
selected_llm_provider, backend_url = ask_minimax_region()
|
||||
elif selected_llm_provider == "glm":
|
||||
selected_llm_provider, backend_url = ask_glm_region()
|
||||
|
||||
# Confirm the provider's API key is present; prompt the user to paste
|
||||
# one and persist it to .env if it's missing, so the analysis run
|
||||
# doesn't fail later at the first API call.
|
||||
ensure_api_key(selected_llm_provider)
|
||||
|
||||
# Step 7: Thinking agents
|
||||
console.print(
|
||||
create_question_box(
|
||||
@@ -618,8 +630,26 @@ def get_user_selections():
|
||||
|
||||
|
||||
def get_ticker():
|
||||
"""Get ticker symbol from user input."""
|
||||
return typer.prompt("", default="SPY")
|
||||
"""Get ticker symbol from user input, preserving exchange suffixes."""
|
||||
# typer.prompt strips trailing dot-suffixes on some shells (e.g. 000404.SH
|
||||
# collapses to 000404). questionary.text reads the raw line.
|
||||
ticker = questionary.text(
|
||||
"",
|
||||
validate=lambda value: (
|
||||
not value.strip()
|
||||
or (
|
||||
all(ch.isalnum() or ch in "._-^" for ch in value.strip())
|
||||
and len(value.strip()) <= 32
|
||||
)
|
||||
)
|
||||
or "Please enter a valid ticker symbol, e.g. AAPL, 000404.SZ, 0700.HK.",
|
||||
).ask()
|
||||
|
||||
if ticker is None:
|
||||
console.print("\n[red]No ticker symbol provided. Exiting...[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
return (ticker.strip() or "SPY").upper()
|
||||
|
||||
|
||||
def get_analysis_date():
|
||||
@@ -651,19 +681,19 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||
analyst_parts = []
|
||||
if final_state.get("market_report"):
|
||||
analysts_dir.mkdir(exist_ok=True)
|
||||
(analysts_dir / "market.md").write_text(final_state["market_report"])
|
||||
(analysts_dir / "market.md").write_text(final_state["market_report"], encoding="utf-8")
|
||||
analyst_parts.append(("Market Analyst", final_state["market_report"]))
|
||||
if final_state.get("sentiment_report"):
|
||||
analysts_dir.mkdir(exist_ok=True)
|
||||
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"])
|
||||
analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
|
||||
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"], encoding="utf-8")
|
||||
analyst_parts.append(("Sentiment Analyst", final_state["sentiment_report"]))
|
||||
if final_state.get("news_report"):
|
||||
analysts_dir.mkdir(exist_ok=True)
|
||||
(analysts_dir / "news.md").write_text(final_state["news_report"])
|
||||
(analysts_dir / "news.md").write_text(final_state["news_report"], encoding="utf-8")
|
||||
analyst_parts.append(("News Analyst", final_state["news_report"]))
|
||||
if final_state.get("fundamentals_report"):
|
||||
analysts_dir.mkdir(exist_ok=True)
|
||||
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"])
|
||||
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"], encoding="utf-8")
|
||||
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
|
||||
if analyst_parts:
|
||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
|
||||
@@ -676,15 +706,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||
research_parts = []
|
||||
if debate.get("bull_history"):
|
||||
research_dir.mkdir(exist_ok=True)
|
||||
(research_dir / "bull.md").write_text(debate["bull_history"])
|
||||
(research_dir / "bull.md").write_text(debate["bull_history"], encoding="utf-8")
|
||||
research_parts.append(("Bull Researcher", debate["bull_history"]))
|
||||
if debate.get("bear_history"):
|
||||
research_dir.mkdir(exist_ok=True)
|
||||
(research_dir / "bear.md").write_text(debate["bear_history"])
|
||||
(research_dir / "bear.md").write_text(debate["bear_history"], encoding="utf-8")
|
||||
research_parts.append(("Bear Researcher", debate["bear_history"]))
|
||||
if debate.get("judge_decision"):
|
||||
research_dir.mkdir(exist_ok=True)
|
||||
(research_dir / "manager.md").write_text(debate["judge_decision"])
|
||||
(research_dir / "manager.md").write_text(debate["judge_decision"], encoding="utf-8")
|
||||
research_parts.append(("Research Manager", debate["judge_decision"]))
|
||||
if research_parts:
|
||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
|
||||
@@ -694,7 +724,7 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||
if final_state.get("trader_investment_plan"):
|
||||
trading_dir = save_path / "3_trading"
|
||||
trading_dir.mkdir(exist_ok=True)
|
||||
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"])
|
||||
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"], encoding="utf-8")
|
||||
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
|
||||
|
||||
# 4. Risk Management
|
||||
@@ -704,15 +734,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||
risk_parts = []
|
||||
if risk.get("aggressive_history"):
|
||||
risk_dir.mkdir(exist_ok=True)
|
||||
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"])
|
||||
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"], encoding="utf-8")
|
||||
risk_parts.append(("Aggressive Analyst", risk["aggressive_history"]))
|
||||
if risk.get("conservative_history"):
|
||||
risk_dir.mkdir(exist_ok=True)
|
||||
(risk_dir / "conservative.md").write_text(risk["conservative_history"])
|
||||
(risk_dir / "conservative.md").write_text(risk["conservative_history"], encoding="utf-8")
|
||||
risk_parts.append(("Conservative Analyst", risk["conservative_history"]))
|
||||
if risk.get("neutral_history"):
|
||||
risk_dir.mkdir(exist_ok=True)
|
||||
(risk_dir / "neutral.md").write_text(risk["neutral_history"])
|
||||
(risk_dir / "neutral.md").write_text(risk["neutral_history"], encoding="utf-8")
|
||||
risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
|
||||
if risk_parts:
|
||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
|
||||
@@ -722,12 +752,12 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||
if risk.get("judge_decision"):
|
||||
portfolio_dir = save_path / "5_portfolio"
|
||||
portfolio_dir.mkdir(exist_ok=True)
|
||||
(portfolio_dir / "decision.md").write_text(risk["judge_decision"])
|
||||
(portfolio_dir / "decision.md").write_text(risk["judge_decision"], encoding="utf-8")
|
||||
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
|
||||
|
||||
# Write consolidated report
|
||||
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections))
|
||||
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections), encoding="utf-8")
|
||||
return save_path / "complete_report.md"
|
||||
|
||||
|
||||
@@ -741,7 +771,7 @@ def display_complete_report(final_state):
|
||||
if final_state.get("market_report"):
|
||||
analysts.append(("Market Analyst", final_state["market_report"]))
|
||||
if final_state.get("sentiment_report"):
|
||||
analysts.append(("Social Analyst", final_state["sentiment_report"]))
|
||||
analysts.append(("Sentiment Analyst", final_state["sentiment_report"]))
|
||||
if final_state.get("news_report"):
|
||||
analysts.append(("News Analyst", final_state["news_report"]))
|
||||
if final_state.get("fundamentals_report"):
|
||||
@@ -803,7 +833,7 @@ def update_research_team_status(status):
|
||||
ANALYST_ORDER = ["market", "social", "news", "fundamentals"]
|
||||
ANALYST_AGENT_NAMES = {
|
||||
"market": "Market Analyst",
|
||||
"social": "Social Analyst",
|
||||
"social": "Sentiment Analyst",
|
||||
"news": "News Analyst",
|
||||
"fundamentals": "Fundamentals Analyst",
|
||||
}
|
||||
@@ -934,7 +964,7 @@ def format_tool_args(args, max_length=80) -> str:
|
||||
return result[:max_length - 3] + "..."
|
||||
return result
|
||||
|
||||
def run_analysis():
|
||||
def run_analysis(checkpoint: bool = False):
|
||||
# First get all user selections
|
||||
selections = get_user_selections()
|
||||
|
||||
@@ -951,6 +981,7 @@ def run_analysis():
|
||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
config["anthropic_effort"] = selections.get("anthropic_effort")
|
||||
config["output_language"] = selections.get("output_language", "English")
|
||||
config["checkpoint_enabled"] = checkpoint
|
||||
|
||||
# Create stats callback handler for tracking LLM/tool calls
|
||||
stats_handler = StatsCallbackHandler()
|
||||
@@ -993,7 +1024,7 @@ def run_analysis():
|
||||
func(*args, **kwargs)
|
||||
timestamp, message_type, content = obj.messages[-1]
|
||||
content = content.replace("\n", " ") # Replace newlines with spaces
|
||||
with open(log_file, "a") as f:
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||
return wrapper
|
||||
|
||||
@@ -1004,7 +1035,7 @@ def run_analysis():
|
||||
func(*args, **kwargs)
|
||||
timestamp, tool_name, args = obj.tool_calls[-1]
|
||||
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
||||
with open(log_file, "a") as f:
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
||||
return wrapper
|
||||
|
||||
@@ -1018,7 +1049,7 @@ def run_analysis():
|
||||
if content:
|
||||
file_name = f"{section_name}.md"
|
||||
text = "\n".join(str(item) for item in content) if isinstance(content, list) else content
|
||||
with open(report_dir / file_name, "w") as f:
|
||||
with open(report_dir / file_name, "w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
return wrapper
|
||||
|
||||
@@ -1067,28 +1098,24 @@ def run_analysis():
|
||||
# Stream the analysis
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
# Process messages if present (skip duplicates via message ID)
|
||||
if len(chunk["messages"]) > 0:
|
||||
last_message = chunk["messages"][-1]
|
||||
msg_id = getattr(last_message, "id", None)
|
||||
# Process all messages in chunk, deduplicating by message ID
|
||||
for message in chunk.get("messages", []):
|
||||
msg_id = getattr(message, "id", None)
|
||||
if msg_id is not None:
|
||||
if msg_id in message_buffer._processed_message_ids:
|
||||
continue
|
||||
message_buffer._processed_message_ids.add(msg_id)
|
||||
|
||||
if msg_id != message_buffer._last_message_id:
|
||||
message_buffer._last_message_id = msg_id
|
||||
msg_type, content = classify_message_type(message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# Add message to buffer
|
||||
msg_type, content = classify_message_type(last_message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# Handle tool calls
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(
|
||||
tool_call["name"], tool_call["args"]
|
||||
)
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
||||
# Update analyst statuses based on report state (runs on every chunk)
|
||||
update_analyst_statuses(
|
||||
@@ -1173,8 +1200,11 @@ def run_analysis():
|
||||
|
||||
trace.append(chunk)
|
||||
|
||||
# Get final state and decision
|
||||
final_state = trace[-1]
|
||||
# Streamed chunks are per-node deltas, not full state. Merge them
|
||||
# so every report field populated across the run is present.
|
||||
final_state = {}
|
||||
for chunk in trace:
|
||||
final_state.update(chunk)
|
||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
# Update all agent statuses to completed
|
||||
@@ -1221,8 +1251,23 @@ def run_analysis():
|
||||
|
||||
|
||||
@app.command()
|
||||
def analyze():
|
||||
run_analysis()
|
||||
def analyze(
|
||||
checkpoint: bool = typer.Option(
|
||||
False,
|
||||
"--checkpoint",
|
||||
help="Enable checkpoint/resume: save state after each node so a crashed run can resume.",
|
||||
),
|
||||
clear_checkpoints: bool = typer.Option(
|
||||
False,
|
||||
"--clear-checkpoints",
|
||||
help="Delete all saved checkpoints before running (force fresh start).",
|
||||
),
|
||||
):
|
||||
if clear_checkpoints:
|
||||
from tradingagents.graph.checkpointer import clear_all_checkpoints
|
||||
n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"])
|
||||
console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]")
|
||||
run_analysis(checkpoint=checkpoint)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,6 +5,8 @@ from pydantic import BaseModel
|
||||
|
||||
class AnalystType(str, Enum):
|
||||
MARKET = "market"
|
||||
# Wire value stays "social" for saved-config and string-keyed-caller
|
||||
# back-compat; the user-facing label is "Sentiment Analyst".
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
FUNDAMENTALS = "fundamentals"
|
||||
|
||||
273
cli/utils.py
273
cli/utils.py
@@ -1,9 +1,13 @@
|
||||
import questionary
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
import questionary
|
||||
from dotenv import find_dotenv, set_key
|
||||
from rich.console import Console
|
||||
|
||||
from cli.models import AnalystType
|
||||
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||
|
||||
console = Console()
|
||||
@@ -12,7 +16,7 @@ TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
|
||||
|
||||
ANALYST_ORDER = [
|
||||
("Market Analyst", AnalystType.MARKET),
|
||||
("Social Media Analyst", AnalystType.SOCIAL),
|
||||
("Sentiment Analyst", AnalystType.SOCIAL),
|
||||
("News Analyst", AnalystType.NEWS),
|
||||
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
|
||||
]
|
||||
@@ -134,14 +138,70 @@ def select_research_depth() -> int:
|
||||
return choice
|
||||
|
||||
|
||||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
def _fetch_openrouter_models() -> List[Tuple[str, str]]:
|
||||
"""Fetch available models from the OpenRouter API."""
|
||||
import requests
|
||||
try:
|
||||
resp = requests.get("https://openrouter.ai/api/v1/models", timeout=10)
|
||||
resp.raise_for_status()
|
||||
models = resp.json().get("data", [])
|
||||
return [(m.get("name") or m["id"], m["id"]) for m in models]
|
||||
except Exception as e:
|
||||
console.print(f"\n[yellow]Could not fetch OpenRouter models: {e}[/yellow]")
|
||||
return []
|
||||
|
||||
|
||||
def select_openrouter_model() -> str:
|
||||
"""Select an OpenRouter model from the newest available, or enter a custom ID."""
|
||||
models = _fetch_openrouter_models()
|
||||
|
||||
choices = [questionary.Choice(name, value=mid) for name, mid in models[:5]]
|
||||
choices.append(questionary.Choice("Custom model ID", value="custom"))
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Quick-Thinking LLM Engine]:",
|
||||
"Select OpenRouter Model (latest available):",
|
||||
choices=choices,
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style([
|
||||
("selected", "fg:magenta noinherit"),
|
||||
("highlighted", "fg:magenta noinherit"),
|
||||
("pointer", "fg:magenta noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
|
||||
if choice is None or choice == "custom":
|
||||
return questionary.text(
|
||||
"Enter OpenRouter model ID (e.g. google/gemma-4-26b-a4b-it):",
|
||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||
).ask().strip()
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
def _prompt_custom_model_id() -> str:
|
||||
"""Prompt user to type a custom model ID."""
|
||||
return questionary.text(
|
||||
"Enter model ID:",
|
||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||
).ask().strip()
|
||||
|
||||
|
||||
def _select_model(provider: str, mode: str) -> str:
|
||||
"""Select a model for the given provider and mode (quick/deep)."""
|
||||
if provider.lower() == "openrouter":
|
||||
return select_openrouter_model()
|
||||
|
||||
if provider.lower() == "azure":
|
||||
return questionary.text(
|
||||
f"Enter Azure deployment name ({mode}-thinking):",
|
||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
|
||||
).ask().strip()
|
||||
|
||||
choice = questionary.select(
|
||||
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in get_model_options(provider, "quick")
|
||||
for display, value in get_model_options(provider, mode)
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
@@ -154,56 +214,46 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print(
|
||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
||||
)
|
||||
console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
if choice == "custom":
|
||||
return _prompt_custom_model_id()
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
return _select_model(provider, "quick")
|
||||
|
||||
|
||||
def select_deep_thinking_agent(provider) -> str:
|
||||
"""Select deep thinking llm engine using an interactive selection."""
|
||||
return _select_model(provider, "deep")
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Deep-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in get_model_options(provider, "deep")
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:magenta noinherit"),
|
||||
("highlighted", "fg:magenta noinherit"),
|
||||
("pointer", "fg:magenta noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
"""Select the OpenAI api url using interactive selection."""
|
||||
# Define OpenAI api options with their corresponding endpoints
|
||||
BASE_URLS = [
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "https://api.x.ai/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
def select_llm_provider() -> tuple[str, str | None]:
|
||||
"""Select the LLM provider and its API endpoint."""
|
||||
# (display_name, provider_key, base_url)
|
||||
PROVIDERS = [
|
||||
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||
("Google", "google", None),
|
||||
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "xai", "https://api.x.ai/v1"),
|
||||
("DeepSeek", "deepseek", "https://api.deepseek.com"),
|
||||
("Qwen", "qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"),
|
||||
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||
("MiniMax", "minimax", "https://api.minimax.io/v1"),
|
||||
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Azure OpenAI", "azure", None),
|
||||
("Ollama", "ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
||||
|
||||
choice = questionary.select(
|
||||
"Select your LLM Provider:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=(display, value))
|
||||
for display, value in BASE_URLS
|
||||
questionary.Choice(display, value=(provider_key, url))
|
||||
for display, provider_key, url in PROVIDERS
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
@@ -216,13 +266,11 @@ def select_llm_provider() -> tuple[str, str]:
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
display_name, url = choice
|
||||
print(f"You selected: {display_name}\tURL: {url}")
|
||||
|
||||
return display_name, url
|
||||
provider, url = choice
|
||||
return provider, url
|
||||
|
||||
|
||||
def ask_openai_reasoning_effort() -> str:
|
||||
@@ -246,7 +294,9 @@ def ask_openai_reasoning_effort() -> str:
|
||||
def ask_anthropic_effort() -> str | None:
|
||||
"""Ask for Anthropic effort level.
|
||||
|
||||
Controls token usage and response thoroughness on Claude 4.5+ and 4.6 models.
|
||||
Controls token usage and response thoroughness on Claude 4.5 / 4.6 / 4.7
|
||||
models. The API also accepts "max"; we expose low/medium/high as the
|
||||
common selection range.
|
||||
"""
|
||||
return questionary.select(
|
||||
"Select Effort Level:",
|
||||
@@ -283,6 +333,129 @@ def ask_gemini_thinking_config() -> str | None:
|
||||
).ask()
|
||||
|
||||
|
||||
def ask_glm_region() -> tuple[str, str]:
|
||||
"""Ask which GLM platform (Z.AI international vs BigModel China) to use.
|
||||
|
||||
Zhipu serves the same GLM models under two brands with separate
|
||||
accounts; keys aren't interchangeable. Returns (provider_key, backend_url).
|
||||
"""
|
||||
return questionary.select(
|
||||
"Select GLM platform:",
|
||||
choices=[
|
||||
questionary.Choice(
|
||||
"Z.AI — api.z.ai (international, uses ZHIPU_API_KEY)",
|
||||
value=("glm", "https://api.z.ai/api/paas/v4/"),
|
||||
),
|
||||
questionary.Choice(
|
||||
"BigModel — open.bigmodel.cn (China, uses ZHIPU_CN_API_KEY)",
|
||||
value=("glm-cn", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||
),
|
||||
],
|
||||
style=questionary.Style([
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "fg:cyan noinherit"),
|
||||
("pointer", "fg:cyan noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
|
||||
|
||||
def ask_qwen_region() -> tuple[str, str]:
|
||||
"""Ask which Qwen region (international vs China) to use.
|
||||
|
||||
Alibaba DashScope exposes two endpoints with separate accounts —
|
||||
a key from one region does NOT authenticate against the other
|
||||
(fixes #758). Returns (provider_key, backend_url).
|
||||
"""
|
||||
return questionary.select(
|
||||
"Select Qwen region:",
|
||||
choices=[
|
||||
questionary.Choice(
|
||||
"International — dashscope-intl.aliyuncs.com (uses DASHSCOPE_API_KEY)",
|
||||
value=("qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"),
|
||||
),
|
||||
questionary.Choice(
|
||||
"China — dashscope.aliyuncs.com (uses DASHSCOPE_CN_API_KEY)",
|
||||
value=("qwen-cn", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
),
|
||||
],
|
||||
style=questionary.Style([
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "fg:cyan noinherit"),
|
||||
("pointer", "fg:cyan noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
|
||||
|
||||
def ask_minimax_region() -> tuple[str, str]:
|
||||
"""Ask which MiniMax region (global vs China) to use.
|
||||
|
||||
MiniMax exposes two endpoints with separate accounts — a key from
|
||||
one region does NOT authenticate against the other. Returns
|
||||
(provider_key, backend_url).
|
||||
"""
|
||||
return questionary.select(
|
||||
"Select MiniMax region:",
|
||||
choices=[
|
||||
questionary.Choice(
|
||||
"Global — api.minimax.io (uses MINIMAX_API_KEY)",
|
||||
value=("minimax", "https://api.minimax.io/v1"),
|
||||
),
|
||||
questionary.Choice(
|
||||
"China — api.minimaxi.com (uses MINIMAX_CN_API_KEY)",
|
||||
value=("minimax-cn", "https://api.minimaxi.com/v1"),
|
||||
),
|
||||
],
|
||||
style=questionary.Style([
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "fg:cyan noinherit"),
|
||||
("pointer", "fg:cyan noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
|
||||
|
||||
def ensure_api_key(provider: str) -> Optional[str]:
|
||||
"""Make sure the API key for `provider` is available in the environment.
|
||||
|
||||
If the env var is already set, returns its value untouched. Otherwise
|
||||
interactively prompts the user, persists the value to the project's
|
||||
.env file via python-dotenv's set_key (creating .env if needed), and
|
||||
exports it into os.environ so the current process picks it up.
|
||||
|
||||
Returns None for providers that do not require a key (e.g. ollama)
|
||||
and for providers not found in the canonical mapping.
|
||||
"""
|
||||
env_var = get_api_key_env(provider)
|
||||
if env_var is None:
|
||||
return None # ollama / unknown — no key check possible
|
||||
|
||||
existing = os.environ.get(env_var)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
console.print(
|
||||
f"\n[yellow]{env_var} is not set in your environment.[/yellow]"
|
||||
)
|
||||
key = questionary.password(
|
||||
f"Paste your {env_var} (will be saved to .env):",
|
||||
style=questionary.Style([
|
||||
("text", "fg:cyan"),
|
||||
("highlighted", "noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
if not key:
|
||||
console.print(
|
||||
f"[red]Skipped. API calls will fail until {env_var} is set.[/red]"
|
||||
)
|
||||
return None
|
||||
|
||||
env_path = find_dotenv(usecwd=True) or str(Path.cwd() / ".env")
|
||||
Path(env_path).touch(exist_ok=True)
|
||||
set_key(env_path, env_var, key)
|
||||
os.environ[env_var] = key
|
||||
console.print(f"[green]Saved {env_var} to {env_path}[/green]")
|
||||
return key
|
||||
|
||||
|
||||
def ask_output_language() -> str:
|
||||
"""Ask for report output language."""
|
||||
choice = questionary.select(
|
||||
|
||||
Reference in New Issue
Block a user