chore(lint): make the repository ruff-clean under the strict select

Clear the deferred full-repo lint backlog so the whole tree passes the strict
ruff select (E,W,F,I,B,UP,C4,SIM). Mechanical fixes dominate: import sorting,
pep585/604 annotations, dropped dead imports, and whitespace. The few semantic
changes are behavior-preserving: declare __all__ on the agent_utils and
alpha_vantage re-export hubs; expand 'from x import *' to explicit names; use
immutable tuple defaults instead of mutable list defaults; contextlib.suppress
for try/except/pass; and narrow an over-broad assertRaises.
This commit is contained in:
Yijia-Xiao
2026-06-14 16:38:36 +00:00
parent cbc5f67d42
commit e3bc872982
59 changed files with 315 additions and 293 deletions

View File

@@ -1,4 +1,5 @@
import getpass
import requests
from rich.console import Console
from rich.panel import Panel

View File

@@ -1,38 +1,53 @@
from typing import Optional
import os
import datetime
import typer
import questionary
from pathlib import Path
from functools import wraps
from rich.console import Console
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
from rich.columns import Columns
from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
from rich.table import Table
from collections import deque
import os
import time
from rich.tree import Tree
from collections import deque
from functools import wraps
from pathlib import Path
import typer
from rich import box
from rich.align import Align
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.rule import Rule
from rich.spinner import Spinner
from rich.table import Table
from rich.text import Text
from tradingagents.graph.trading_graph import TradingAgentsGraph
from cli.announcements import display_announcements, fetch_announcements
from cli.stats_handler import StatsCallbackHandler
from cli.utils import (
ask_anthropic_effort,
ask_gemini_thinking_config,
ask_glm_region,
ask_minimax_region,
ask_openai_reasoning_effort,
ask_output_language,
ask_qwen_region,
confirm_ollama_endpoint,
detect_asset_type,
ensure_api_key,
get_ticker,
prompt_openai_compatible_url,
resolve_backend_url,
select_analysts,
select_deep_thinking_agent,
select_llm_provider,
select_research_depth,
select_shallow_thinking_agent,
)
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.analyst_execution import (
AnalystWallTimeTracker,
build_analyst_execution_plan,
get_initial_analyst_node,
sync_analyst_tracker_from_chunk,
)
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
from cli.announcements import fetch_announcements, display_announcements
from cli.stats_handler import StatsCallbackHandler
from tradingagents.graph.trading_graph import TradingAgentsGraph
console = Console()
@@ -466,7 +481,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
def get_user_selections():
"""Get all user selections before starting the analysis display."""
# Display ASCII art welcome message
with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f:
with open(Path(__file__).parent / "static" / "welcome.txt", encoding="utf-8") as f:
welcome_ascii = f.read()
# Create welcome box content
@@ -922,9 +937,12 @@ def update_analyst_statuses(message_buffer, chunk, wall_time_tracker=None):
message_buffer.update_agent_status(agent_name, "pending")
# When all analysts complete, transition research team to in_progress
if not found_active and selected:
if message_buffer.agent_status.get("Bull Researcher") == "pending":
message_buffer.update_agent_status("Bull Researcher", "in_progress")
if (
not found_active
and selected
and message_buffer.agent_status.get("Bull Researcher") == "pending"
):
message_buffer.update_agent_status("Bull Researcher", "in_progress")
def extract_content_string(content):
"""Extract string content from various message formats.
@@ -1097,7 +1115,7 @@ def run_analysis(checkpoint: bool = False):
# Now start the display layout
layout = create_layout()
with Live(layout, refresh_per_second=4) as live:
with Live(layout, refresh_per_second=4):
# Initial display
update_display(layout, stats_handler=stats_handler, start_time=start_time)
@@ -1232,16 +1250,15 @@ def run_analysis(checkpoint: bool = False):
message_buffer.update_report_section(
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
)
if judge:
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
)
message_buffer.update_agent_status("Aggressive Analyst", "completed")
message_buffer.update_agent_status("Conservative Analyst", "completed")
message_buffer.update_agent_status("Neutral Analyst", "completed")
message_buffer.update_agent_status("Portfolio Manager", "completed")
if judge and message_buffer.agent_status.get("Portfolio Manager") != "completed":
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
)
message_buffer.update_agent_status("Aggressive Analyst", "completed")
message_buffer.update_agent_status("Conservative Analyst", "completed")
message_buffer.update_agent_status("Neutral Analyst", "completed")
message_buffer.update_agent_status("Portfolio Manager", "completed")
# Update the display
update_display(layout, stats_handler=stats_handler, start_time=start_time)
@@ -1253,7 +1270,6 @@ def run_analysis(checkpoint: bool = False):
final_state = {}
for chunk in trace:
final_state.update(chunk)
decision = graph.process_signal(final_state["final_trade_decision"])
# Update all agent statuses to completed
for agent in message_buffer.agent_status:
@@ -1265,7 +1281,7 @@ def run_analysis(checkpoint: bool = False):
message_buffer.add_message("System", analyst_wall_time_tracker.format_summary())
# Update final report sections
for section in message_buffer.report_sections.keys():
for section in message_buffer.report_sections:
if section in final_state:
message_buffer.update_report_section(section, final_state[section])

View File

@@ -1,6 +1,4 @@
from enum import Enum
from typing import List, Optional, Dict
from pydantic import BaseModel
class AnalystType(str, Enum):

View File

@@ -1,9 +1,9 @@
import threading
from typing import Any, Dict, List, Union
from typing import Any
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.messages import AIMessage
from langchain_core.outputs import LLMResult
class StatsCallbackHandler(BaseCallbackHandler):
@@ -19,8 +19,8 @@ class StatsCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
**kwargs: Any,
) -> None:
"""Increment LLM call counter when an LLM starts."""
@@ -29,8 +29,8 @@ class StatsCallbackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[Any]],
serialized: dict[str, Any],
messages: list[list[Any]],
**kwargs: Any,
) -> None:
"""Increment LLM call counter when a chat model starts."""
@@ -57,7 +57,7 @@ class StatsCallbackHandler(BaseCallbackHandler):
def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
@@ -65,7 +65,7 @@ class StatsCallbackHandler(BaseCallbackHandler):
with self._lock:
self.tool_calls += 1
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""Return current statistics."""
with self._lock:
return {

View File

@@ -1,6 +1,5 @@
import os
from pathlib import Path
from typing import List, Optional, Tuple, Dict
import questionary
from dotenv import find_dotenv, set_key
@@ -89,8 +88,8 @@ def detect_asset_type(ticker: str) -> AssetType:
def filter_analysts_for_asset_type(
analysts: List[AnalystType], asset_type: AssetType
) -> List[AnalystType]:
analysts: list[AnalystType], asset_type: AssetType
) -> list[AnalystType]:
if asset_type != AssetType.CRYPTO:
return analysts
return [
@@ -133,7 +132,7 @@ def get_analysis_date() -> str:
return date.strip()
def select_analysts(asset_type: AssetType = AssetType.STOCK) -> List[AnalystType]:
def select_analysts(asset_type: AssetType = AssetType.STOCK) -> list[AnalystType]:
"""Select analysts using an interactive checkbox."""
available_analysts = filter_analysts_for_asset_type(
[value for _, value in ANALYST_ORDER],
@@ -197,7 +196,7 @@ def select_research_depth() -> int:
return choice
def _fetch_openrouter_models() -> List[Tuple[str, str]]:
def _fetch_openrouter_models() -> list[tuple[str, str]]:
"""Fetch available models from the OpenRouter API."""
import requests
try:
@@ -556,7 +555,7 @@ def confirm_ollama_endpoint(url: str) -> None:
)
def ensure_api_key(provider: str) -> Optional[str]:
def ensure_api_key(provider: str) -> str | None:
"""Make sure the API key for `provider` is available in the environment.
If the env var is already set, returns its value untouched. Otherwise

View File

@@ -1,5 +1,5 @@
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides
# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.),

View File

@@ -21,7 +21,6 @@ added, plus the heuristic SignalProcessor.
from __future__ import annotations
import argparse
import os
import sys
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
@@ -30,7 +29,6 @@ from tradingagents.agents.trader.trader import create_trader
from tradingagents.graph.signal_processing import SignalProcessor
from tradingagents.llm_clients import create_llm_client
PROVIDER_DEFAULTS = {
"openai": ("gpt-5.4-mini", None),
"google": ("gemini-2.5-flash", None),

View File

@@ -1,5 +1,8 @@
import time
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
from tradingagents.dataflows.y_finance import (
get_stock_stats_indicators_window,
)
print("Testing optimized implementation with 30-day lookback:")
start_time = time.time()

View File

@@ -3,14 +3,12 @@
from __future__ import annotations
import os
from pathlib import Path
from unittest.mock import patch
import pytest
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
# ---- Mapping coverage -----------------------------------------------------
@@ -71,6 +69,7 @@ def test_case_insensitive_lookup():
def cli_utils(monkeypatch):
"""Import cli.utils with a fresh environment so module-level state is consistent."""
import importlib
import cli.utils as cli_utils_module
return importlib.reload(cli_utils_module)

View File

@@ -1,9 +1,10 @@
"""Unit tests for the LLM capability table."""
from dataclasses import FrozenInstanceError
import pytest
from tradingagents.llm_clients.capabilities import (
ModelCapabilities,
get_capabilities,
)
@@ -119,5 +120,5 @@ class TestDefault:
def test_capabilities_dataclass_is_frozen():
"""Capability rows are immutable so they can be safely shared."""
caps = get_capabilities("deepseek-chat")
with pytest.raises(Exception):
with pytest.raises(FrozenInstanceError):
caps.supports_tool_choice = False # type: ignore[misc]

View File

@@ -1,12 +1,9 @@
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
import sqlite3
import tempfile
import unittest
from pathlib import Path
from typing import TypedDict
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph
from tradingagents.graph.checkpointer import (

View File

@@ -24,7 +24,6 @@ from tradingagents.llm_clients.openai_client import (
_input_to_messages,
)
# ---------------------------------------------------------------------------
# _input_to_messages — the helper that handles list / ChatPromptValue / other
# (Gemini bot review note: non-list inputs must also work)

View File

@@ -1,15 +1,16 @@
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
import pytest
import pandas as pd
from unittest.mock import MagicMock, patch
from tradingagents.agents.utils.memory import TradingMemoryLog
import pandas as pd
import pytest
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
from tradingagents.agents.utils.memory import TradingMemoryLog
from tradingagents.graph.propagation import Propagator
from tradingagents.graph.reflection import Reflector
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.graph.propagation import Propagator
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
_SEP = TradingMemoryLog._SEPARATOR

View File

@@ -6,7 +6,6 @@ news injected future articles), #993 (empty-after-filter returned a blank body).
"""
import time
from datetime import datetime
from unittest import mock
import pytest

View File

@@ -14,7 +14,7 @@ from unittest import mock
import pandas as pd
import pytest
from tradingagents.dataflows import stockstats_utils, interface
from tradingagents.dataflows import interface, stockstats_utils
from tradingagents.dataflows.config import set_config
from tradingagents.dataflows.symbol_utils import NoMarketDataError
@@ -33,9 +33,9 @@ class TestLoadOhlcvNoPoison(unittest.TestCase):
def test_empty_download_raises_and_does_not_cache(self):
empty = pd.DataFrame()
with mock.patch.object(stockstats_utils.yf, "download", return_value=empty) as dl:
with self.assertRaises(NoMarketDataError):
stockstats_utils.load_ohlcv("FAKE", "2026-01-01")
with mock.patch.object(stockstats_utils.yf, "download", return_value=empty), \
self.assertRaises(NoMarketDataError):
stockstats_utils.load_ohlcv("FAKE", "2026-01-01")
# Nothing should have been written to the cache.
self.assertEqual(os.listdir(self._tmp), [])

View File

@@ -18,8 +18,8 @@ def _resync_reloaded_modules():
doesn't leak across test modules.
"""
yield
import cli.utils
import cli.main
import cli.utils
importlib.reload(cli.utils)
importlib.reload(cli.main)

View File

@@ -4,7 +4,6 @@ Verifies the user-supplied base_url is required and honored, the key is optional
(keyless local default), Chat Completions (not the Responses API) is used, any
model name is accepted, and the env backend URL precedence (#978).
"""
import os
import pytest

View File

@@ -13,7 +13,6 @@ import pytest
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
from tradingagents.graph.signal_processing import SignalProcessor
# ---------------------------------------------------------------------------
# Heuristic parser
# ---------------------------------------------------------------------------

View File

@@ -27,7 +27,6 @@ from tradingagents.agents.schemas import (
)
from tradingagents.agents.trader.trader import create_trader
# ---------------------------------------------------------------------------
# Render functions
# ---------------------------------------------------------------------------

View File

@@ -6,8 +6,8 @@ import pytest
from tradingagents.dataflows.symbol_utils import (
NoMarketDataError,
normalize_symbol,
is_yahoo_safe,
normalize_symbol,
)

View File

@@ -6,7 +6,6 @@ Regressions for #988 (explicit single-vendor config still fell back to others),
were swallowed without a trace).
"""
import copy
import logging
import unittest
from unittest import mock
@@ -76,9 +75,9 @@ class VendorRoutingTests(unittest.TestCase):
# #989: primary errors + fallback no-data -> NO_DATA, but the failure
# must be visible in logs (broken primary not hidden).
set_config({"data_vendors": {"core_stock_apis": "yfinance,alpha_vantage"}})
with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}):
with self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm:
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}), \
self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm:
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
self.assertIn("NO_DATA_AVAILABLE", result)
joined = "\n".join(cm.output)
self.assertIn("boom", joined) # the real error surfaced in logs

View File

@@ -1,3 +1,4 @@
import contextlib
import warnings
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
@@ -20,10 +21,8 @@ except ImportError:
# subclassed warning categories. To suppress a specific warning we must
# install our filter AFTER langchain-core has installed its own, so import
# it first. The package is a guaranteed transitive dep via langgraph.
try:
with contextlib.suppress(ImportError):
import langchain_core # noqa: F401
except ImportError:
pass
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
# explicit allowed_objects, which triggers a noisy pending-deprecation

View File

@@ -1,6 +1,3 @@
from .utils.agent_utils import create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
@@ -8,18 +5,16 @@ from .analysts.sentiment_analyst import (
create_sentiment_analyst,
create_social_media_analyst, # deprecated alias kept for back-compat
)
from .managers.portfolio_manager import create_portfolio_manager
from .managers.research_manager import create_research_manager
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
from .risk_mgmt.aggressive_debator import create_aggressive_debator
from .risk_mgmt.conservative_debator import create_conservative_debator
from .risk_mgmt.neutral_debator import create_neutral_debator
from .managers.research_manager import create_research_manager
from .managers.portfolio_manager import create_portfolio_manager
from .trader.trader import create_trader
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.agent_utils import create_msg_delete
__all__ = [
"AgentState",

View File

@@ -1,14 +1,13 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import (
get_instrument_context_from_state,
get_balance_sheet,
get_cashflow,
get_fundamentals,
get_income_statement,
get_insider_transactions,
get_instrument_context_from_state,
get_language_instruction,
)
from tradingagents.dataflows.config import get_config
def create_fundamentals_analyst(llm):

View File

@@ -1,12 +1,12 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import (
get_instrument_context_from_state,
get_indicators,
get_instrument_context_from_state,
get_language_instruction,
get_stock_data,
get_verified_market_snapshot,
)
from tradingagents.dataflows.config import get_config
def create_market_analyst(llm):

View File

@@ -1,13 +1,13 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import (
get_instrument_context_from_state,
get_global_news,
get_instrument_context_from_state,
get_language_instruction,
get_macro_indicators,
get_news,
get_prediction_markets,
)
from tradingagents.dataflows.config import get_config
def create_news_analyst(llm):

View File

@@ -19,11 +19,10 @@ so that:
from __future__ import annotations
from enum import Enum
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Shared rating types
# ---------------------------------------------------------------------------
@@ -124,15 +123,15 @@ class TraderProposal(BaseModel):
"the research plan. Two to four sentences."
),
)
entry_price: Optional[float] = Field(
entry_price: float | None = Field(
default=None,
description="Optional entry price target in the instrument's quote currency.",
)
stop_loss: Optional[float] = Field(
stop_loss: float | None = Field(
default=None,
description="Optional stop-loss price in the instrument's quote currency.",
)
position_sizing: Optional[str] = Field(
position_sizing: str | None = Field(
default=None,
description="Optional sizing guidance, e.g. '5% of portfolio'.",
)
@@ -196,11 +195,11 @@ class PortfolioDecision(BaseModel):
"incorporate them; otherwise rely solely on the current analysis."
),
)
price_target: Optional[float] = Field(
price_target: float | None = Field(
default=None,
description="Optional target price in the instrument's quote currency.",
)
time_horizon: Optional[str] = Field(
time_horizon: str | None = Field(
default=None,
description="Optional recommended holding period, e.g. '3-6 months'.",
)

View File

@@ -1,6 +1,7 @@
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import MessagesState
from typing_extensions import TypedDict
# Researcher team state

View File

@@ -1,37 +1,50 @@
import functools
import logging
from typing import Any, Mapping, Optional
from collections.abc import Mapping
from typing import Any
import yfinance as yf
from langchain_core.messages import HumanMessage, RemoveMessage
# Import tools from separate utility files
from tradingagents.agents.utils.core_stock_tools import (
get_stock_data
)
from tradingagents.agents.utils.technical_indicators_tools import (
get_indicators
)
from tradingagents.agents.utils.core_stock_tools import get_stock_data
from tradingagents.agents.utils.fundamental_data_tools import (
get_fundamentals,
get_balance_sheet,
get_cashflow,
get_income_statement
get_fundamentals,
get_income_statement,
)
from tradingagents.agents.utils.macro_data_tools import get_macro_indicators
from tradingagents.agents.utils.market_data_validation_tools import get_verified_market_snapshot
from tradingagents.agents.utils.news_data_tools import (
get_news,
get_global_news,
get_insider_transactions,
get_global_news
)
from tradingagents.agents.utils.macro_data_tools import (
get_macro_indicators
)
from tradingagents.agents.utils.prediction_markets_tools import (
get_prediction_markets
)
from tradingagents.agents.utils.market_data_validation_tools import (
get_verified_market_snapshot
get_news,
)
from tradingagents.agents.utils.prediction_markets_tools import get_prediction_markets
from tradingagents.agents.utils.technical_indicators_tools import get_indicators
# Public surface: the data tools are imported here so agents and the graph
# import them from one place, plus the instrument/language helpers defined below.
__all__ = [
"get_stock_data",
"get_indicators",
"get_fundamentals",
"get_balance_sheet",
"get_cashflow",
"get_income_statement",
"get_news",
"get_global_news",
"get_insider_transactions",
"get_macro_indicators",
"get_prediction_markets",
"get_verified_market_snapshot",
"build_instrument_context",
"resolve_instrument_identity",
"get_instrument_context_from_state",
"get_language_instruction",
"create_msg_delete",
]
logger = logging.getLogger(__name__)
@@ -52,7 +65,7 @@ def get_language_instruction() -> str:
return f" Write your entire response in {lang}."
def _clean_identity_value(value: Any) -> Optional[str]:
def _clean_identity_value(value: Any) -> str | None:
"""Return a trimmed string, or None for empty / placeholder-ish values."""
if not isinstance(value, str):
return None
@@ -109,7 +122,7 @@ def resolve_instrument_identity(ticker: str) -> dict:
def build_instrument_context(
ticker: str,
asset_type: str = "stock",
identity: Optional[Mapping[str, str]] = None,
identity: Mapping[str, str] | None = None,
) -> str:
"""Describe the exact instrument so agents preserve identity and ticker.

