diff --git a/cli/announcements.py b/cli/announcements.py index 5947cee59..3ff70159c 100644 --- a/cli/announcements.py +++ b/cli/announcements.py @@ -1,4 +1,5 @@ import getpass + import requests from rich.console import Console from rich.panel import Panel diff --git a/cli/main.py b/cli/main.py index 28da4bb8c..df3d69554 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,38 +1,53 @@ -from typing import Optional -import os import datetime -import typer -import questionary -from pathlib import Path -from functools import wraps -from rich.console import Console -from rich.panel import Panel -from rich.spinner import Spinner -from rich.live import Live -from rich.columns import Columns -from rich.markdown import Markdown -from rich.layout import Layout -from rich.text import Text -from rich.table import Table -from collections import deque +import os import time -from rich.tree import Tree +from collections import deque +from functools import wraps +from pathlib import Path + +import typer from rich import box from rich.align import Align +from rich.console import Console +from rich.layout import Layout +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel from rich.rule import Rule +from rich.spinner import Spinner +from rich.table import Table +from rich.text import Text -from tradingagents.graph.trading_graph import TradingAgentsGraph +from cli.announcements import display_announcements, fetch_announcements +from cli.stats_handler import StatsCallbackHandler +from cli.utils import ( + ask_anthropic_effort, + ask_gemini_thinking_config, + ask_glm_region, + ask_minimax_region, + ask_openai_reasoning_effort, + ask_output_language, + ask_qwen_region, + confirm_ollama_endpoint, + detect_asset_type, + ensure_api_key, + get_ticker, + prompt_openai_compatible_url, + resolve_backend_url, + select_analysts, + select_deep_thinking_agent, + select_llm_provider, + select_research_depth, + select_shallow_thinking_agent, +) +from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.graph.analyst_execution import ( AnalystWallTimeTracker, build_analyst_execution_plan, get_initial_analyst_node, sync_analyst_tracker_from_chunk, ) -from tradingagents.default_config import DEFAULT_CONFIG -from cli.models import AnalystType -from cli.utils import * -from cli.announcements import fetch_announcements, display_announcements -from cli.stats_handler import StatsCallbackHandler +from tradingagents.graph.trading_graph import TradingAgentsGraph console = Console() @@ -169,7 +184,7 @@ class MessageBuffer: if content is not None: latest_section = section latest_content = content - + if latest_section and latest_content: # Format the current section for display section_titles = { @@ -466,7 +481,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non def get_user_selections(): """Get all user selections before starting the analysis display.""" # Display ASCII art welcome message - with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f: + with open(Path(__file__).parent / "static" / "welcome.txt", encoding="utf-8") as f: welcome_ascii = f.read() # Create welcome box content @@ -922,9 +937,12 @@ def update_analyst_statuses(message_buffer, chunk, wall_time_tracker=None): message_buffer.update_agent_status(agent_name, "pending") # When all analysts complete, transition research team to in_progress - if not found_active and selected: - if message_buffer.agent_status.get("Bull Researcher") == "pending": - message_buffer.update_agent_status("Bull Researcher", "in_progress") + if ( + not found_active + and selected + and message_buffer.agent_status.get("Bull Researcher") == "pending" + ): + message_buffer.update_agent_status("Bull Researcher", "in_progress") def extract_content_string(content): """Extract string content from various message formats. @@ -1064,7 +1082,7 @@ def run_analysis(checkpoint: bool = False): with open(log_file, "a", encoding="utf-8") as f: f.write(f"{timestamp} [{message_type}] {content}\n") return wrapper - + def save_tool_call_decorator(obj, func_name): func = getattr(obj, func_name) @wraps(func) @@ -1097,7 +1115,7 @@ def run_analysis(checkpoint: bool = False): # Now start the display layout layout = create_layout() - with Live(layout, refresh_per_second=4) as live: + with Live(layout, refresh_per_second=4): # Initial display update_display(layout, stats_handler=stats_handler, start_time=start_time) @@ -1232,16 +1250,15 @@ def run_analysis(checkpoint: bool = False): message_buffer.update_report_section( "final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}" ) - if judge: - if message_buffer.agent_status.get("Portfolio Manager") != "completed": - message_buffer.update_agent_status("Portfolio Manager", "in_progress") - message_buffer.update_report_section( - "final_trade_decision", f"### Portfolio Manager Decision\n{judge}" - ) - message_buffer.update_agent_status("Aggressive Analyst", "completed") - message_buffer.update_agent_status("Conservative Analyst", "completed") - message_buffer.update_agent_status("Neutral Analyst", "completed") - message_buffer.update_agent_status("Portfolio Manager", "completed") + if judge and message_buffer.agent_status.get("Portfolio Manager") != "completed": + message_buffer.update_agent_status("Portfolio Manager", "in_progress") + message_buffer.update_report_section( + "final_trade_decision", f"### Portfolio Manager Decision\n{judge}" + ) + message_buffer.update_agent_status("Aggressive Analyst", "completed") + message_buffer.update_agent_status("Conservative Analyst", "completed") + message_buffer.update_agent_status("Neutral Analyst", "completed") + message_buffer.update_agent_status("Portfolio Manager", "completed") # Update the display update_display(layout, stats_handler=stats_handler, start_time=start_time) @@ -1253,7 +1270,6 @@ def run_analysis(checkpoint: bool = False): final_state = {} for chunk in trace: final_state.update(chunk) - decision = graph.process_signal(final_state["final_trade_decision"]) # Update all agent statuses to completed for agent in message_buffer.agent_status: @@ -1265,7 +1281,7 @@ def run_analysis(checkpoint: bool = False): message_buffer.add_message("System", analyst_wall_time_tracker.format_summary()) # Update final report sections - for section in message_buffer.report_sections.keys(): + for section in message_buffer.report_sections: if section in final_state: message_buffer.update_report_section(section, final_state[section]) diff --git a/cli/models.py b/cli/models.py index 7ab91b02d..1eef2b7e3 100644 --- a/cli/models.py +++ b/cli/models.py @@ -1,6 +1,4 @@ from enum import Enum -from typing import List, Optional, Dict -from pydantic import BaseModel class AnalystType(str, Enum): diff --git a/cli/stats_handler.py b/cli/stats_handler.py index 10734cc34..8e2e83bd3 100644 --- a/cli/stats_handler.py +++ b/cli/stats_handler.py @@ -1,9 +1,9 @@ import threading -from typing import Any, Dict, List, Union +from typing import Any from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.outputs import LLMResult from langchain_core.messages import AIMessage +from langchain_core.outputs import LLMResult class StatsCallbackHandler(BaseCallbackHandler): @@ -19,8 +19,8 @@ class StatsCallbackHandler(BaseCallbackHandler): def on_llm_start( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], **kwargs: Any, ) -> None: """Increment LLM call counter when an LLM starts.""" @@ -29,8 +29,8 @@ class StatsCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[Any]], + serialized: dict[str, Any], + messages: list[list[Any]], **kwargs: Any, ) -> None: """Increment LLM call counter when a chat model starts.""" @@ -57,7 +57,7 @@ class StatsCallbackHandler(BaseCallbackHandler): def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, **kwargs: Any, ) -> None: @@ -65,7 +65,7 @@ class StatsCallbackHandler(BaseCallbackHandler): with self._lock: self.tool_calls += 1 - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """Return current statistics.""" with self._lock: return { diff --git a/cli/utils.py b/cli/utils.py index 4583b6107..bdca55a45 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,6 +1,5 @@ import os from pathlib import Path -from typing import List, Optional, Tuple, Dict import questionary from dotenv import find_dotenv, set_key @@ -89,8 +88,8 @@ def detect_asset_type(ticker: str) -> AssetType: def filter_analysts_for_asset_type( - analysts: List[AnalystType], asset_type: AssetType -) -> List[AnalystType]: + analysts: list[AnalystType], asset_type: AssetType +) -> list[AnalystType]: if asset_type != AssetType.CRYPTO: return analysts return [ @@ -133,7 +132,7 @@ def get_analysis_date() -> str: return date.strip() -def select_analysts(asset_type: AssetType = AssetType.STOCK) -> List[AnalystType]: +def select_analysts(asset_type: AssetType = AssetType.STOCK) -> list[AnalystType]: """Select analysts using an interactive checkbox.""" available_analysts = filter_analysts_for_asset_type( [value for _, value in ANALYST_ORDER], @@ -197,7 +196,7 @@ def select_research_depth() -> int: return choice -def _fetch_openrouter_models() -> List[Tuple[str, str]]: +def _fetch_openrouter_models() -> list[tuple[str, str]]: """Fetch available models from the OpenRouter API.""" import requests try: @@ -377,7 +376,7 @@ def select_llm_provider() -> tuple[str, str | None]: ] ), ).ask() - + if choice is None: console.print("\n[red]No LLM provider selected. Exiting...[/red]") exit(1) @@ -556,7 +555,7 @@ def confirm_ollama_endpoint(url: str) -> None: ) -def ensure_api_key(provider: str) -> Optional[str]: +def ensure_api_key(provider: str) -> str | None: """Make sure the API key for `provider` is available in the environment. If the env var is already set, returns its value untouched. Otherwise diff --git a/main.py b/main.py index fea2f3680..c9d3ab038 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ -from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.graph.trading_graph import TradingAgentsGraph # DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides # (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.), diff --git a/scripts/smoke_structured_output.py b/scripts/smoke_structured_output.py index 1d3cf681c..73b7e8f47 100644 --- a/scripts/smoke_structured_output.py +++ b/scripts/smoke_structured_output.py @@ -21,7 +21,6 @@ added, plus the heuristic SignalProcessor. from __future__ import annotations import argparse -import os import sys from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager @@ -30,7 +29,6 @@ from tradingagents.agents.trader.trader import create_trader from tradingagents.graph.signal_processing import SignalProcessor from tradingagents.llm_clients import create_llm_client - PROVIDER_DEFAULTS = { "openai": ("gpt-5.4-mini", None), "google": ("gemini-2.5-flash", None), diff --git a/test.py b/test.py index b73783e1e..f0b93184c 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,8 @@ import time -from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions + +from tradingagents.dataflows.y_finance import ( + get_stock_stats_indicators_window, +) print("Testing optimized implementation with 30-day lookback:") start_time = time.time() diff --git a/tests/test_api_key_env.py b/tests/test_api_key_env.py index dde5a4886..7361ea7b5 100644 --- a/tests/test_api_key_env.py +++ b/tests/test_api_key_env.py @@ -3,14 +3,12 @@ from __future__ import annotations import os -from pathlib import Path from unittest.mock import patch import pytest from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env - # ---- Mapping coverage ----------------------------------------------------- @@ -71,6 +69,7 @@ def test_case_insensitive_lookup(): def cli_utils(monkeypatch): """Import cli.utils with a fresh environment so module-level state is consistent.""" import importlib + import cli.utils as cli_utils_module return importlib.reload(cli_utils_module) diff --git a/tests/test_capabilities.py b/tests/test_capabilities.py index f29d90ff9..64adb9532 100644 --- a/tests/test_capabilities.py +++ b/tests/test_capabilities.py @@ -1,9 +1,10 @@ """Unit tests for the LLM capability table.""" +from dataclasses import FrozenInstanceError + import pytest from tradingagents.llm_clients.capabilities import ( - ModelCapabilities, get_capabilities, ) @@ -119,5 +120,5 @@ class TestDefault: def test_capabilities_dataclass_is_frozen(): """Capability rows are immutable so they can be safely shared.""" caps = get_capabilities("deepseek-chat") - with pytest.raises(Exception): + with pytest.raises(FrozenInstanceError): caps.supports_tool_choice = False # type: ignore[misc] diff --git a/tests/test_checkpoint_resume.py b/tests/test_checkpoint_resume.py index 6f2692bd8..bf801b073 100644 --- a/tests/test_checkpoint_resume.py +++ b/tests/test_checkpoint_resume.py @@ -1,12 +1,9 @@ """Test checkpoint resume: crash mid-analysis, re-run resumes from last node.""" -import sqlite3 import tempfile import unittest -from pathlib import Path from typing import TypedDict -from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.graph import END, StateGraph from tradingagents.graph.checkpointer import ( diff --git a/tests/test_deepseek_reasoning.py b/tests/test_deepseek_reasoning.py index 62c1b3497..5ba719e42 100644 --- a/tests/test_deepseek_reasoning.py +++ b/tests/test_deepseek_reasoning.py @@ -24,7 +24,6 @@ from tradingagents.llm_clients.openai_client import ( _input_to_messages, ) - # --------------------------------------------------------------------------- # _input_to_messages — the helper that handles list / ChatPromptValue / other # (Gemini bot review note: non-list inputs must also work) diff --git a/tests/test_memory_log.py b/tests/test_memory_log.py index 2d0261fac..69fdb8047 100644 --- a/tests/test_memory_log.py +++ b/tests/test_memory_log.py @@ -1,15 +1,16 @@ """Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal.""" -import pytest -import pandas as pd from unittest.mock import MagicMock, patch -from tradingagents.agents.utils.memory import TradingMemoryLog +import pandas as pd +import pytest + +from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating +from tradingagents.agents.utils.memory import TradingMemoryLog +from tradingagents.graph.propagation import Propagator from tradingagents.graph.reflection import Reflector from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.graph.propagation import Propagator -from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager _SEP = TradingMemoryLog._SEPARATOR diff --git a/tests/test_news_lookahead.py b/tests/test_news_lookahead.py index f58d54c62..97dbc1c1b 100644 --- a/tests/test_news_lookahead.py +++ b/tests/test_news_lookahead.py @@ -6,7 +6,6 @@ news injected future articles), #993 (empty-after-filter returned a blank body). """ import time from datetime import datetime -from unittest import mock import pytest diff --git a/tests/test_no_data_handling.py b/tests/test_no_data_handling.py index e0fdfb7a8..e4bc7845b 100644 --- a/tests/test_no_data_handling.py +++ b/tests/test_no_data_handling.py @@ -14,7 +14,7 @@ from unittest import mock import pandas as pd import pytest -from tradingagents.dataflows import stockstats_utils, interface +from tradingagents.dataflows import interface, stockstats_utils from tradingagents.dataflows.config import set_config from tradingagents.dataflows.symbol_utils import NoMarketDataError @@ -33,9 +33,9 @@ class TestLoadOhlcvNoPoison(unittest.TestCase): def test_empty_download_raises_and_does_not_cache(self): empty = pd.DataFrame() - with mock.patch.object(stockstats_utils.yf, "download", return_value=empty) as dl: - with self.assertRaises(NoMarketDataError): - stockstats_utils.load_ohlcv("FAKE", "2026-01-01") + with mock.patch.object(stockstats_utils.yf, "download", return_value=empty), \ + self.assertRaises(NoMarketDataError): + stockstats_utils.load_ohlcv("FAKE", "2026-01-01") # Nothing should have been written to the cache. self.assertEqual(os.listdir(self._tmp), []) diff --git a/tests/test_ollama_base_url.py b/tests/test_ollama_base_url.py index f226a976d..5dd0f85c8 100644 --- a/tests/test_ollama_base_url.py +++ b/tests/test_ollama_base_url.py @@ -18,8 +18,8 @@ def _resync_reloaded_modules(): doesn't leak across test modules. """ yield - import cli.utils import cli.main + import cli.utils importlib.reload(cli.utils) importlib.reload(cli.main) diff --git a/tests/test_openai_compatible_provider.py b/tests/test_openai_compatible_provider.py index 43d91c984..ab253e26a 100644 --- a/tests/test_openai_compatible_provider.py +++ b/tests/test_openai_compatible_provider.py @@ -4,7 +4,6 @@ Verifies the user-supplied base_url is required and honored, the key is optional (keyless local default), Chat Completions (not the Responses API) is used, any model name is accepted, and the env backend URL precedence (#978). """ -import os import pytest diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 4bbfb7475..92520a8d1 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -13,7 +13,6 @@ import pytest from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating from tradingagents.graph.signal_processing import SignalProcessor - # --------------------------------------------------------------------------- # Heuristic parser # --------------------------------------------------------------------------- diff --git a/tests/test_structured_agents.py b/tests/test_structured_agents.py index 5927f2d11..d80063206 100644 --- a/tests/test_structured_agents.py +++ b/tests/test_structured_agents.py @@ -27,7 +27,6 @@ from tradingagents.agents.schemas import ( ) from tradingagents.agents.trader.trader import create_trader - # --------------------------------------------------------------------------- # Render functions # --------------------------------------------------------------------------- diff --git a/tests/test_symbol_utils.py b/tests/test_symbol_utils.py index 7af509f75..774c0f9bf 100644 --- a/tests/test_symbol_utils.py +++ b/tests/test_symbol_utils.py @@ -6,8 +6,8 @@ import pytest from tradingagents.dataflows.symbol_utils import ( NoMarketDataError, - normalize_symbol, is_yahoo_safe, + normalize_symbol, ) diff --git a/tests/test_vendor_routing.py b/tests/test_vendor_routing.py index 57e36a134..170f07705 100644 --- a/tests/test_vendor_routing.py +++ b/tests/test_vendor_routing.py @@ -6,7 +6,6 @@ Regressions for #988 (explicit single-vendor config still fell back to others), were swallowed without a trace). """ import copy -import logging import unittest from unittest import mock @@ -76,9 +75,9 @@ class VendorRoutingTests(unittest.TestCase): # #989: primary errors + fallback no-data -> NO_DATA, but the failure # must be visible in logs (broken primary not hidden). set_config({"data_vendors": {"core_stock_apis": "yfinance,alpha_vantage"}}) - with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}): - with self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm: - result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10") + with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}), \ + self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm: + result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10") self.assertIn("NO_DATA_AVAILABLE", result) joined = "\n".join(cm.output) self.assertIn("boom", joined) # the real error surfaced in logs diff --git a/tradingagents/__init__.py b/tradingagents/__init__.py index 5f83f2a52..87e37b3cc 100644 --- a/tradingagents/__init__.py +++ b/tradingagents/__init__.py @@ -1,3 +1,4 @@ +import contextlib import warnings # Load .env files at package import so DEFAULT_CONFIG's env-var overlay @@ -20,10 +21,8 @@ except ImportError: # subclassed warning categories. To suppress a specific warning we must # install our filter AFTER langchain-core has installed its own, so import # it first. The package is a guaranteed transitive dep via langgraph. -try: +with contextlib.suppress(ImportError): import langchain_core # noqa: F401 -except ImportError: - pass # langgraph-checkpoint 4.0.3 calls Reviver() at module load without an # explicit allowed_objects, which triggers a noisy pending-deprecation diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index f88261408..b78803a04 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -1,6 +1,3 @@ -from .utils.agent_utils import create_msg_delete -from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState - from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.market_analyst import create_market_analyst from .analysts.news_analyst import create_news_analyst @@ -8,18 +5,16 @@ from .analysts.sentiment_analyst import ( create_sentiment_analyst, create_social_media_analyst, # deprecated alias kept for back-compat ) - +from .managers.portfolio_manager import create_portfolio_manager +from .managers.research_manager import create_research_manager from .researchers.bear_researcher import create_bear_researcher from .researchers.bull_researcher import create_bull_researcher - from .risk_mgmt.aggressive_debator import create_aggressive_debator from .risk_mgmt.conservative_debator import create_conservative_debator from .risk_mgmt.neutral_debator import create_neutral_debator - -from .managers.research_manager import create_research_manager -from .managers.portfolio_manager import create_portfolio_manager - from .trader.trader import create_trader +from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState +from .utils.agent_utils import create_msg_delete __all__ = [ "AgentState", diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index b2ea3bcfb..0428e676b 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,14 +1,13 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + from tradingagents.agents.utils.agent_utils import ( - get_instrument_context_from_state, get_balance_sheet, get_cashflow, get_fundamentals, get_income_statement, - get_insider_transactions, + get_instrument_context_from_state, get_language_instruction, ) -from tradingagents.dataflows.config import get_config def create_fundamentals_analyst(llm): diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 87fca70de..d41a28db3 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,12 +1,12 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + from tradingagents.agents.utils.agent_utils import ( - get_instrument_context_from_state, get_indicators, + get_instrument_context_from_state, get_language_instruction, get_stock_data, get_verified_market_snapshot, ) -from tradingagents.dataflows.config import get_config def create_market_analyst(llm): diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index 996974549..a0d7fde86 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,13 +1,13 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + from tradingagents.agents.utils.agent_utils import ( - get_instrument_context_from_state, get_global_news, + get_instrument_context_from_state, get_language_instruction, get_macro_indicators, get_news, get_prediction_markets, ) -from tradingagents.dataflows.config import get_config def create_news_analyst(llm): diff --git a/tradingagents/agents/schemas.py b/tradingagents/agents/schemas.py index f89878c41..2e852b25f 100644 --- a/tradingagents/agents/schemas.py +++ b/tradingagents/agents/schemas.py @@ -19,11 +19,10 @@ so that: from __future__ import annotations from enum import Enum -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field - # --------------------------------------------------------------------------- # Shared rating types # --------------------------------------------------------------------------- @@ -124,15 +123,15 @@ class TraderProposal(BaseModel): "the research plan. Two to four sentences." ), ) - entry_price: Optional[float] = Field( + entry_price: float | None = Field( default=None, description="Optional entry price target in the instrument's quote currency.", ) - stop_loss: Optional[float] = Field( + stop_loss: float | None = Field( default=None, description="Optional stop-loss price in the instrument's quote currency.", ) - position_sizing: Optional[str] = Field( + position_sizing: str | None = Field( default=None, description="Optional sizing guidance, e.g. '5% of portfolio'.", ) @@ -196,11 +195,11 @@ class PortfolioDecision(BaseModel): "incorporate them; otherwise rely solely on the current analysis." ), ) - price_target: Optional[float] = Field( + price_target: float | None = Field( default=None, description="Optional target price in the instrument's quote currency.", ) - time_horizon: Optional[str] = Field( + time_horizon: str | None = Field( default=None, description="Optional recommended holding period, e.g. '3-6 months'.", ) diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index cce73c79d..82504ca03 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,6 +1,7 @@ from typing import Annotated -from typing_extensions import TypedDict + from langgraph.graph import MessagesState +from typing_extensions import TypedDict # Researcher team state diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 73201b785..16f45c178 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -1,37 +1,50 @@ import functools import logging -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any import yfinance as yf from langchain_core.messages import HumanMessage, RemoveMessage # Import tools from separate utility files -from tradingagents.agents.utils.core_stock_tools import ( - get_stock_data -) -from tradingagents.agents.utils.technical_indicators_tools import ( - get_indicators -) +from tradingagents.agents.utils.core_stock_tools import get_stock_data from tradingagents.agents.utils.fundamental_data_tools import ( - get_fundamentals, get_balance_sheet, get_cashflow, - get_income_statement + get_fundamentals, + get_income_statement, ) +from tradingagents.agents.utils.macro_data_tools import get_macro_indicators +from tradingagents.agents.utils.market_data_validation_tools import get_verified_market_snapshot from tradingagents.agents.utils.news_data_tools import ( - get_news, + get_global_news, get_insider_transactions, - get_global_news -) -from tradingagents.agents.utils.macro_data_tools import ( - get_macro_indicators -) -from tradingagents.agents.utils.prediction_markets_tools import ( - get_prediction_markets -) -from tradingagents.agents.utils.market_data_validation_tools import ( - get_verified_market_snapshot + get_news, ) +from tradingagents.agents.utils.prediction_markets_tools import get_prediction_markets +from tradingagents.agents.utils.technical_indicators_tools import get_indicators + +# Public surface: the data tools are imported here so agents and the graph +# import them from one place, plus the instrument/language helpers defined below. +__all__ = [ + "get_stock_data", + "get_indicators", + "get_fundamentals", + "get_balance_sheet", + "get_cashflow", + "get_income_statement", + "get_news", + "get_global_news", + "get_insider_transactions", + "get_macro_indicators", + "get_prediction_markets", + "get_verified_market_snapshot", + "build_instrument_context", + "resolve_instrument_identity", + "get_instrument_context_from_state", + "get_language_instruction", + "create_msg_delete", +] logger = logging.getLogger(__name__) @@ -52,7 +65,7 @@ def get_language_instruction() -> str: return f" Write your entire response in {lang}." -def _clean_identity_value(value: Any) -> Optional[str]: +def _clean_identity_value(value: Any) -> str | None: """Return a trimmed string, or None for empty / placeholder-ish values.""" if not isinstance(value, str): return None @@ -109,7 +122,7 @@ def resolve_instrument_identity(ticker: str) -> dict: def build_instrument_context( ticker: str, asset_type: str = "stock", - identity: Optional[Mapping[str, str]] = None, + identity: Mapping[str, str] | None = None, ) -> str: """Describe the exact instrument so agents preserve identity and ticker. @@ -201,4 +214,4 @@ def create_msg_delete(): return delete_messages - + diff --git a/tradingagents/agents/utils/core_stock_tools.py b/tradingagents/agents/utils/core_stock_tools.py index 3a4166222..bd5cabfad 100644 --- a/tradingagents/agents/utils/core_stock_tools.py +++ b/tradingagents/agents/utils/core_stock_tools.py @@ -1,5 +1,7 @@ -from langchain_core.tools import tool from typing import Annotated + +from langchain_core.tools import tool + from tradingagents.dataflows.interface import route_to_vendor diff --git a/tradingagents/agents/utils/fundamental_data_tools.py b/tradingagents/agents/utils/fundamental_data_tools.py index 47f6f2ebf..aefa0de7c 100644 --- a/tradingagents/agents/utils/fundamental_data_tools.py +++ b/tradingagents/agents/utils/fundamental_data_tools.py @@ -1,5 +1,7 @@ -from langchain_core.tools import tool from typing import Annotated + +from langchain_core.tools import tool + from tradingagents.dataflows.interface import route_to_vendor @@ -74,4 +76,4 @@ def get_income_statement( Returns: str: A formatted report containing income statement data """ - return route_to_vendor("get_income_statement", ticker, freq, curr_date) \ No newline at end of file + return route_to_vendor("get_income_statement", ticker, freq, curr_date) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index c94717556..ff9e94579 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,8 +1,7 @@ """Append-only markdown decision log for TradingAgents.""" -from typing import List, Optional -from pathlib import Path import re +from pathlib import Path from tradingagents.agents.utils.rating import parse_rating @@ -51,7 +50,7 @@ class TradingMemoryLog: # --- Read path (Phase A) --- - def load_entries(self) -> List[dict]: + def load_entries(self) -> list[dict]: """Parse all entries from log. Returns list of dicts.""" if not self._log_path or not self._log_path.exists(): return [] @@ -64,7 +63,7 @@ class TradingMemoryLog: entries.append(parsed) return entries - def get_pending_entries(self) -> List[dict]: + def get_pending_entries(self) -> list[dict]: """Return entries with outcome:pending (for Phase B).""" return [e for e in self.load_entries() if e.get("pending")] @@ -162,7 +161,7 @@ class TradingMemoryLog: tmp_path.write_text(new_text, encoding="utf-8") tmp_path.replace(self._log_path) - def batch_update_with_outcomes(self, updates: List[dict]) -> None: + def batch_update_with_outcomes(self, updates: list[dict]) -> None: """Apply multiple outcome updates in a single read + atomic write. Each element of updates must have keys: ticker, trade_date, @@ -218,7 +217,7 @@ class TradingMemoryLog: # --- Helpers --- - def _apply_rotation(self, blocks: List[str]) -> List[str]: + def _apply_rotation(self, blocks: list[str]) -> list[str]: """Drop oldest resolved blocks when their count exceeds max_entries. Pending blocks are always kept (they represent unprocessed work). @@ -247,7 +246,7 @@ class TradingMemoryLog: return blocks to_drop = resolved_count - self._max_entries - kept: List[str] = [] + kept: list[str] = [] for block, is_resolved in decisions: if is_resolved and to_drop > 0: to_drop -= 1 @@ -255,7 +254,7 @@ class TradingMemoryLog: kept.append(block) return kept - def _parse_entry(self, raw: str) -> Optional[dict]: + def _parse_entry(self, raw: str) -> dict | None: lines = raw.strip().splitlines() if not lines: return None diff --git a/tradingagents/agents/utils/news_data_tools.py b/tradingagents/agents/utils/news_data_tools.py index f503c4d3a..6c6f5352a 100644 --- a/tradingagents/agents/utils/news_data_tools.py +++ b/tradingagents/agents/utils/news_data_tools.py @@ -1,7 +1,10 @@ +from typing import Annotated + from langchain_core.tools import tool -from typing import Annotated, Optional + from tradingagents.dataflows.interface import route_to_vendor + @tool def get_news( ticker: Annotated[str, "Ticker symbol"], @@ -23,8 +26,8 @@ def get_news( @tool def get_global_news( curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], - look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None, - limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None, + look_back_days: Annotated[int | None, "Days to look back; omit to use the configured default"] = None, + limit: Annotated[int | None, "Max articles to return; omit to use the configured default"] = None, ) -> str: """ Retrieve global news data. diff --git a/tradingagents/agents/utils/rating.py b/tradingagents/agents/utils/rating.py index d5032346a..234bc568d 100644 --- a/tradingagents/agents/utils/rating.py +++ b/tradingagents/agents/utils/rating.py @@ -12,11 +12,9 @@ Centralising it here avoids drift between those call sites. from __future__ import annotations import re -from typing import Tuple - # Canonical, ordered 5-tier scale (most bullish to most bearish). -RATINGS_5_TIER: Tuple[str, ...] = ( +RATINGS_5_TIER: tuple[str, ...] = ( "Buy", "Overweight", "Hold", "Underweight", "Sell", ) diff --git a/tradingagents/agents/utils/structured.py b/tradingagents/agents/utils/structured.py index 400e1a82b..28dac8b55 100644 --- a/tradingagents/agents/utils/structured.py +++ b/tradingagents/agents/utils/structured.py @@ -19,7 +19,8 @@ all three agents log the same warnings when fallback fires. from __future__ import annotations import logging -from typing import Any, Callable, Optional, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from pydantic import BaseModel @@ -28,7 +29,7 @@ logger = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) -def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]: +def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Any | None: """Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported. Logs a warning when the binding fails so the user understands the agent @@ -46,7 +47,7 @@ def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any] def invoke_structured_or_freetext( - structured_llm: Optional[Any], + structured_llm: Any | None, plain_llm: Any, prompt: Any, render: Callable[[T], str], diff --git a/tradingagents/agents/utils/technical_indicators_tools.py b/tradingagents/agents/utils/technical_indicators_tools.py index a3dda5a51..c1c5ad99f 100644 --- a/tradingagents/agents/utils/technical_indicators_tools.py +++ b/tradingagents/agents/utils/technical_indicators_tools.py @@ -1,7 +1,10 @@ -from langchain_core.tools import tool from typing import Annotated + +from langchain_core.tools import tool + from tradingagents.dataflows.interface import route_to_vendor + @tool def get_indicators( symbol: Annotated[str, "ticker symbol of the company"], @@ -29,4 +32,4 @@ def get_indicators( results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days)) except ValueError as e: results.append(str(e)) - return "\n\n".join(results) \ No newline at end of file + return "\n\n".join(results) diff --git a/tradingagents/dataflows/alpha_vantage.py b/tradingagents/dataflows/alpha_vantage.py index b2be1d611..90032a72b 100644 --- a/tradingagents/dataflows/alpha_vantage.py +++ b/tradingagents/dataflows/alpha_vantage.py @@ -1,5 +1,23 @@ -# Import functions from specialized modules -from .alpha_vantage_stock import get_stock +# Aggregates the per-category Alpha Vantage implementations into one module the +# vendor router imports from; the imports below are the public surface. +from .alpha_vantage_fundamentals import ( + get_balance_sheet, + get_cashflow, + get_fundamentals, + get_income_statement, +) from .alpha_vantage_indicator import get_indicator -from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement -from .alpha_vantage_news import get_news, get_global_news, get_insider_transactions \ No newline at end of file +from .alpha_vantage_news import get_global_news, get_insider_transactions, get_news +from .alpha_vantage_stock import get_stock + +__all__ = [ + "get_balance_sheet", + "get_cashflow", + "get_fundamentals", + "get_income_statement", + "get_indicator", + "get_global_news", + "get_insider_transactions", + "get_news", + "get_stock", +] diff --git a/tradingagents/dataflows/alpha_vantage_indicator.py b/tradingagents/dataflows/alpha_vantage_indicator.py index 53623ead4..dcfd17ab6 100644 --- a/tradingagents/dataflows/alpha_vantage_indicator.py +++ b/tradingagents/dataflows/alpha_vantage_indicator.py @@ -1,4 +1,5 @@ -from .alpha_vantage_common import _make_api_request, AlphaVantageNotConfiguredError +from .alpha_vantage_common import AlphaVantageNotConfiguredError, _make_api_request + def get_indicator( symbol: str, @@ -25,6 +26,7 @@ def get_indicator( String containing indicator values and description """ from datetime import datetime + from dateutil.relativedelta import relativedelta supported_indicators = { @@ -98,21 +100,7 @@ def get_indicator( "series_type": series_type, "datatype": "csv" }) - elif indicator == "macd": - data = _make_api_request("MACD", { - "symbol": symbol, - "interval": interval, - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "macds": - data = _make_api_request("MACD", { - "symbol": symbol, - "interval": interval, - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "macdh": + elif indicator == "macd" or indicator == "macds" or indicator == "macdh": data = _make_api_request("MACD", { "symbol": symbol, "interval": interval, diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py index 4cf7bb0ec..f9c7cfc99 100644 --- a/tradingagents/dataflows/alpha_vantage_news.py +++ b/tradingagents/dataflows/alpha_vantage_news.py @@ -1,5 +1,6 @@ from .alpha_vantage_common import _make_api_request, format_datetime_for_api + def get_news(ticker, start_date, end_date) -> dict[str, str] | str: """Returns live and historical market news & sentiment data from premier news outlets worldwide. @@ -68,4 +69,4 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str: "symbol": symbol, } - return _make_api_request("INSIDER_TRANSACTIONS", params) \ No newline at end of file + return _make_api_request("INSIDER_TRANSACTIONS", params) diff --git a/tradingagents/dataflows/alpha_vantage_stock.py b/tradingagents/dataflows/alpha_vantage_stock.py index ffd3570b3..43d693b9c 100644 --- a/tradingagents/dataflows/alpha_vantage_stock.py +++ b/tradingagents/dataflows/alpha_vantage_stock.py @@ -1,5 +1,7 @@ from datetime import datetime -from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range + +from .alpha_vantage_common import _filter_csv_by_date_range, _make_api_request + def get_stock( symbol: str, @@ -35,4 +37,4 @@ def get_stock( response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params) - return _filter_csv_by_date_range(response, start_date, end_date) \ No newline at end of file + return _filter_csv_by_date_range(response, start_date, end_date) diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index 6f3076aea..c4790e847 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -1,10 +1,9 @@ from copy import deepcopy -from typing import Dict, Optional import tradingagents.default_config as default_config # Use default config but allow it to be overridden -_config: Optional[Dict] = None +_config: dict | None = None def initialize_config(): @@ -14,7 +13,7 @@ def initialize_config(): _config = deepcopy(default_config.DEFAULT_CONFIG) -def set_config(config: Dict): +def set_config(config: dict): """Update the configuration with custom values. Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a @@ -31,7 +30,7 @@ def set_config(config: Dict): _config[key] = value -def get_config() -> Dict: +def get_config() -> dict: """Get the current configuration.""" if _config is None: initialize_config() diff --git a/tradingagents/dataflows/market_data_validator.py b/tradingagents/dataflows/market_data_validator.py index d1992c0f8..baa5efecf 100644 --- a/tradingagents/dataflows/market_data_validator.py +++ b/tradingagents/dataflows/market_data_validator.py @@ -10,7 +10,7 @@ claim. Deterministic, no LLM involved. from __future__ import annotations -from typing import Iterable, Optional +from collections.abc import Iterable import pandas as pd from stockstats import wrap @@ -63,7 +63,7 @@ def build_verified_market_snapshot( symbol: str, curr_date: str, look_back_days: int = 30, - indicators: Optional[Iterable[str]] = None, + indicators: Iterable[str] | None = None, ) -> str: """Render a ground-truth snapshot: latest OHLCV row, indicators, recent closes.""" # `df` keeps the original capitalized OHLCV columns (Open/High/Low/Close/ diff --git a/tradingagents/dataflows/utils.py b/tradingagents/dataflows/utils.py index 0cf6a6032..b5d48d079 100644 --- a/tradingagents/dataflows/utils.py +++ b/tradingagents/dataflows/utils.py @@ -1,10 +1,9 @@ -import os import re -import json -import pandas as pd -from datetime import date, timedelta, datetime +from datetime import date, datetime, timedelta from typing import Annotated +import pandas as pd + SavePathType = Annotated[str, "File path to save data. If None, data is not saved."] # Tickers can contain letters, digits, dot, dash, underscore, caret diff --git a/tradingagents/dataflows/yfinance_news.py b/tradingagents/dataflows/yfinance_news.py index fa0c342be..3c8e4ba6f 100644 --- a/tradingagents/dataflows/yfinance_news.py +++ b/tradingagents/dataflows/yfinance_news.py @@ -1,9 +1,9 @@ """yfinance-based news data fetching functions.""" -from typing import Optional +import contextlib +from datetime import datetime import yfinance as yf -from datetime import datetime from dateutil.relativedelta import relativedelta from .config import get_config @@ -28,10 +28,8 @@ def _extract_article_data(article: dict) -> dict: pub_date_str = content.get("pubDate", "") pub_date = None if pub_date_str: - try: + with contextlib.suppress(ValueError, AttributeError): pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00")) - except (ValueError, AttributeError): - pass return { "title": title, @@ -47,10 +45,8 @@ def _extract_article_data(article: dict) -> dict: pub_date = None ts = article.get("providerPublishTime") if ts: - try: + with contextlib.suppress(ValueError, OSError, TypeError): pub_date = datetime.fromtimestamp(ts) - except (ValueError, OSError, TypeError): - pass return { "title": article.get("title", "No title"), "summary": article.get("summary", ""), @@ -131,8 +127,8 @@ def get_news_yfinance( def get_global_news_yfinance( curr_date: str, - look_back_days: Optional[int] = None, - limit: Optional[int] = None, + look_back_days: int | None = None, + limit: int | None = None, ) -> str: """ Retrieve global/macro economic news using yfinance Search. diff --git a/tradingagents/graph/__init__.py b/tradingagents/graph/__init__.py index 80982c199..901edddd8 100644 --- a/tradingagents/graph/__init__.py +++ b/tradingagents/graph/__init__.py @@ -1,11 +1,11 @@ # TradingAgents/graph/__init__.py -from .trading_graph import TradingAgentsGraph from .conditional_logic import ConditionalLogic -from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector +from .setup import GraphSetup from .signal_processing import SignalProcessor +from .trading_graph import TradingAgentsGraph __all__ = [ "TradingAgentsGraph", diff --git a/tradingagents/graph/analyst_execution.py b/tradingagents/graph/analyst_execution.py index 14c375338..6587ea7eb 100644 --- a/tradingagents/graph/analyst_execution.py +++ b/tradingagents/graph/analyst_execution.py @@ -1,6 +1,6 @@ +from collections.abc import Iterable from dataclasses import dataclass from time import monotonic -from typing import Dict, Iterable, List, Optional @dataclass(frozen=True) @@ -14,11 +14,11 @@ class AnalystNodeSpec: @dataclass(frozen=True) class AnalystExecutionPlan: - specs: List[AnalystNodeSpec] + specs: list[AnalystNodeSpec] concurrency_limit: int -ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = { +ANALYST_NODE_SPECS: dict[str, AnalystNodeSpec] = { "market": AnalystNodeSpec( key="market", agent_node="Market Analyst", @@ -61,7 +61,7 @@ def build_analyst_execution_plan( if concurrency_limit < 1: raise ValueError("analyst concurrency limit must be >= 1") - specs: List[AnalystNodeSpec] = [] + specs: list[AnalystNodeSpec] = [] for analyst_key in selected_analysts: spec = ANALYST_NODE_SPECS.get(analyst_key) if spec is None: @@ -81,10 +81,10 @@ def get_initial_analyst_node(plan: AnalystExecutionPlan) -> str: class AnalystWallTimeTracker: def __init__(self, plan: AnalystExecutionPlan): self.plan = plan - self._started_at: Dict[str, float] = {} - self._wall_times: Dict[str, float] = {} + self._started_at: dict[str, float] = {} + self._wall_times: dict[str, float] = {} - def mark_started(self, analyst_key: str, started_at: Optional[float] = None) -> None: + def mark_started(self, analyst_key: str, started_at: float | None = None) -> None: if analyst_key not in ANALYST_NODE_SPECS: raise ValueError(f"unknown analyst key: {analyst_key}") self._started_at.setdefault(analyst_key, monotonic() if started_at is None else started_at) @@ -92,7 +92,7 @@ class AnalystWallTimeTracker: def mark_completed( self, analyst_key: str, - completed_at: Optional[float] = None, + completed_at: float | None = None, ) -> None: if analyst_key not in ANALYST_NODE_SPECS: raise ValueError(f"unknown analyst key: {analyst_key}") @@ -104,7 +104,7 @@ class AnalystWallTimeTracker: finished_at = monotonic() if completed_at is None else completed_at self._wall_times[analyst_key] = max(0.0, finished_at - started_at) - def get_wall_times(self) -> Dict[str, float]: + def get_wall_times(self) -> dict[str, float]: return dict(self._wall_times) def format_summary(self) -> str: @@ -121,8 +121,8 @@ class AnalystWallTimeTracker: def sync_analyst_tracker_from_chunk( tracker: AnalystWallTimeTracker, - chunk: Dict[str, str], - now: Optional[float] = None, + chunk: dict[str, str], + now: float | None = None, ) -> None: current_time = monotonic() if now is None else now active_found = False diff --git a/tradingagents/graph/checkpointer.py b/tradingagents/graph/checkpointer.py index 3ba19726d..a32cb4cb2 100644 --- a/tradingagents/graph/checkpointer.py +++ b/tradingagents/graph/checkpointer.py @@ -7,9 +7,9 @@ from __future__ import annotations import hashlib import sqlite3 +from collections.abc import Generator from contextlib import contextmanager from pathlib import Path -from typing import Generator from langgraph.checkpoint.sqlite import SqliteSaver diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index edd93f55e..2a6aeddb6 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -1,8 +1,8 @@ # TradingAgents/graph/propagation.py -from typing import Dict, Any, List, Optional +from typing import Any + from tradingagents.agents.utils.agent_states import ( - AgentState, InvestDebateState, RiskDebateState, ) @@ -22,7 +22,7 @@ class Propagator: asset_type: str = "stock", past_context: str = "", instrument_context: str = "", - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Create the initial state for the agent graph. ``instrument_context`` is the deterministic ticker-identity string @@ -68,7 +68,7 @@ class Propagator: "news_report": "", } - def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]: + def get_graph_args(self, callbacks: list | None = None) -> dict[str, Any]: """Get arguments for the graph invocation. Args: diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 1e6b4da72..8771dd0bb 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,10 +1,25 @@ # TradingAgents/graph/setup.py -from typing import Any, Dict +from typing import Any + from langgraph.graph import END, START, StateGraph from langgraph.prebuilt import ToolNode -from tradingagents.agents import * +from tradingagents.agents import ( + create_aggressive_debator, + create_bear_researcher, + create_bull_researcher, + create_conservative_debator, + create_fundamentals_analyst, + create_market_analyst, + create_msg_delete, + create_neutral_debator, + create_news_analyst, + create_portfolio_manager, + create_research_manager, + create_sentiment_analyst, + create_trader, +) from tradingagents.agents.utils.agent_states import AgentState from .analyst_execution import build_analyst_execution_plan @@ -18,7 +33,7 @@ class GraphSetup: self, quick_thinking_llm: Any, deep_thinking_llm: Any, - tool_nodes: Dict[str, ToolNode], + tool_nodes: dict[str, ToolNode], conditional_logic: ConditionalLogic, analyst_concurrency_limit: int = 1, ): @@ -30,7 +45,7 @@ class GraphSetup: self.analyst_concurrency_limit = analyst_concurrency_limit def setup_graph( - self, selected_analysts=["market", "social", "news", "fundamentals"] + self, selected_analysts=("market", "social", "news", "fundamentals") ): """Set up and compile the agent workflow graph. diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index f9a0d9446..f72ede83d 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,66 +1,57 @@ # TradingAgents/graph/trading_graph.py +import json import logging import os -from pathlib import Path -import json from datetime import datetime, timedelta -from typing import Dict, Any, Tuple, List, Optional +from pathlib import Path +from typing import Any import yfinance as yf - -logger = logging.getLogger(__name__) - from langgraph.prebuilt import ToolNode -from tradingagents.llm_clients import create_llm_client - -from tradingagents.agents import * -from tradingagents.default_config import DEFAULT_CONFIG -from tradingagents.agents.utils.memory import TradingMemoryLog -from tradingagents.dataflows.utils import safe_ticker_component -from tradingagents.agents.utils.agent_states import ( - AgentState, - InvestDebateState, - RiskDebateState, -) -from tradingagents.dataflows.config import set_config - -# Import the new abstract tool methods from agent_utils +# Import the abstract tool methods from agent_utils from tradingagents.agents.utils.agent_utils import ( build_instrument_context, - resolve_instrument_identity, - get_stock_data, - get_indicators, - get_verified_market_snapshot, - get_fundamentals, get_balance_sheet, get_cashflow, - get_income_statement, - get_news, - get_insider_transactions, + get_fundamentals, get_global_news, + get_income_statement, + get_indicators, + get_insider_transactions, get_macro_indicators, - get_prediction_markets + get_news, + get_prediction_markets, + get_stock_data, + get_verified_market_snapshot, + resolve_instrument_identity, ) +from tradingagents.agents.utils.memory import TradingMemoryLog +from tradingagents.dataflows.config import set_config +from tradingagents.dataflows.utils import safe_ticker_component +from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.llm_clients import create_llm_client from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id from .conditional_logic import ConditionalLogic -from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector +from .setup import GraphSetup from .signal_processing import SignalProcessor +logger = logging.getLogger(__name__) + class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" def __init__( self, - selected_analysts=["market", "social", "news", "fundamentals"], + selected_analysts=("market", "social", "news", "fundamentals"), debug=False, - config: Dict[str, Any] = None, - callbacks: Optional[List] = None, + config: dict[str, Any] = None, + callbacks: list | None = None, ): """Initialize the trading agents graph and components. @@ -103,7 +94,7 @@ class TradingAgentsGraph: self.deep_thinking_llm = deep_client.get_llm() self.quick_thinking_llm = quick_client.get_llm() - + self.memory_log = TradingMemoryLog(self.config) # Create tool nodes @@ -138,7 +129,7 @@ class TradingAgentsGraph: self.graph = self.workflow.compile() self._checkpointer_ctx = None - def _get_provider_kwargs(self) -> Dict[str, Any]: + def _get_provider_kwargs(self) -> dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" kwargs = {} provider = self.config.get("llm_provider", "").lower() @@ -167,7 +158,7 @@ class TradingAgentsGraph: return kwargs - def _create_tool_nodes(self) -> Dict[str, ToolNode]: + def _create_tool_nodes(self) -> dict[str, ToolNode]: """Create tool nodes for different data sources using abstract methods.""" return { "market": ToolNode( @@ -233,7 +224,7 @@ class TradingAgentsGraph: def _fetch_returns( self, ticker: str, trade_date: str, holding_days: int = 5, benchmark: str = "SPY", - ) -> Tuple[Optional[float], Optional[float], Optional[int]]: + ) -> tuple[float | None, float | None, int | None]: """Fetch raw and alpha return for ticker over holding_days from trade_date. ``benchmark`` is the index used as the alpha baseline (resolved by the diff --git a/tradingagents/llm_clients/anthropic_client.py b/tradingagents/llm_clients/anthropic_client.py index 0ac94b99f..8449ba680 100644 --- a/tradingagents/llm_clients/anthropic_client.py +++ b/tradingagents/llm_clients/anthropic_client.py @@ -1,5 +1,5 @@ import re -from typing import Any, Optional +from typing import Any from langchain_anthropic import ChatAnthropic @@ -43,7 +43,7 @@ class NormalizedChatAnthropic(ChatAnthropic): class AnthropicClient(BaseLLMClient): """Client for Anthropic Claude models.""" - def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + def __init__(self, model: str, base_url: str | None = None, **kwargs): super().__init__(model, base_url, **kwargs) def get_llm(self) -> Any: diff --git a/tradingagents/llm_clients/api_key_env.py b/tradingagents/llm_clients/api_key_env.py index 8521e1602..b97db38aa 100644 --- a/tradingagents/llm_clients/api_key_env.py +++ b/tradingagents/llm_clients/api_key_env.py @@ -11,10 +11,7 @@ prompts for it automatically instead of failing on first API call. from __future__ import annotations -from typing import Optional - - -PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = { +PROVIDER_API_KEY_ENV: dict[str, str | None] = { "openai": "OPENAI_API_KEY", "anthropic": "ANTHROPIC_API_KEY", "google": "GOOGLE_API_KEY", @@ -47,7 +44,7 @@ PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = { } -def get_api_key_env(provider: str) -> Optional[str]: +def get_api_key_env(provider: str) -> str | None: """Return the env var name for `provider`'s API key, or None if not applicable. Unknown providers also return None — callers should treat that as diff --git a/tradingagents/llm_clients/azure_client.py b/tradingagents/llm_clients/azure_client.py index a17e6fc2b..f6f996edd 100644 --- a/tradingagents/llm_clients/azure_client.py +++ b/tradingagents/llm_clients/azure_client.py @@ -1,10 +1,9 @@ import os -from typing import Any, Optional +from typing import Any from langchain_openai import AzureChatOpenAI from .base_client import BaseLLMClient, normalize_content -from .validators import validate_model _PASSTHROUGH_KWARGS = ( "timeout", "max_retries", "api_key", "reasoning_effort", "temperature", @@ -29,7 +28,7 @@ class AzureOpenAIClient(BaseLLMClient): OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview) """ - def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + def __init__(self, model: str, base_url: str | None = None, **kwargs): super().__init__(model, base_url, **kwargs) def get_llm(self) -> Any: diff --git a/tradingagents/llm_clients/base_client.py b/tradingagents/llm_clients/base_client.py index f29d713dd..cb56e25f2 100644 --- a/tradingagents/llm_clients/base_client.py +++ b/tradingagents/llm_clients/base_client.py @@ -1,6 +1,6 @@ -from abc import ABC, abstractmethod -from typing import Any, Optional import warnings +from abc import ABC, abstractmethod +from typing import Any def normalize_content(response): @@ -25,7 +25,7 @@ def normalize_content(response): class BaseLLMClient(ABC): """Abstract base class for LLM clients.""" - def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + def __init__(self, model: str, base_url: str | None = None, **kwargs): self.model = model self.base_url = base_url self.kwargs = kwargs diff --git a/tradingagents/llm_clients/capabilities.py b/tradingagents/llm_clients/capabilities.py index 8c95d2641..d3e9c5782 100644 --- a/tradingagents/llm_clients/capabilities.py +++ b/tradingagents/llm_clients/capabilities.py @@ -18,7 +18,6 @@ import re from dataclasses import dataclass from typing import Literal - StructuredMethod = Literal[ "function_calling", # uses tools; respects supports_tool_choice "json_mode", # uses response_format={"type":"json_object"} diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index d0dbdd63c..02c1cf428 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -1,11 +1,11 @@ -from typing import Optional from .base_client import BaseLLMClient + def create_llm_client( provider: str, model: str, - base_url: Optional[str] = None, + base_url: str | None = None, **kwargs, ) -> BaseLLMClient: """Create an LLM client for the specified provider. diff --git a/tradingagents/llm_clients/google_client.py b/tradingagents/llm_clients/google_client.py index 61a104c7c..df83b6cc8 100644 --- a/tradingagents/llm_clients/google_client.py +++ b/tradingagents/llm_clients/google_client.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from langchain_google_genai import ChatGoogleGenerativeAI @@ -20,7 +20,7 @@ class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI): class GoogleClient(BaseLLMClient): """Client for Google Gemini models.""" - def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + def __init__(self, model: str, base_url: str | None = None, **kwargs): super().__init__(model, base_url, **kwargs) def get_llm(self) -> Any: diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index 418c898a2..10b030ba1 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -2,14 +2,12 @@ from __future__ import annotations -from typing import Dict, List, Tuple - -ModelOption = Tuple[str, str] -ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]] +ModelOption = tuple[str, str] +ProviderModeOptions = dict[str, dict[str, list[ModelOption]]] # Providers that serve many / frequently-changing models: offer only "Custom # model ID" rather than a list that goes stale. -_CUSTOM_ONLY: Dict[str, List[ModelOption]] = { +_CUSTOM_ONLY: dict[str, list[ModelOption]] = { "quick": [("Custom model ID", "custom")], "deep": [("Custom model ID", "custom")], } @@ -18,7 +16,7 @@ _CUSTOM_ONLY: Dict[str, List[ModelOption]] = { # Shared model list for GLM via Z.AI (international) and BigModel (China). # Source: docs.z.ai (GLM Coding Plan supported models + LLM guides). # All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}. -_GLM_MODELS: Dict[str, List[ModelOption]] = { +_GLM_MODELS: dict[str, list[ModelOption]] = { "quick": [ ("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"), ("GLM-4.7 - Previous-gen flagship", "glm-4.7"), @@ -44,7 +42,7 @@ _GLM_MODELS: Dict[str, List[ModelOption]] = { # the backing model. Users who want a specific generation pick it # explicitly; users who really want auto-latest can enter the alias via # "Custom model ID". -_QWEN_MODELS: Dict[str, List[ModelOption]] = { +_QWEN_MODELS: dict[str, list[ModelOption]] = { "quick": [ ("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"), ("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"), @@ -62,7 +60,7 @@ _QWEN_MODELS: Dict[str, List[ModelOption]] = { # Shared model list for MiniMax's global and CN endpoints (same IDs). # Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api. # All M2.x models share a 204,800-token context window. -_MINIMAX_MODELS: Dict[str, List[ModelOption]] = { +_MINIMAX_MODELS: dict[str, list[ModelOption]] = { "quick": [ ("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"), ("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"), @@ -198,12 +196,12 @@ MODEL_OPTIONS: ProviderModeOptions = { } -def get_model_options(provider: str, mode: str) -> List[ModelOption]: +def get_model_options(provider: str, mode: str) -> list[ModelOption]: """Return shared model options for a provider and selection mode.""" return MODEL_OPTIONS[provider.lower()][mode] -def get_known_models() -> Dict[str, List[str]]: +def get_known_models() -> dict[str, list[str]]: """Build known model names from the shared CLI catalog.""" return { provider: sorted( diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index 214cf1ee0..b03cfb860 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -2,7 +2,6 @@ from .model_catalog import get_known_models - # Providers whose model names are user-defined (local servers, relays, hosted # OpenAI-compatible endpoints serving many models), so any model string is # accepted without warning.