mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
chore(lint): make the repository ruff-clean under the strict select
Clear the deferred full-repo lint backlog so the whole tree passes the strict ruff select (E,W,F,I,B,UP,C4,SIM). Mechanical fixes dominate: import sorting, pep585/604 annotations, dropped dead imports, and whitespace. The few semantic changes are behavior-preserving: declare __all__ on the agent_utils and alpha_vantage re-export hubs; expand 'from x import *' to explicit names; use immutable tuple defaults instead of mutable list defaults; contextlib.suppress for try/except/pass; and narrow an over-broad assertRaises.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import getpass
|
||||
|
||||
import requests
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
100
cli/main.py
100
cli/main.py
@@ -1,38 +1,53 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
import datetime
|
||||
import typer
|
||||
import questionary
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
from rich.columns import Columns
|
||||
from rich.markdown import Markdown
|
||||
from rich.layout import Layout
|
||||
from rich.text import Text
|
||||
from rich.table import Table
|
||||
from collections import deque
|
||||
import os
|
||||
import time
|
||||
from rich.tree import Tree
|
||||
from collections import deque
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from rich import box
|
||||
from rich.align import Align
|
||||
from rich.console import Console
|
||||
from rich.layout import Layout
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.rule import Rule
|
||||
from rich.spinner import Spinner
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from cli.announcements import display_announcements, fetch_announcements
|
||||
from cli.stats_handler import StatsCallbackHandler
|
||||
from cli.utils import (
|
||||
ask_anthropic_effort,
|
||||
ask_gemini_thinking_config,
|
||||
ask_glm_region,
|
||||
ask_minimax_region,
|
||||
ask_openai_reasoning_effort,
|
||||
ask_output_language,
|
||||
ask_qwen_region,
|
||||
confirm_ollama_endpoint,
|
||||
detect_asset_type,
|
||||
ensure_api_key,
|
||||
get_ticker,
|
||||
prompt_openai_compatible_url,
|
||||
resolve_backend_url,
|
||||
select_analysts,
|
||||
select_deep_thinking_agent,
|
||||
select_llm_provider,
|
||||
select_research_depth,
|
||||
select_shallow_thinking_agent,
|
||||
)
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.graph.analyst_execution import (
|
||||
AnalystWallTimeTracker,
|
||||
build_analyst_execution_plan,
|
||||
get_initial_analyst_node,
|
||||
sync_analyst_tracker_from_chunk,
|
||||
)
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from cli.models import AnalystType
|
||||
from cli.utils import *
|
||||
from cli.announcements import fetch_announcements, display_announcements
|
||||
from cli.stats_handler import StatsCallbackHandler
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -169,7 +184,7 @@ class MessageBuffer:
|
||||
if content is not None:
|
||||
latest_section = section
|
||||
latest_content = content
|
||||
|
||||
|
||||
if latest_section and latest_content:
|
||||
# Format the current section for display
|
||||
section_titles = {
|
||||
@@ -466,7 +481,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
|
||||
def get_user_selections():
|
||||
"""Get all user selections before starting the analysis display."""
|
||||
# Display ASCII art welcome message
|
||||
with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f:
|
||||
with open(Path(__file__).parent / "static" / "welcome.txt", encoding="utf-8") as f:
|
||||
welcome_ascii = f.read()
|
||||
|
||||
# Create welcome box content
|
||||
@@ -922,9 +937,12 @@ def update_analyst_statuses(message_buffer, chunk, wall_time_tracker=None):
|
||||
message_buffer.update_agent_status(agent_name, "pending")
|
||||
|
||||
# When all analysts complete, transition research team to in_progress
|
||||
if not found_active and selected:
|
||||
if message_buffer.agent_status.get("Bull Researcher") == "pending":
|
||||
message_buffer.update_agent_status("Bull Researcher", "in_progress")
|
||||
if (
|
||||
not found_active
|
||||
and selected
|
||||
and message_buffer.agent_status.get("Bull Researcher") == "pending"
|
||||
):
|
||||
message_buffer.update_agent_status("Bull Researcher", "in_progress")
|
||||
|
||||
def extract_content_string(content):
|
||||
"""Extract string content from various message formats.
|
||||
@@ -1064,7 +1082,7 @@ def run_analysis(checkpoint: bool = False):
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||
return wrapper
|
||||
|
||||
|
||||
def save_tool_call_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
@wraps(func)
|
||||
@@ -1097,7 +1115,7 @@ def run_analysis(checkpoint: bool = False):
|
||||
# Now start the display layout
|
||||
layout = create_layout()
|
||||
|
||||
with Live(layout, refresh_per_second=4) as live:
|
||||
with Live(layout, refresh_per_second=4):
|
||||
# Initial display
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
@@ -1232,16 +1250,15 @@ def run_analysis(checkpoint: bool = False):
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
|
||||
)
|
||||
if judge:
|
||||
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
||||
)
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "completed")
|
||||
message_buffer.update_agent_status("Conservative Analyst", "completed")
|
||||
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
||||
message_buffer.update_agent_status("Portfolio Manager", "completed")
|
||||
if judge and message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
||||
)
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "completed")
|
||||
message_buffer.update_agent_status("Conservative Analyst", "completed")
|
||||
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
||||
message_buffer.update_agent_status("Portfolio Manager", "completed")
|
||||
|
||||
# Update the display
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
@@ -1253,7 +1270,6 @@ def run_analysis(checkpoint: bool = False):
|
||||
final_state = {}
|
||||
for chunk in trace:
|
||||
final_state.update(chunk)
|
||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
# Update all agent statuses to completed
|
||||
for agent in message_buffer.agent_status:
|
||||
@@ -1265,7 +1281,7 @@ def run_analysis(checkpoint: bool = False):
|
||||
message_buffer.add_message("System", analyst_wall_time_tracker.format_summary())
|
||||
|
||||
# Update final report sections
|
||||
for section in message_buffer.report_sections.keys():
|
||||
for section in message_buffer.report_sections:
|
||||
if section in final_state:
|
||||
message_buffer.update_report_section(section, final_state[section])
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnalystType(str, Enum):
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import threading
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class StatsCallbackHandler(BaseCallbackHandler):
|
||||
@@ -19,8 +19,8 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Increment LLM call counter when an LLM starts."""
|
||||
@@ -29,8 +29,8 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[Any]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[Any]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Increment LLM call counter when a chat model starts."""
|
||||
@@ -57,7 +57,7 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@@ -65,7 +65,7 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
||||
with self._lock:
|
||||
self.tool_calls += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Return current statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
|
||||
13
cli/utils.py
13
cli/utils.py
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
import questionary
|
||||
from dotenv import find_dotenv, set_key
|
||||
@@ -89,8 +88,8 @@ def detect_asset_type(ticker: str) -> AssetType:
|
||||
|
||||
|
||||
def filter_analysts_for_asset_type(
|
||||
analysts: List[AnalystType], asset_type: AssetType
|
||||
) -> List[AnalystType]:
|
||||
analysts: list[AnalystType], asset_type: AssetType
|
||||
) -> list[AnalystType]:
|
||||
if asset_type != AssetType.CRYPTO:
|
||||
return analysts
|
||||
return [
|
||||
@@ -133,7 +132,7 @@ def get_analysis_date() -> str:
|
||||
return date.strip()
|
||||
|
||||
|
||||
def select_analysts(asset_type: AssetType = AssetType.STOCK) -> List[AnalystType]:
|
||||
def select_analysts(asset_type: AssetType = AssetType.STOCK) -> list[AnalystType]:
|
||||
"""Select analysts using an interactive checkbox."""
|
||||
available_analysts = filter_analysts_for_asset_type(
|
||||
[value for _, value in ANALYST_ORDER],
|
||||
@@ -197,7 +196,7 @@ def select_research_depth() -> int:
|
||||
return choice
|
||||
|
||||
|
||||
def _fetch_openrouter_models() -> List[Tuple[str, str]]:
|
||||
def _fetch_openrouter_models() -> list[tuple[str, str]]:
|
||||
"""Fetch available models from the OpenRouter API."""
|
||||
import requests
|
||||
try:
|
||||
@@ -377,7 +376,7 @@ def select_llm_provider() -> tuple[str, str | None]:
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
@@ -556,7 +555,7 @@ def confirm_ollama_endpoint(url: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def ensure_api_key(provider: str) -> Optional[str]:
|
||||
def ensure_api_key(provider: str) -> str | None:
|
||||
"""Make sure the API key for `provider` is available in the environment.
|
||||
|
||||
If the env var is already set, returns its value untouched. Otherwise
|
||||
|
||||
2
main.py
2
main.py
@@ -1,5 +1,5 @@
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides
|
||||
# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.),
|
||||
|
||||
@@ -21,7 +21,6 @@ added, plus the heuristic SignalProcessor.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||
@@ -30,7 +29,6 @@ from tradingagents.agents.trader.trader import create_trader
|
||||
from tradingagents.graph.signal_processing import SignalProcessor
|
||||
from tradingagents.llm_clients import create_llm_client
|
||||
|
||||
|
||||
PROVIDER_DEFAULTS = {
|
||||
"openai": ("gpt-5.4-mini", None),
|
||||
"google": ("gemini-2.5-flash", None),
|
||||
|
||||
5
test.py
5
test.py
@@ -1,5 +1,8 @@
|
||||
import time
|
||||
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
||||
|
||||
from tradingagents.dataflows.y_finance import (
|
||||
get_stock_stats_indicators_window,
|
||||
)
|
||||
|
||||
print("Testing optimized implementation with 30-day lookback:")
|
||||
start_time = time.time()
|
||||
|
||||
@@ -3,14 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
|
||||
|
||||
|
||||
# ---- Mapping coverage -----------------------------------------------------
|
||||
|
||||
|
||||
@@ -71,6 +69,7 @@ def test_case_insensitive_lookup():
|
||||
def cli_utils(monkeypatch):
|
||||
"""Import cli.utils with a fresh environment so module-level state is consistent."""
|
||||
import importlib
|
||||
|
||||
import cli.utils as cli_utils_module
|
||||
return importlib.reload(cli_utils_module)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Unit tests for the LLM capability table."""
|
||||
|
||||
from dataclasses import FrozenInstanceError
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.capabilities import (
|
||||
ModelCapabilities,
|
||||
get_capabilities,
|
||||
)
|
||||
|
||||
@@ -119,5 +120,5 @@ class TestDefault:
|
||||
def test_capabilities_dataclass_is_frozen():
|
||||
"""Capability rows are immutable so they can be safely shared."""
|
||||
caps = get_capabilities("deepseek-chat")
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
caps.supports_tool_choice = False # type: ignore[misc]
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from tradingagents.graph.checkpointer import (
|
||||
|
||||
@@ -24,7 +24,6 @@ from tradingagents.llm_clients.openai_client import (
|
||||
_input_to_messages,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
||||
# (Gemini bot review note: non-list inputs must also work)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
|
||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||
from tradingagents.graph.propagation import Propagator
|
||||
from tradingagents.graph.reflection import Reflector
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.graph.propagation import Propagator
|
||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||
|
||||
_SEP = TradingMemoryLog._SEPARATOR
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ news injected future articles), #993 (empty-after-filter returned a blank body).
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from unittest import mock
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from tradingagents.dataflows import stockstats_utils, interface
|
||||
from tradingagents.dataflows import interface, stockstats_utils
|
||||
from tradingagents.dataflows.config import set_config
|
||||
from tradingagents.dataflows.symbol_utils import NoMarketDataError
|
||||
|
||||
@@ -33,9 +33,9 @@ class TestLoadOhlcvNoPoison(unittest.TestCase):
|
||||
|
||||
def test_empty_download_raises_and_does_not_cache(self):
|
||||
empty = pd.DataFrame()
|
||||
with mock.patch.object(stockstats_utils.yf, "download", return_value=empty) as dl:
|
||||
with self.assertRaises(NoMarketDataError):
|
||||
stockstats_utils.load_ohlcv("FAKE", "2026-01-01")
|
||||
with mock.patch.object(stockstats_utils.yf, "download", return_value=empty), \
|
||||
self.assertRaises(NoMarketDataError):
|
||||
stockstats_utils.load_ohlcv("FAKE", "2026-01-01")
|
||||
# Nothing should have been written to the cache.
|
||||
self.assertEqual(os.listdir(self._tmp), [])
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ def _resync_reloaded_modules():
|
||||
doesn't leak across test modules.
|
||||
"""
|
||||
yield
|
||||
import cli.utils
|
||||
import cli.main
|
||||
import cli.utils
|
||||
importlib.reload(cli.utils)
|
||||
importlib.reload(cli.main)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Verifies the user-supplied base_url is required and honored, the key is optional
|
||||
(keyless local default), Chat Completions (not the Responses API) is used, any
|
||||
model name is accepted, and the env backend URL precedence (#978).
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import pytest
|
||||
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
|
||||
from tradingagents.graph.signal_processing import SignalProcessor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Heuristic parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -27,7 +27,6 @@ from tradingagents.agents.schemas import (
|
||||
)
|
||||
from tradingagents.agents.trader.trader import create_trader
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Render functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -6,8 +6,8 @@ import pytest
|
||||
|
||||
from tradingagents.dataflows.symbol_utils import (
|
||||
NoMarketDataError,
|
||||
normalize_symbol,
|
||||
is_yahoo_safe,
|
||||
normalize_symbol,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ Regressions for #988 (explicit single-vendor config still fell back to others),
|
||||
were swallowed without a trace).
|
||||
"""
|
||||
import copy
|
||||
import logging
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
@@ -76,9 +75,9 @@ class VendorRoutingTests(unittest.TestCase):
|
||||
# #989: primary errors + fallback no-data -> NO_DATA, but the failure
|
||||
# must be visible in logs (broken primary not hidden).
|
||||
set_config({"data_vendors": {"core_stock_apis": "yfinance,alpha_vantage"}})
|
||||
with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}):
|
||||
with self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm:
|
||||
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||
with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}), \
|
||||
self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm:
|
||||
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||
self.assertIn("NO_DATA_AVAILABLE", result)
|
||||
joined = "\n".join(cm.output)
|
||||
self.assertIn("boom", joined) # the real error surfaced in logs
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import warnings
|
||||
|
||||
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
|
||||
@@ -20,10 +21,8 @@ except ImportError:
|
||||
# subclassed warning categories. To suppress a specific warning we must
|
||||
# install our filter AFTER langchain-core has installed its own, so import
|
||||
# it first. The package is a guaranteed transitive dep via langgraph.
|
||||
try:
|
||||
with contextlib.suppress(ImportError):
|
||||
import langchain_core # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
|
||||
# explicit allowed_objects, which triggers a noisy pending-deprecation
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from .utils.agent_utils import create_msg_delete
|
||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
|
||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||
from .analysts.market_analyst import create_market_analyst
|
||||
from .analysts.news_analyst import create_news_analyst
|
||||
@@ -8,18 +5,16 @@ from .analysts.sentiment_analyst import (
|
||||
create_sentiment_analyst,
|
||||
create_social_media_analyst, # deprecated alias kept for back-compat
|
||||
)
|
||||
|
||||
from .managers.portfolio_manager import create_portfolio_manager
|
||||
from .managers.research_manager import create_research_manager
|
||||
from .researchers.bear_researcher import create_bear_researcher
|
||||
from .researchers.bull_researcher import create_bull_researcher
|
||||
|
||||
from .risk_mgmt.aggressive_debator import create_aggressive_debator
|
||||
from .risk_mgmt.conservative_debator import create_conservative_debator
|
||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||
|
||||
from .managers.research_manager import create_research_manager
|
||||
from .managers.portfolio_manager import create_portfolio_manager
|
||||
|
||||
from .trader.trader import create_trader
|
||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .utils.agent_utils import create_msg_delete
|
||||
|
||||
__all__ = [
|
||||
"AgentState",
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
get_instrument_context_from_state,
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
get_fundamentals,
|
||||
get_income_statement,
|
||||
get_insider_transactions,
|
||||
get_instrument_context_from_state,
|
||||
get_language_instruction,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
def create_fundamentals_analyst(llm):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
get_instrument_context_from_state,
|
||||
get_indicators,
|
||||
get_instrument_context_from_state,
|
||||
get_language_instruction,
|
||||
get_stock_data,
|
||||
get_verified_market_snapshot,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
def create_market_analyst(llm):
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
get_instrument_context_from_state,
|
||||
get_global_news,
|
||||
get_instrument_context_from_state,
|
||||
get_language_instruction,
|
||||
get_macro_indicators,
|
||||
get_news,
|
||||
get_prediction_markets,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
def create_news_analyst(llm):
|
||||
|
||||
@@ -19,11 +19,10 @@ so that:
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared rating types
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -124,15 +123,15 @@ class TraderProposal(BaseModel):
|
||||
"the research plan. Two to four sentences."
|
||||
),
|
||||
)
|
||||
entry_price: Optional[float] = Field(
|
||||
entry_price: float | None = Field(
|
||||
default=None,
|
||||
description="Optional entry price target in the instrument's quote currency.",
|
||||
)
|
||||
stop_loss: Optional[float] = Field(
|
||||
stop_loss: float | None = Field(
|
||||
default=None,
|
||||
description="Optional stop-loss price in the instrument's quote currency.",
|
||||
)
|
||||
position_sizing: Optional[str] = Field(
|
||||
position_sizing: str | None = Field(
|
||||
default=None,
|
||||
description="Optional sizing guidance, e.g. '5% of portfolio'.",
|
||||
)
|
||||
@@ -196,11 +195,11 @@ class PortfolioDecision(BaseModel):
|
||||
"incorporate them; otherwise rely solely on the current analysis."
|
||||
),
|
||||
)
|
||||
price_target: Optional[float] = Field(
|
||||
price_target: float | None = Field(
|
||||
default=None,
|
||||
description="Optional target price in the instrument's quote currency.",
|
||||
)
|
||||
time_horizon: Optional[str] = Field(
|
||||
time_horizon: str | None = Field(
|
||||
default=None,
|
||||
description="Optional recommended holding period, e.g. '3-6 months'.",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langgraph.graph import MessagesState
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# Researcher team state
|
||||
|
||||
@@ -1,37 +1,50 @@
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Mapping, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import yfinance as yf
|
||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
|
||||
# Import tools from separate utility files
|
||||
from tradingagents.agents.utils.core_stock_tools import (
|
||||
get_stock_data
|
||||
)
|
||||
from tradingagents.agents.utils.technical_indicators_tools import (
|
||||
get_indicators
|
||||
)
|
||||
from tradingagents.agents.utils.core_stock_tools import get_stock_data
|
||||
from tradingagents.agents.utils.fundamental_data_tools import (
|
||||
get_fundamentals,
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
get_income_statement
|
||||
get_fundamentals,
|
||||
get_income_statement,
|
||||
)
|
||||
from tradingagents.agents.utils.macro_data_tools import get_macro_indicators
|
||||
from tradingagents.agents.utils.market_data_validation_tools import get_verified_market_snapshot
|
||||
from tradingagents.agents.utils.news_data_tools import (
|
||||
get_news,
|
||||
get_global_news,
|
||||
get_insider_transactions,
|
||||
get_global_news
|
||||
)
|
||||
from tradingagents.agents.utils.macro_data_tools import (
|
||||
get_macro_indicators
|
||||
)
|
||||
from tradingagents.agents.utils.prediction_markets_tools import (
|
||||
get_prediction_markets
|
||||
)
|
||||
from tradingagents.agents.utils.market_data_validation_tools import (
|
||||
get_verified_market_snapshot
|
||||
get_news,
|
||||
)
|
||||
from tradingagents.agents.utils.prediction_markets_tools import get_prediction_markets
|
||||
from tradingagents.agents.utils.technical_indicators_tools import get_indicators
|
||||
|
||||
# Public surface: the data tools are imported here so agents and the graph
|
||||
# import them from one place, plus the instrument/language helpers defined below.
|
||||
__all__ = [
|
||||
"get_stock_data",
|
||||
"get_indicators",
|
||||
"get_fundamentals",
|
||||
"get_balance_sheet",
|
||||
"get_cashflow",
|
||||
"get_income_statement",
|
||||
"get_news",
|
||||
"get_global_news",
|
||||
"get_insider_transactions",
|
||||
"get_macro_indicators",
|
||||
"get_prediction_markets",
|
||||
"get_verified_market_snapshot",
|
||||
"build_instrument_context",
|
||||
"resolve_instrument_identity",
|
||||
"get_instrument_context_from_state",
|
||||
"get_language_instruction",
|
||||
"create_msg_delete",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,7 +65,7 @@ def get_language_instruction() -> str:
|
||||
return f" Write your entire response in {lang}."
|
||||
|
||||
|
||||
def _clean_identity_value(value: Any) -> Optional[str]:
|
||||
def _clean_identity_value(value: Any) -> str | None:
|
||||
"""Return a trimmed string, or None for empty / placeholder-ish values."""
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
@@ -109,7 +122,7 @@ def resolve_instrument_identity(ticker: str) -> dict:
|
||||
def build_instrument_context(
|
||||
ticker: str,
|
||||
asset_type: str = "stock",
|
||||
identity: Optional[Mapping[str, str]] = None,
|
||||
identity: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Describe the exact instrument so agents preserve identity and ticker.
|
||||
|
||||
@@ -201,4 +214,4 @@ def create_msg_delete():
|
||||
return delete_messages
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
|
||||
@@ -74,4 +76,4 @@ def get_income_statement(
|
||||
Returns:
|
||||
str: A formatted report containing income statement data
|
||||
"""
|
||||
return route_to_vendor("get_income_statement", ticker, freq, curr_date)
|
||||
return route_to_vendor("get_income_statement", ticker, freq, curr_date)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Append-only markdown decision log for TradingAgents."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from tradingagents.agents.utils.rating import parse_rating
|
||||
|
||||
@@ -51,7 +50,7 @@ class TradingMemoryLog:
|
||||
|
||||
# --- Read path (Phase A) ---
|
||||
|
||||
def load_entries(self) -> List[dict]:
|
||||
def load_entries(self) -> list[dict]:
|
||||
"""Parse all entries from log. Returns list of dicts."""
|
||||
if not self._log_path or not self._log_path.exists():
|
||||
return []
|
||||
@@ -64,7 +63,7 @@ class TradingMemoryLog:
|
||||
entries.append(parsed)
|
||||
return entries
|
||||
|
||||
def get_pending_entries(self) -> List[dict]:
|
||||
def get_pending_entries(self) -> list[dict]:
|
||||
"""Return entries with outcome:pending (for Phase B)."""
|
||||
return [e for e in self.load_entries() if e.get("pending")]
|
||||
|
||||
@@ -162,7 +161,7 @@ class TradingMemoryLog:
|
||||
tmp_path.write_text(new_text, encoding="utf-8")
|
||||
tmp_path.replace(self._log_path)
|
||||
|
||||
def batch_update_with_outcomes(self, updates: List[dict]) -> None:
|
||||
def batch_update_with_outcomes(self, updates: list[dict]) -> None:
|
||||
"""Apply multiple outcome updates in a single read + atomic write.
|
||||
|
||||
Each element of updates must have keys: ticker, trade_date,
|
||||
@@ -218,7 +217,7 @@ class TradingMemoryLog:
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def _apply_rotation(self, blocks: List[str]) -> List[str]:
|
||||
def _apply_rotation(self, blocks: list[str]) -> list[str]:
|
||||
"""Drop oldest resolved blocks when their count exceeds max_entries.
|
||||
|
||||
Pending blocks are always kept (they represent unprocessed work).
|
||||
@@ -247,7 +246,7 @@ class TradingMemoryLog:
|
||||
return blocks
|
||||
|
||||
to_drop = resolved_count - self._max_entries
|
||||
kept: List[str] = []
|
||||
kept: list[str] = []
|
||||
for block, is_resolved in decisions:
|
||||
if is_resolved and to_drop > 0:
|
||||
to_drop -= 1
|
||||
@@ -255,7 +254,7 @@ class TradingMemoryLog:
|
||||
kept.append(block)
|
||||
return kept
|
||||
|
||||
def _parse_entry(self, raw: str) -> Optional[dict]:
|
||||
def _parse_entry(self, raw: str) -> dict | None:
|
||||
lines = raw.strip().splitlines()
|
||||
if not lines:
|
||||
return None
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
|
||||
@tool
|
||||
def get_news(
|
||||
ticker: Annotated[str, "Ticker symbol"],
|
||||
@@ -23,8 +26,8 @@ def get_news(
|
||||
@tool
|
||||
def get_global_news(
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None,
|
||||
limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None,
|
||||
look_back_days: Annotated[int | None, "Days to look back; omit to use the configured default"] = None,
|
||||
limit: Annotated[int | None, "Max articles to return; omit to use the configured default"] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global news data.
|
||||
|
||||
@@ -12,11 +12,9 @@ Centralising it here avoids drift between those call sites.
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
# Canonical, ordered 5-tier scale (most bullish to most bearish).
|
||||
RATINGS_5_TIER: Tuple[str, ...] = (
|
||||
RATINGS_5_TIER: tuple[str, ...] = (
|
||||
"Buy", "Overweight", "Hold", "Underweight", "Sell",
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ all three agents log the same warnings when fallback fires.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -28,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]:
|
||||
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Any | None:
|
||||
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
|
||||
|
||||
Logs a warning when the binding fails so the user understands the agent
|
||||
@@ -46,7 +47,7 @@ def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]
|
||||
|
||||
|
||||
def invoke_structured_or_freetext(
|
||||
structured_llm: Optional[Any],
|
||||
structured_llm: Any | None,
|
||||
plain_llm: Any,
|
||||
prompt: Any,
|
||||
render: Callable[[T], str],
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicators(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
@@ -29,4 +32,4 @@ def get_indicators(
|
||||
results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days))
|
||||
except ValueError as e:
|
||||
results.append(str(e))
|
||||
return "\n\n".join(results)
|
||||
return "\n\n".join(results)
|
||||
|
||||
@@ -1,5 +1,23 @@
|
||||
# Import functions from specialized modules
|
||||
from .alpha_vantage_stock import get_stock
|
||||
# Aggregates the per-category Alpha Vantage implementations into one module the
|
||||
# vendor router imports from; the imports below are the public surface.
|
||||
from .alpha_vantage_fundamentals import (
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
get_fundamentals,
|
||||
get_income_statement,
|
||||
)
|
||||
from .alpha_vantage_indicator import get_indicator
|
||||
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
|
||||
from .alpha_vantage_news import get_news, get_global_news, get_insider_transactions
|
||||
from .alpha_vantage_news import get_global_news, get_insider_transactions, get_news
|
||||
from .alpha_vantage_stock import get_stock
|
||||
|
||||
__all__ = [
|
||||
"get_balance_sheet",
|
||||
"get_cashflow",
|
||||
"get_fundamentals",
|
||||
"get_income_statement",
|
||||
"get_indicator",
|
||||
"get_global_news",
|
||||
"get_insider_transactions",
|
||||
"get_news",
|
||||
"get_stock",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .alpha_vantage_common import _make_api_request, AlphaVantageNotConfiguredError
|
||||
from .alpha_vantage_common import AlphaVantageNotConfiguredError, _make_api_request
|
||||
|
||||
|
||||
def get_indicator(
|
||||
symbol: str,
|
||||
@@ -25,6 +26,7 @@ def get_indicator(
|
||||
String containing indicator values and description
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
supported_indicators = {
|
||||
@@ -98,21 +100,7 @@ def get_indicator(
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
elif indicator == "macd":
|
||||
data = _make_api_request("MACD", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
elif indicator == "macds":
|
||||
data = _make_api_request("MACD", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"series_type": series_type,
|
||||
"datatype": "csv"
|
||||
})
|
||||
elif indicator == "macdh":
|
||||
elif indicator == "macd" or indicator == "macds" or indicator == "macdh":
|
||||
data = _make_api_request("MACD", {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||
|
||||
|
||||
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
||||
|
||||
@@ -68,4 +69,4 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str:
|
||||
"symbol": symbol,
|
||||
}
|
||||
|
||||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
|
||||
|
||||
from .alpha_vantage_common import _filter_csv_by_date_range, _make_api_request
|
||||
|
||||
|
||||
def get_stock(
|
||||
symbol: str,
|
||||
@@ -35,4 +37,4 @@ def get_stock(
|
||||
|
||||
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
|
||||
|
||||
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional
|
||||
|
||||
import tradingagents.default_config as default_config
|
||||
|
||||
# Use default config but allow it to be overridden
|
||||
_config: Optional[Dict] = None
|
||||
_config: dict | None = None
|
||||
|
||||
|
||||
def initialize_config():
|
||||
@@ -14,7 +13,7 @@ def initialize_config():
|
||||
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
def set_config(config: dict):
|
||||
"""Update the configuration with custom values.
|
||||
|
||||
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
|
||||
@@ -31,7 +30,7 @@ def set_config(config: Dict):
|
||||
_config[key] = value
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
def get_config() -> dict:
|
||||
"""Get the current configuration."""
|
||||
if _config is None:
|
||||
initialize_config()
|
||||
|
||||
@@ -10,7 +10,7 @@ claim. Deterministic, no LLM involved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Optional
|
||||
from collections.abc import Iterable
|
||||
|
||||
import pandas as pd
|
||||
from stockstats import wrap
|
||||
@@ -63,7 +63,7 @@ def build_verified_market_snapshot(
|
||||
symbol: str,
|
||||
curr_date: str,
|
||||
look_back_days: int = 30,
|
||||
indicators: Optional[Iterable[str]] = None,
|
||||
indicators: Iterable[str] | None = None,
|
||||
) -> str:
|
||||
"""Render a ground-truth snapshot: latest OHLCV row, indicators, recent closes."""
|
||||
# `df` keeps the original capitalized OHLCV columns (Open/High/Low/Close/
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import pandas as pd
|
||||
from datetime import date, timedelta, datetime
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Annotated
|
||||
|
||||
import pandas as pd
|
||||
|
||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||
|
||||
# Tickers can contain letters, digits, dot, dash, underscore, caret
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""yfinance-based news data fetching functions."""
|
||||
|
||||
from typing import Optional
|
||||
import contextlib
|
||||
from datetime import datetime
|
||||
|
||||
import yfinance as yf
|
||||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from .config import get_config
|
||||
@@ -28,10 +28,8 @@ def _extract_article_data(article: dict) -> dict:
|
||||
pub_date_str = content.get("pubDate", "")
|
||||
pub_date = None
|
||||
if pub_date_str:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, AttributeError):
|
||||
pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
@@ -47,10 +45,8 @@ def _extract_article_data(article: dict) -> dict:
|
||||
pub_date = None
|
||||
ts = article.get("providerPublishTime")
|
||||
if ts:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, OSError, TypeError):
|
||||
pub_date = datetime.fromtimestamp(ts)
|
||||
except (ValueError, OSError, TypeError):
|
||||
pass
|
||||
return {
|
||||
"title": article.get("title", "No title"),
|
||||
"summary": article.get("summary", ""),
|
||||
@@ -131,8 +127,8 @@ def get_news_yfinance(
|
||||
|
||||
def get_global_news_yfinance(
|
||||
curr_date: str,
|
||||
look_back_days: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
look_back_days: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global/macro economic news using yfinance Search.
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# TradingAgents/graph/__init__.py
|
||||
|
||||
from .trading_graph import TradingAgentsGraph
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
from .reflection import Reflector
|
||||
from .setup import GraphSetup
|
||||
from .signal_processing import SignalProcessor
|
||||
from .trading_graph import TradingAgentsGraph
|
||||
|
||||
__all__ = [
|
||||
"TradingAgentsGraph",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from time import monotonic
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -14,11 +14,11 @@ class AnalystNodeSpec:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnalystExecutionPlan:
|
||||
specs: List[AnalystNodeSpec]
|
||||
specs: list[AnalystNodeSpec]
|
||||
concurrency_limit: int
|
||||
|
||||
|
||||
ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = {
|
||||
ANALYST_NODE_SPECS: dict[str, AnalystNodeSpec] = {
|
||||
"market": AnalystNodeSpec(
|
||||
key="market",
|
||||
agent_node="Market Analyst",
|
||||
@@ -61,7 +61,7 @@ def build_analyst_execution_plan(
|
||||
if concurrency_limit < 1:
|
||||
raise ValueError("analyst concurrency limit must be >= 1")
|
||||
|
||||
specs: List[AnalystNodeSpec] = []
|
||||
specs: list[AnalystNodeSpec] = []
|
||||
for analyst_key in selected_analysts:
|
||||
spec = ANALYST_NODE_SPECS.get(analyst_key)
|
||||
if spec is None:
|
||||
@@ -81,10 +81,10 @@ def get_initial_analyst_node(plan: AnalystExecutionPlan) -> str:
|
||||
class AnalystWallTimeTracker:
|
||||
def __init__(self, plan: AnalystExecutionPlan):
|
||||
self.plan = plan
|
||||
self._started_at: Dict[str, float] = {}
|
||||
self._wall_times: Dict[str, float] = {}
|
||||
self._started_at: dict[str, float] = {}
|
||||
self._wall_times: dict[str, float] = {}
|
||||
|
||||
def mark_started(self, analyst_key: str, started_at: Optional[float] = None) -> None:
|
||||
def mark_started(self, analyst_key: str, started_at: float | None = None) -> None:
|
||||
if analyst_key not in ANALYST_NODE_SPECS:
|
||||
raise ValueError(f"unknown analyst key: {analyst_key}")
|
||||
self._started_at.setdefault(analyst_key, monotonic() if started_at is None else started_at)
|
||||
@@ -92,7 +92,7 @@ class AnalystWallTimeTracker:
|
||||
def mark_completed(
|
||||
self,
|
||||
analyst_key: str,
|
||||
completed_at: Optional[float] = None,
|
||||
completed_at: float | None = None,
|
||||
) -> None:
|
||||
if analyst_key not in ANALYST_NODE_SPECS:
|
||||
raise ValueError(f"unknown analyst key: {analyst_key}")
|
||||
@@ -104,7 +104,7 @@ class AnalystWallTimeTracker:
|
||||
finished_at = monotonic() if completed_at is None else completed_at
|
||||
self._wall_times[analyst_key] = max(0.0, finished_at - started_at)
|
||||
|
||||
def get_wall_times(self) -> Dict[str, float]:
|
||||
def get_wall_times(self) -> dict[str, float]:
|
||||
return dict(self._wall_times)
|
||||
|
||||
def format_summary(self) -> str:
|
||||
@@ -121,8 +121,8 @@ class AnalystWallTimeTracker:
|
||||
|
||||
def sync_analyst_tracker_from_chunk(
|
||||
tracker: AnalystWallTimeTracker,
|
||||
chunk: Dict[str, str],
|
||||
now: Optional[float] = None,
|
||||
chunk: dict[str, str],
|
||||
now: float | None = None,
|
||||
) -> None:
|
||||
current_time = monotonic() if now is None else now
|
||||
active_found = False
|
||||
|
||||
@@ -7,9 +7,9 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import sqlite3
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# TradingAgents/graph/propagation.py
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
)
|
||||
@@ -22,7 +22,7 @@ class Propagator:
|
||||
asset_type: str = "stock",
|
||||
past_context: str = "",
|
||||
instrument_context: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Create the initial state for the agent graph.
|
||||
|
||||
``instrument_context`` is the deterministic ticker-identity string
|
||||
@@ -68,7 +68,7 @@ class Propagator:
|
||||
"news_report": "",
|
||||
}
|
||||
|
||||
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
|
||||
def get_graph_args(self, callbacks: list | None = None) -> dict[str, Any]:
|
||||
"""Get arguments for the graph invocation.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
# TradingAgents/graph/setup.py
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.agents import (
|
||||
create_aggressive_debator,
|
||||
create_bear_researcher,
|
||||
create_bull_researcher,
|
||||
create_conservative_debator,
|
||||
create_fundamentals_analyst,
|
||||
create_market_analyst,
|
||||
create_msg_delete,
|
||||
create_neutral_debator,
|
||||
create_news_analyst,
|
||||
create_portfolio_manager,
|
||||
create_research_manager,
|
||||
create_sentiment_analyst,
|
||||
create_trader,
|
||||
)
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
|
||||
from .analyst_execution import build_analyst_execution_plan
|
||||
@@ -18,7 +33,7 @@ class GraphSetup:
|
||||
self,
|
||||
quick_thinking_llm: Any,
|
||||
deep_thinking_llm: Any,
|
||||
tool_nodes: Dict[str, ToolNode],
|
||||
tool_nodes: dict[str, ToolNode],
|
||||
conditional_logic: ConditionalLogic,
|
||||
analyst_concurrency_limit: int = 1,
|
||||
):
|
||||
@@ -30,7 +45,7 @@ class GraphSetup:
|
||||
self.analyst_concurrency_limit = analyst_concurrency_limit
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
self, selected_analysts=("market", "social", "news", "fundamentals")
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
|
||||
@@ -1,66 +1,57 @@
|
||||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yfinance as yf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.llm_clients import create_llm_client
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||
from tradingagents.dataflows.utils import safe_ticker_component
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
)
|
||||
from tradingagents.dataflows.config import set_config
|
||||
|
||||
# Import the new abstract tool methods from agent_utils
|
||||
# Import the abstract tool methods from agent_utils
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
resolve_instrument_identity,
|
||||
get_stock_data,
|
||||
get_indicators,
|
||||
get_verified_market_snapshot,
|
||||
get_fundamentals,
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
get_income_statement,
|
||||
get_news,
|
||||
get_insider_transactions,
|
||||
get_fundamentals,
|
||||
get_global_news,
|
||||
get_income_statement,
|
||||
get_indicators,
|
||||
get_insider_transactions,
|
||||
get_macro_indicators,
|
||||
get_prediction_markets
|
||||
get_news,
|
||||
get_prediction_markets,
|
||||
get_stock_data,
|
||||
get_verified_market_snapshot,
|
||||
resolve_instrument_identity,
|
||||
)
|
||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||
from tradingagents.dataflows.config import set_config
|
||||
from tradingagents.dataflows.utils import safe_ticker_component
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.llm_clients import create_llm_client
|
||||
|
||||
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
from .reflection import Reflector
|
||||
from .setup import GraphSetup
|
||||
from .signal_processing import SignalProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradingAgentsGraph:
|
||||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
selected_analysts=("market", "social", "news", "fundamentals"),
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
callbacks: Optional[List] = None,
|
||||
config: dict[str, Any] = None,
|
||||
callbacks: list | None = None,
|
||||
):
|
||||
"""Initialize the trading agents graph and components.
|
||||
|
||||
@@ -103,7 +94,7 @@ class TradingAgentsGraph:
|
||||
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
self.quick_thinking_llm = quick_client.get_llm()
|
||||
|
||||
|
||||
self.memory_log = TradingMemoryLog(self.config)
|
||||
|
||||
# Create tool nodes
|
||||
@@ -138,7 +129,7 @@ class TradingAgentsGraph:
|
||||
self.graph = self.workflow.compile()
|
||||
self._checkpointer_ctx = None
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
def _get_provider_kwargs(self) -> dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
kwargs = {}
|
||||
provider = self.config.get("llm_provider", "").lower()
|
||||
@@ -167,7 +158,7 @@ class TradingAgentsGraph:
|
||||
|
||||
return kwargs
|
||||
|
||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||
def _create_tool_nodes(self) -> dict[str, ToolNode]:
|
||||
"""Create tool nodes for different data sources using abstract methods."""
|
||||
return {
|
||||
"market": ToolNode(
|
||||
@@ -233,7 +224,7 @@ class TradingAgentsGraph:
|
||||
def _fetch_returns(
|
||||
self, ticker: str, trade_date: str, holding_days: int = 5,
|
||||
benchmark: str = "SPY",
|
||||
) -> Tuple[Optional[float], Optional[float], Optional[int]]:
|
||||
) -> tuple[float | None, float | None, int | None]:
|
||||
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
|
||||
|
||||
``benchmark`` is the index used as the alpha baseline (resolved by the
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
@@ -43,7 +43,7 @@ class NormalizedChatAnthropic(ChatAnthropic):
|
||||
class AnthropicClient(BaseLLMClient):
|
||||
"""Client for Anthropic Claude models."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
def __init__(self, model: str, base_url: str | None = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
|
||||
@@ -11,10 +11,7 @@ prompts for it automatically instead of failing on first API call.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
||||
PROVIDER_API_KEY_ENV: dict[str, str | None] = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
@@ -47,7 +44,7 @@ PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
||||
}
|
||||
|
||||
|
||||
def get_api_key_env(provider: str) -> Optional[str]:
|
||||
def get_api_key_env(provider: str) -> str | None:
|
||||
"""Return the env var name for `provider`'s API key, or None if not applicable.
|
||||
|
||||
Unknown providers also return None — callers should treat that as
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .validators import validate_model
|
||||
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"timeout", "max_retries", "api_key", "reasoning_effort", "temperature",
|
||||
@@ -29,7 +28,7 @@ class AzureOpenAIClient(BaseLLMClient):
|
||||
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
def __init__(self, model: str, base_url: str | None = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
def normalize_content(response):
|
||||
@@ -25,7 +25,7 @@ def normalize_content(response):
|
||||
class BaseLLMClient(ABC):
|
||||
"""Abstract base class for LLM clients."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
def __init__(self, model: str, base_url: str | None = None, **kwargs):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.kwargs = kwargs
|
||||
|
||||
@@ -18,7 +18,6 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
StructuredMethod = Literal[
|
||||
"function_calling", # uses tools; respects supports_tool_choice
|
||||
"json_mode", # uses response_format={"type":"json_object"}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
provider: str,
|
||||
model: str,
|
||||
base_url: Optional[str] = None,
|
||||
base_url: str | None = None,
|
||||
**kwargs,
|
||||
) -> BaseLLMClient:
|
||||
"""Create an LLM client for the specified provider.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
@@ -20,7 +20,7 @@ class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI):
|
||||
class GoogleClient(BaseLLMClient):
|
||||
"""Client for Google Gemini models."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
def __init__(self, model: str, base_url: str | None = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
|
||||
@@ -2,14 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
ModelOption = Tuple[str, str]
|
||||
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
|
||||
ModelOption = tuple[str, str]
|
||||
ProviderModeOptions = dict[str, dict[str, list[ModelOption]]]
|
||||
|
||||
# Providers that serve many / frequently-changing models: offer only "Custom
|
||||
# model ID" rather than a list that goes stale.
|
||||
_CUSTOM_ONLY: Dict[str, List[ModelOption]] = {
|
||||
_CUSTOM_ONLY: dict[str, list[ModelOption]] = {
|
||||
"quick": [("Custom model ID", "custom")],
|
||||
"deep": [("Custom model ID", "custom")],
|
||||
}
|
||||
@@ -18,7 +16,7 @@ _CUSTOM_ONLY: Dict[str, List[ModelOption]] = {
|
||||
# Shared model list for GLM via Z.AI (international) and BigModel (China).
|
||||
# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides).
|
||||
# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}.
|
||||
_GLM_MODELS: Dict[str, List[ModelOption]] = {
|
||||
_GLM_MODELS: dict[str, list[ModelOption]] = {
|
||||
"quick": [
|
||||
("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"),
|
||||
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
|
||||
@@ -44,7 +42,7 @@ _GLM_MODELS: Dict[str, List[ModelOption]] = {
|
||||
# the backing model. Users who want a specific generation pick it
|
||||
# explicitly; users who really want auto-latest can enter the alias via
|
||||
# "Custom model ID".
|
||||
_QWEN_MODELS: Dict[str, List[ModelOption]] = {
|
||||
_QWEN_MODELS: dict[str, list[ModelOption]] = {
|
||||
"quick": [
|
||||
("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"),
|
||||
("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"),
|
||||
@@ -62,7 +60,7 @@ _QWEN_MODELS: Dict[str, List[ModelOption]] = {
|
||||
# Shared model list for MiniMax's global and CN endpoints (same IDs).
|
||||
# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api.
|
||||
# All M2.x models share a 204,800-token context window.
|
||||
_MINIMAX_MODELS: Dict[str, List[ModelOption]] = {
|
||||
_MINIMAX_MODELS: dict[str, list[ModelOption]] = {
|
||||
"quick": [
|
||||
("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"),
|
||||
("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"),
|
||||
@@ -198,12 +196,12 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
||||
}
|
||||
|
||||
|
||||
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
|
||||
def get_model_options(provider: str, mode: str) -> list[ModelOption]:
|
||||
"""Return shared model options for a provider and selection mode."""
|
||||
return MODEL_OPTIONS[provider.lower()][mode]
|
||||
|
||||
|
||||
def get_known_models() -> Dict[str, List[str]]:
|
||||
def get_known_models() -> dict[str, list[str]]:
|
||||
"""Build known model names from the shared CLI catalog."""
|
||||
return {
|
||||
provider: sorted(
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from .model_catalog import get_known_models
|
||||
|
||||
|
||||
# Providers whose model names are user-defined (local servers, relays, hosted
|
||||
# OpenAI-compatible endpoints serving many models), so any model string is
|
||||
# accepted without warning.
|
||||
|
||||
Reference in New Issue
Block a user