View File

@@ -1,5 +1,7 @@
from langchain_core.tools import tool
from typing import Annotated
from langchain_core.tools import tool
from tradingagents.dataflows.interface import route_to_vendor

View File

@@ -1,5 +1,7 @@
from langchain_core.tools import tool
from typing import Annotated
from langchain_core.tools import tool
from tradingagents.dataflows.interface import route_to_vendor

View File

@@ -1,8 +1,7 @@
"""Append-only markdown decision log for TradingAgents."""
from typing import List, Optional
from pathlib import Path
import re
from pathlib import Path
from tradingagents.agents.utils.rating import parse_rating
@@ -51,7 +50,7 @@ class TradingMemoryLog:
# --- Read path (Phase A) ---
def load_entries(self) -> List[dict]:
def load_entries(self) -> list[dict]:
"""Parse all entries from log. Returns list of dicts."""
if not self._log_path or not self._log_path.exists():
return []
@@ -64,7 +63,7 @@ class TradingMemoryLog:
entries.append(parsed)
return entries
def get_pending_entries(self) -> List[dict]:
def get_pending_entries(self) -> list[dict]:
"""Return entries with outcome:pending (for Phase B)."""
return [e for e in self.load_entries() if e.get("pending")]
@@ -162,7 +161,7 @@ class TradingMemoryLog:
tmp_path.write_text(new_text, encoding="utf-8")
tmp_path.replace(self._log_path)
def batch_update_with_outcomes(self, updates: List[dict]) -> None:
def batch_update_with_outcomes(self, updates: list[dict]) -> None:
"""Apply multiple outcome updates in a single read + atomic write.
Each element of updates must have keys: ticker, trade_date,
@@ -218,7 +217,7 @@ class TradingMemoryLog:
# --- Helpers ---
def _apply_rotation(self, blocks: List[str]) -> List[str]:
def _apply_rotation(self, blocks: list[str]) -> list[str]:
"""Drop oldest resolved blocks when their count exceeds max_entries.
Pending blocks are always kept (they represent unprocessed work).
@@ -247,7 +246,7 @@ class TradingMemoryLog:
return blocks
to_drop = resolved_count - self._max_entries
kept: List[str] = []
kept: list[str] = []
for block, is_resolved in decisions:
if is_resolved and to_drop > 0:
to_drop -= 1
@@ -255,7 +254,7 @@ class TradingMemoryLog:
kept.append(block)
return kept
def _parse_entry(self, raw: str) -> Optional[dict]:
def _parse_entry(self, raw: str) -> dict | None:
lines = raw.strip().splitlines()
if not lines:
return None

View File

@@ -1,7 +1,10 @@
from typing import Annotated
from langchain_core.tools import tool
from typing import Annotated, Optional
from tradingagents.dataflows.interface import route_to_vendor
@tool
def get_news(
ticker: Annotated[str, "Ticker symbol"],
@@ -23,8 +26,8 @@ def get_news(
@tool
def get_global_news(
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None,
limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None,
look_back_days: Annotated[int | None, "Days to look back; omit to use the configured default"] = None,
limit: Annotated[int | None, "Max articles to return; omit to use the configured default"] = None,
) -> str:
"""
Retrieve global news data.

View File

@@ -12,11 +12,9 @@ Centralising it here avoids drift between those call sites.
from __future__ import annotations
import re
from typing import Tuple
# Canonical, ordered 5-tier scale (most bullish to most bearish).
RATINGS_5_TIER: Tuple[str, ...] = (
RATINGS_5_TIER: tuple[str, ...] = (
"Buy", "Overweight", "Hold", "Underweight", "Sell",
)

View File

@@ -19,7 +19,8 @@ all three agents log the same warnings when fallback fires.
from __future__ import annotations
import logging
from typing import Any, Callable, Optional, TypeVar
from collections.abc import Callable
from typing import Any, TypeVar
from pydantic import BaseModel
@@ -28,7 +29,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]:
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Any | None:
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
Logs a warning when the binding fails so the user understands the agent
@@ -46,7 +47,7 @@ def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]
def invoke_structured_or_freetext(
structured_llm: Optional[Any],
structured_llm: Any | None,
plain_llm: Any,
prompt: Any,
render: Callable[[T], str],

View File

@@ -1,7 +1,10 @@
from langchain_core.tools import tool
from typing import Annotated
from langchain_core.tools import tool
from tradingagents.dataflows.interface import route_to_vendor
@tool
def get_indicators(
symbol: Annotated[str, "ticker symbol of the company"],

View File

@@ -1,5 +1,23 @@
# Import functions from specialized modules
from .alpha_vantage_stock import get_stock
# Aggregates the per-category Alpha Vantage implementations into one module the
# vendor router imports from; the imports below are the public surface.
from .alpha_vantage_fundamentals import (
get_balance_sheet,
get_cashflow,
get_fundamentals,
get_income_statement,
)
from .alpha_vantage_indicator import get_indicator
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
from .alpha_vantage_news import get_news, get_global_news, get_insider_transactions
from .alpha_vantage_news import get_global_news, get_insider_transactions, get_news
from .alpha_vantage_stock import get_stock
__all__ = [
"get_balance_sheet",
"get_cashflow",
"get_fundamentals",
"get_income_statement",
"get_indicator",
"get_global_news",
"get_insider_transactions",
"get_news",
"get_stock",
]

View File

@@ -1,4 +1,5 @@
from .alpha_vantage_common import _make_api_request, AlphaVantageNotConfiguredError
from .alpha_vantage_common import AlphaVantageNotConfiguredError, _make_api_request
def get_indicator(
symbol: str,
@@ -25,6 +26,7 @@ def get_indicator(
String containing indicator values and description
"""
from datetime import datetime
from dateutil.relativedelta import relativedelta
supported_indicators = {
@@ -98,21 +100,7 @@ def get_indicator(
"series_type": series_type,
"datatype": "csv"
})
elif indicator == "macd":
data = _make_api_request("MACD", {
"symbol": symbol,
"interval": interval,
"series_type": series_type,
"datatype": "csv"
})
elif indicator == "macds":
data = _make_api_request("MACD", {
"symbol": symbol,
"interval": interval,
"series_type": series_type,
"datatype": "csv"
})
elif indicator == "macdh":
elif indicator == "macd" or indicator == "macds" or indicator == "macdh":
data = _make_api_request("MACD", {
"symbol": symbol,
"interval": interval,

View File

@@ -1,5 +1,6 @@
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.

View File

@@ -1,5 +1,7 @@
from datetime import datetime
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
from .alpha_vantage_common import _filter_csv_by_date_range, _make_api_request
def get_stock(
symbol: str,

View File

@@ -1,10 +1,9 @@
from copy import deepcopy
from typing import Dict, Optional
import tradingagents.default_config as default_config
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
_config: dict | None = None
def initialize_config():
@@ -14,7 +13,7 @@ def initialize_config():
_config = deepcopy(default_config.DEFAULT_CONFIG)
def set_config(config: Dict):
def set_config(config: dict):
"""Update the configuration with custom values.
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
@@ -31,7 +30,7 @@ def set_config(config: Dict):
_config[key] = value
def get_config() -> Dict:
def get_config() -> dict:
"""Get the current configuration."""
if _config is None:
initialize_config()

View File

@@ -10,7 +10,7 @@ claim. Deterministic, no LLM involved.
from __future__ import annotations
from typing import Iterable, Optional
from collections.abc import Iterable
import pandas as pd
from stockstats import wrap
@@ -63,7 +63,7 @@ def build_verified_market_snapshot(
symbol: str,
curr_date: str,
look_back_days: int = 30,
indicators: Optional[Iterable[str]] = None,
indicators: Iterable[str] | None = None,
) -> str:
"""Render a ground-truth snapshot: latest OHLCV row, indicators, recent closes."""
# `df` keeps the original capitalized OHLCV columns (Open/High/Low/Close/

View File

@@ -1,10 +1,9 @@
import os
import re
import json
import pandas as pd
from datetime import date, timedelta, datetime
from datetime import date, datetime, timedelta
from typing import Annotated
import pandas as pd
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
# Tickers can contain letters, digits, dot, dash, underscore, caret

View File

@@ -1,9 +1,9 @@
"""yfinance-based news data fetching functions."""
from typing import Optional
import contextlib
from datetime import datetime
import yfinance as yf
from datetime import datetime
from dateutil.relativedelta import relativedelta
from .config import get_config
@@ -28,10 +28,8 @@ def _extract_article_data(article: dict) -> dict:
pub_date_str = content.get("pubDate", "")
pub_date = None
if pub_date_str:
try:
with contextlib.suppress(ValueError, AttributeError):
pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00"))
except (ValueError, AttributeError):
pass
return {
"title": title,
@@ -47,10 +45,8 @@ def _extract_article_data(article: dict) -> dict:
pub_date = None
ts = article.get("providerPublishTime")
if ts:
try:
with contextlib.suppress(ValueError, OSError, TypeError):
pub_date = datetime.fromtimestamp(ts)
except (ValueError, OSError, TypeError):
pass
return {
"title": article.get("title", "No title"),
"summary": article.get("summary", ""),
@@ -131,8 +127,8 @@ def get_news_yfinance(
def get_global_news_yfinance(
curr_date: str,
look_back_days: Optional[int] = None,
limit: Optional[int] = None,
look_back_days: int | None = None,
limit: int | None = None,
) -> str:
"""
Retrieve global/macro economic news using yfinance Search.

View File

@@ -1,11 +1,11 @@
# TradingAgents/graph/__init__.py
from .trading_graph import TradingAgentsGraph
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor
from .trading_graph import TradingAgentsGraph
__all__ = [
"TradingAgentsGraph",

View File

@@ -1,6 +1,6 @@
from collections.abc import Iterable
from dataclasses import dataclass
from time import monotonic
from typing import Dict, Iterable, List, Optional
@dataclass(frozen=True)
@@ -14,11 +14,11 @@ class AnalystNodeSpec:
@dataclass(frozen=True)
class AnalystExecutionPlan:
specs: List[AnalystNodeSpec]
specs: list[AnalystNodeSpec]
concurrency_limit: int
ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = {
ANALYST_NODE_SPECS: dict[str, AnalystNodeSpec] = {
"market": AnalystNodeSpec(
key="market",
agent_node="Market Analyst",
@@ -61,7 +61,7 @@ def build_analyst_execution_plan(
if concurrency_limit < 1:
raise ValueError("analyst concurrency limit must be >= 1")
specs: List[AnalystNodeSpec] = []
specs: list[AnalystNodeSpec] = []
for analyst_key in selected_analysts:
spec = ANALYST_NODE_SPECS.get(analyst_key)
if spec is None:
@@ -81,10 +81,10 @@ def get_initial_analyst_node(plan: AnalystExecutionPlan) -> str:
class AnalystWallTimeTracker:
def __init__(self, plan: AnalystExecutionPlan):
self.plan = plan
self._started_at: Dict[str, float] = {}
self._wall_times: Dict[str, float] = {}
self._started_at: dict[str, float] = {}
self._wall_times: dict[str, float] = {}
def mark_started(self, analyst_key: str, started_at: Optional[float] = None) -> None:
def mark_started(self, analyst_key: str, started_at: float | None = None) -> None:
if analyst_key not in ANALYST_NODE_SPECS:
raise ValueError(f"unknown analyst key: {analyst_key}")
self._started_at.setdefault(analyst_key, monotonic() if started_at is None else started_at)
@@ -92,7 +92,7 @@ class AnalystWallTimeTracker:
def mark_completed(
self,
analyst_key: str,
completed_at: Optional[float] = None,
completed_at: float | None = None,
) -> None:
if analyst_key not in ANALYST_NODE_SPECS:
raise ValueError(f"unknown analyst key: {analyst_key}")
@@ -104,7 +104,7 @@ class AnalystWallTimeTracker:
finished_at = monotonic() if completed_at is None else completed_at
self._wall_times[analyst_key] = max(0.0, finished_at - started_at)
def get_wall_times(self) -> Dict[str, float]:
def get_wall_times(self) -> dict[str, float]:
return dict(self._wall_times)
def format_summary(self) -> str:
@@ -121,8 +121,8 @@ class AnalystWallTimeTracker:
def sync_analyst_tracker_from_chunk(
tracker: AnalystWallTimeTracker,
chunk: Dict[str, str],
now: Optional[float] = None,
chunk: dict[str, str],
now: float | None = None,
) -> None:
current_time = monotonic() if now is None else now
active_found = False

View File

@@ -7,9 +7,9 @@ from __future__ import annotations
import hashlib
import sqlite3
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from langgraph.checkpoint.sqlite import SqliteSaver

View File

@@ -1,8 +1,8 @@
# TradingAgents/graph/propagation.py
from typing import Dict, Any, List, Optional
from typing import Any
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
@@ -22,7 +22,7 @@ class Propagator:
asset_type: str = "stock",
past_context: str = "",
instrument_context: str = "",
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Create the initial state for the agent graph.
``instrument_context`` is the deterministic ticker-identity string
@@ -68,7 +68,7 @@ class Propagator:
"news_report": "",
}
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
def get_graph_args(self, callbacks: list | None = None) -> dict[str, Any]:
"""Get arguments for the graph invocation.
Args:

View File

@@ -1,10 +1,25 @@
# TradingAgents/graph/setup.py
from typing import Any, Dict
from typing import Any
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.agents import (
create_aggressive_debator,
create_bear_researcher,
create_bull_researcher,
create_conservative_debator,
create_fundamentals_analyst,
create_market_analyst,
create_msg_delete,
create_neutral_debator,
create_news_analyst,
create_portfolio_manager,
create_research_manager,
create_sentiment_analyst,
create_trader,
)
from tradingagents.agents.utils.agent_states import AgentState
from .analyst_execution import build_analyst_execution_plan
@@ -18,7 +33,7 @@ class GraphSetup:
self,
quick_thinking_llm: Any,
deep_thinking_llm: Any,
tool_nodes: Dict[str, ToolNode],
tool_nodes: dict[str, ToolNode],
conditional_logic: ConditionalLogic,
analyst_concurrency_limit: int = 1,
):
@@ -30,7 +45,7 @@ class GraphSetup:
self.analyst_concurrency_limit = analyst_concurrency_limit
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
self, selected_analysts=("market", "social", "news", "fundamentals")
):
"""Set up and compile the agent workflow graph.

View File

@@ -1,66 +1,57 @@
# TradingAgents/graph/trading_graph.py
import json
import logging
import os
from pathlib import Path
import json
from datetime import datetime, timedelta
from typing import Dict, Any, Tuple, List, Optional
from pathlib import Path
from typing import Any
import yfinance as yf
logger = logging.getLogger(__name__)
from langgraph.prebuilt import ToolNode
from tradingagents.llm_clients import create_llm_client
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import TradingMemoryLog
from tradingagents.dataflows.utils import safe_ticker_component
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
from tradingagents.dataflows.config import set_config
# Import the new abstract tool methods from agent_utils
# Import the abstract tool methods from agent_utils
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
resolve_instrument_identity,
get_stock_data,
get_indicators,
get_verified_market_snapshot,
get_fundamentals,
get_balance_sheet,
get_cashflow,
get_income_statement,
get_news,
get_insider_transactions,
get_fundamentals,
get_global_news,
get_income_statement,
get_indicators,
get_insider_transactions,
get_macro_indicators,
get_prediction_markets
get_news,
get_prediction_markets,
get_stock_data,
get_verified_market_snapshot,
resolve_instrument_identity,
)
from tradingagents.agents.utils.memory import TradingMemoryLog
from tradingagents.dataflows.config import set_config
from tradingagents.dataflows.utils import safe_ticker_component
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.llm_clients import create_llm_client
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor
logger = logging.getLogger(__name__)
class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework."""
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
selected_analysts=("market", "social", "news", "fundamentals"),
debug=False,
config: Dict[str, Any] = None,
callbacks: Optional[List] = None,
config: dict[str, Any] = None,
callbacks: list | None = None,
):
"""Initialize the trading agents graph and components.
@@ -138,7 +129,7 @@ class TradingAgentsGraph:
self.graph = self.workflow.compile()
self._checkpointer_ctx = None
def _get_provider_kwargs(self) -> Dict[str, Any]:
def _get_provider_kwargs(self) -> dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
kwargs = {}
provider = self.config.get("llm_provider", "").lower()
@@ -167,7 +158,7 @@ class TradingAgentsGraph:
return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
def _create_tool_nodes(self) -> dict[str, ToolNode]:
"""Create tool nodes for different data sources using abstract methods."""
return {
"market": ToolNode(
@@ -233,7 +224,7 @@ class TradingAgentsGraph:
def _fetch_returns(
self, ticker: str, trade_date: str, holding_days: int = 5,
benchmark: str = "SPY",
) -> Tuple[Optional[float], Optional[float], Optional[int]]:
) -> tuple[float | None, float | None, int | None]:
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
``benchmark`` is the index used as the alpha baseline (resolved by the

View File

@@ -1,5 +1,5 @@
import re
from typing import Any, Optional
from typing import Any
from langchain_anthropic import ChatAnthropic
@@ -43,7 +43,7 @@ class NormalizedChatAnthropic(ChatAnthropic):
class AnthropicClient(BaseLLMClient):
"""Client for Anthropic Claude models."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
def __init__(self, model: str, base_url: str | None = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:

View File

@@ -11,10 +11,7 @@ prompts for it automatically instead of failing on first API call.
from __future__ import annotations
from typing import Optional
PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
PROVIDER_API_KEY_ENV: dict[str, str | None] = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"google": "GOOGLE_API_KEY",
@@ -47,7 +44,7 @@ PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
}
def get_api_key_env(provider: str) -> Optional[str]:
def get_api_key_env(provider: str) -> str | None:
"""Return the env var name for `provider`'s API key, or None if not applicable.
Unknown providers also return None — callers should treat that as

View File

@@ -1,10 +1,9 @@
import os
from typing import Any, Optional
from typing import Any
from langchain_openai import AzureChatOpenAI
from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "api_key", "reasoning_effort", "temperature",
@@ -29,7 +28,7 @@ class AzureOpenAIClient(BaseLLMClient):
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
"""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
def __init__(self, model: str, base_url: str | None = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
import warnings
from abc import ABC, abstractmethod
from typing import Any
def normalize_content(response):
@@ -25,7 +25,7 @@ def normalize_content(response):
class BaseLLMClient(ABC):
"""Abstract base class for LLM clients."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
def __init__(self, model: str, base_url: str | None = None, **kwargs):
self.model = model
self.base_url = base_url
self.kwargs = kwargs

View File

@@ -18,7 +18,6 @@ import re
from dataclasses import dataclass
from typing import Literal
StructuredMethod = Literal[
"function_calling", # uses tools; respects supports_tool_choice
"json_mode", # uses response_format={"type":"json_object"}

View File

@@ -1,11 +1,11 @@
from typing import Optional
from .base_client import BaseLLMClient
def create_llm_client(
provider: str,
model: str,
base_url: Optional[str] = None,
base_url: str | None = None,
**kwargs,
) -> BaseLLMClient:
"""Create an LLM client for the specified provider.

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from langchain_google_genai import ChatGoogleGenerativeAI
@@ -20,7 +20,7 @@ class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI):
class GoogleClient(BaseLLMClient):
"""Client for Google Gemini models."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
def __init__(self, model: str, base_url: str | None = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:

View File

@@ -2,14 +2,12 @@
from __future__ import annotations
from typing import Dict, List, Tuple
ModelOption = Tuple[str, str]
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
ModelOption = tuple[str, str]
ProviderModeOptions = dict[str, dict[str, list[ModelOption]]]
# Providers that serve many / frequently-changing models: offer only "Custom
# model ID" rather than a list that goes stale.
_CUSTOM_ONLY: Dict[str, List[ModelOption]] = {
_CUSTOM_ONLY: dict[str, list[ModelOption]] = {
"quick": [("Custom model ID", "custom")],
"deep": [("Custom model ID", "custom")],
}
@@ -18,7 +16,7 @@ _CUSTOM_ONLY: Dict[str, List[ModelOption]] = {
# Shared model list for GLM via Z.AI (international) and BigModel (China).
# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides).
# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}.
_GLM_MODELS: Dict[str, List[ModelOption]] = {
_GLM_MODELS: dict[str, list[ModelOption]] = {
"quick": [
("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"),
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
@@ -44,7 +42,7 @@ _GLM_MODELS: Dict[str, List[ModelOption]] = {
# the backing model. Users who want a specific generation pick it
# explicitly; users who really want auto-latest can enter the alias via
# "Custom model ID".
_QWEN_MODELS: Dict[str, List[ModelOption]] = {
_QWEN_MODELS: dict[str, list[ModelOption]] = {
"quick": [
("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"),
("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"),
@@ -62,7 +60,7 @@ _QWEN_MODELS: Dict[str, List[ModelOption]] = {
# Shared model list for MiniMax's global and CN endpoints (same IDs).
# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api.
# All M2.x models share a 204,800-token context window.
_MINIMAX_MODELS: Dict[str, List[ModelOption]] = {
_MINIMAX_MODELS: dict[str, list[ModelOption]] = {
"quick": [
("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"),
("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"),
@@ -198,12 +196,12 @@ MODEL_OPTIONS: ProviderModeOptions = {
}
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
def get_model_options(provider: str, mode: str) -> list[ModelOption]:
"""Return shared model options for a provider and selection mode."""
return MODEL_OPTIONS[provider.lower()][mode]
def get_known_models() -> Dict[str, List[str]]:
def get_known_models() -> dict[str, list[str]]:
"""Build known model names from the shared CLI catalog."""
return {
provider: sorted(

View File

@@ -2,7 +2,6 @@
from .model_catalog import get_known_models
# Providers whose model names are user-defined (local servers, relays, hosted
# OpenAI-compatible endpoints serving many models), so any model string is
# accepted without warning.