mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
chore(lint): make the repository ruff-clean under the strict select
Clear the deferred full-repo lint backlog so the whole tree passes the strict ruff select (E,W,F,I,B,UP,C4,SIM). Mechanical fixes dominate: import sorting, pep585/604 annotations, dropped dead imports, and whitespace. The few semantic changes are behavior-preserving: declare __all__ on the agent_utils and alpha_vantage re-export hubs; expand 'from x import *' to explicit names; use immutable tuple defaults instead of mutable list defaults; contextlib.suppress for try/except/pass; and narrow an over-broad assertRaises.
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import getpass
|
import getpass
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|||||||
78
cli/main.py
78
cli/main.py
@@ -1,38 +1,53 @@
|
|||||||
from typing import Optional
|
|
||||||
import os
|
|
||||||
import datetime
|
import datetime
|
||||||
import typer
|
import os
|
||||||
import questionary
|
|
||||||
from pathlib import Path
|
|
||||||
from functools import wraps
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.spinner import Spinner
|
|
||||||
from rich.live import Live
|
|
||||||
from rich.columns import Columns
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
from rich.layout import Layout
|
|
||||||
from rich.text import Text
|
|
||||||
from rich.table import Table
|
|
||||||
from collections import deque
|
|
||||||
import time
|
import time
|
||||||
from rich.tree import Tree
|
from collections import deque
|
||||||
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
from rich import box
|
from rich import box
|
||||||
from rich.align import Align
|
from rich.align import Align
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.layout import Layout
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
|
from rich.spinner import Spinner
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from cli.announcements import display_announcements, fetch_announcements
|
||||||
|
from cli.stats_handler import StatsCallbackHandler
|
||||||
|
from cli.utils import (
|
||||||
|
ask_anthropic_effort,
|
||||||
|
ask_gemini_thinking_config,
|
||||||
|
ask_glm_region,
|
||||||
|
ask_minimax_region,
|
||||||
|
ask_openai_reasoning_effort,
|
||||||
|
ask_output_language,
|
||||||
|
ask_qwen_region,
|
||||||
|
confirm_ollama_endpoint,
|
||||||
|
detect_asset_type,
|
||||||
|
ensure_api_key,
|
||||||
|
get_ticker,
|
||||||
|
prompt_openai_compatible_url,
|
||||||
|
resolve_backend_url,
|
||||||
|
select_analysts,
|
||||||
|
select_deep_thinking_agent,
|
||||||
|
select_llm_provider,
|
||||||
|
select_research_depth,
|
||||||
|
select_shallow_thinking_agent,
|
||||||
|
)
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from tradingagents.graph.analyst_execution import (
|
from tradingagents.graph.analyst_execution import (
|
||||||
AnalystWallTimeTracker,
|
AnalystWallTimeTracker,
|
||||||
build_analyst_execution_plan,
|
build_analyst_execution_plan,
|
||||||
get_initial_analyst_node,
|
get_initial_analyst_node,
|
||||||
sync_analyst_tracker_from_chunk,
|
sync_analyst_tracker_from_chunk,
|
||||||
)
|
)
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from cli.models import AnalystType
|
|
||||||
from cli.utils import *
|
|
||||||
from cli.announcements import fetch_announcements, display_announcements
|
|
||||||
from cli.stats_handler import StatsCallbackHandler
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@@ -466,7 +481,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
|
|||||||
def get_user_selections():
|
def get_user_selections():
|
||||||
"""Get all user selections before starting the analysis display."""
|
"""Get all user selections before starting the analysis display."""
|
||||||
# Display ASCII art welcome message
|
# Display ASCII art welcome message
|
||||||
with open(Path(__file__).parent / "static" / "welcome.txt", "r", encoding="utf-8") as f:
|
with open(Path(__file__).parent / "static" / "welcome.txt", encoding="utf-8") as f:
|
||||||
welcome_ascii = f.read()
|
welcome_ascii = f.read()
|
||||||
|
|
||||||
# Create welcome box content
|
# Create welcome box content
|
||||||
@@ -922,8 +937,11 @@ def update_analyst_statuses(message_buffer, chunk, wall_time_tracker=None):
|
|||||||
message_buffer.update_agent_status(agent_name, "pending")
|
message_buffer.update_agent_status(agent_name, "pending")
|
||||||
|
|
||||||
# When all analysts complete, transition research team to in_progress
|
# When all analysts complete, transition research team to in_progress
|
||||||
if not found_active and selected:
|
if (
|
||||||
if message_buffer.agent_status.get("Bull Researcher") == "pending":
|
not found_active
|
||||||
|
and selected
|
||||||
|
and message_buffer.agent_status.get("Bull Researcher") == "pending"
|
||||||
|
):
|
||||||
message_buffer.update_agent_status("Bull Researcher", "in_progress")
|
message_buffer.update_agent_status("Bull Researcher", "in_progress")
|
||||||
|
|
||||||
def extract_content_string(content):
|
def extract_content_string(content):
|
||||||
@@ -1097,7 +1115,7 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
# Now start the display layout
|
# Now start the display layout
|
||||||
layout = create_layout()
|
layout = create_layout()
|
||||||
|
|
||||||
with Live(layout, refresh_per_second=4) as live:
|
with Live(layout, refresh_per_second=4):
|
||||||
# Initial display
|
# Initial display
|
||||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
@@ -1232,8 +1250,7 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
|
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
|
||||||
)
|
)
|
||||||
if judge:
|
if judge and message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
||||||
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
|
||||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||||
message_buffer.update_report_section(
|
message_buffer.update_report_section(
|
||||||
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
||||||
@@ -1253,7 +1270,6 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
final_state = {}
|
final_state = {}
|
||||||
for chunk in trace:
|
for chunk in trace:
|
||||||
final_state.update(chunk)
|
final_state.update(chunk)
|
||||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
|
||||||
|
|
||||||
# Update all agent statuses to completed
|
# Update all agent statuses to completed
|
||||||
for agent in message_buffer.agent_status:
|
for agent in message_buffer.agent_status:
|
||||||
@@ -1265,7 +1281,7 @@ def run_analysis(checkpoint: bool = False):
|
|||||||
message_buffer.add_message("System", analyst_wall_time_tracker.format_summary())
|
message_buffer.add_message("System", analyst_wall_time_tracker.format_summary())
|
||||||
|
|
||||||
# Update final report sections
|
# Update final report sections
|
||||||
for section in message_buffer.report_sections.keys():
|
for section in message_buffer.report_sections:
|
||||||
if section in final_state:
|
if section in final_state:
|
||||||
message_buffer.update_report_section(section, final_state[section])
|
message_buffer.update_report_section(section, final_state[section])
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Dict
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class AnalystType(str, Enum):
|
class AnalystType(str, Enum):
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.outputs import LLMResult
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
|
||||||
class StatsCallbackHandler(BaseCallbackHandler):
|
class StatsCallbackHandler(BaseCallbackHandler):
|
||||||
@@ -19,8 +19,8 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Increment LLM call counter when an LLM starts."""
|
"""Increment LLM call counter when an LLM starts."""
|
||||||
@@ -29,8 +29,8 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
messages: List[List[Any]],
|
messages: list[list[Any]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Increment LLM call counter when a chat model starts."""
|
"""Increment LLM call counter when a chat model starts."""
|
||||||
@@ -57,7 +57,7 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
input_str: str,
|
input_str: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -65,7 +65,7 @@ class StatsCallbackHandler(BaseCallbackHandler):
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self.tool_calls += 1
|
self.tool_calls += 1
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> dict[str, Any]:
|
||||||
"""Return current statistics."""
|
"""Return current statistics."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return {
|
return {
|
||||||
|
|||||||
11
cli/utils.py
11
cli/utils.py
@@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Dict
|
|
||||||
|
|
||||||
import questionary
|
import questionary
|
||||||
from dotenv import find_dotenv, set_key
|
from dotenv import find_dotenv, set_key
|
||||||
@@ -89,8 +88,8 @@ def detect_asset_type(ticker: str) -> AssetType:
|
|||||||
|
|
||||||
|
|
||||||
def filter_analysts_for_asset_type(
|
def filter_analysts_for_asset_type(
|
||||||
analysts: List[AnalystType], asset_type: AssetType
|
analysts: list[AnalystType], asset_type: AssetType
|
||||||
) -> List[AnalystType]:
|
) -> list[AnalystType]:
|
||||||
if asset_type != AssetType.CRYPTO:
|
if asset_type != AssetType.CRYPTO:
|
||||||
return analysts
|
return analysts
|
||||||
return [
|
return [
|
||||||
@@ -133,7 +132,7 @@ def get_analysis_date() -> str:
|
|||||||
return date.strip()
|
return date.strip()
|
||||||
|
|
||||||
|
|
||||||
def select_analysts(asset_type: AssetType = AssetType.STOCK) -> List[AnalystType]:
|
def select_analysts(asset_type: AssetType = AssetType.STOCK) -> list[AnalystType]:
|
||||||
"""Select analysts using an interactive checkbox."""
|
"""Select analysts using an interactive checkbox."""
|
||||||
available_analysts = filter_analysts_for_asset_type(
|
available_analysts = filter_analysts_for_asset_type(
|
||||||
[value for _, value in ANALYST_ORDER],
|
[value for _, value in ANALYST_ORDER],
|
||||||
@@ -197,7 +196,7 @@ def select_research_depth() -> int:
|
|||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def _fetch_openrouter_models() -> List[Tuple[str, str]]:
|
def _fetch_openrouter_models() -> list[tuple[str, str]]:
|
||||||
"""Fetch available models from the OpenRouter API."""
|
"""Fetch available models from the OpenRouter API."""
|
||||||
import requests
|
import requests
|
||||||
try:
|
try:
|
||||||
@@ -556,7 +555,7 @@ def confirm_ollama_endpoint(url: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_api_key(provider: str) -> Optional[str]:
|
def ensure_api_key(provider: str) -> str | None:
|
||||||
"""Make sure the API key for `provider` is available in the environment.
|
"""Make sure the API key for `provider` is available in the environment.
|
||||||
|
|
||||||
If the env var is already set, returns its value untouched. Otherwise
|
If the env var is already set, returns its value untouched. Otherwise
|
||||||
|
|||||||
2
main.py
2
main.py
@@ -1,5 +1,5 @@
|
|||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides
|
# DEFAULT_CONFIG already applies TRADINGAGENTS_* env-var overrides
|
||||||
# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.),
|
# (llm_provider, deep_think_llm, quick_think_llm, backend_url, etc.),
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ added, plus the heuristic SignalProcessor.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||||
@@ -30,7 +29,6 @@ from tradingagents.agents.trader.trader import create_trader
|
|||||||
from tradingagents.graph.signal_processing import SignalProcessor
|
from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
from tradingagents.llm_clients import create_llm_client
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_DEFAULTS = {
|
PROVIDER_DEFAULTS = {
|
||||||
"openai": ("gpt-5.4-mini", None),
|
"openai": ("gpt-5.4-mini", None),
|
||||||
"google": ("gemini-2.5-flash", None),
|
"google": ("gemini-2.5-flash", None),
|
||||||
|
|||||||
5
test.py
5
test.py
@@ -1,5 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
|
||||||
|
from tradingagents.dataflows.y_finance import (
|
||||||
|
get_stock_stats_indicators_window,
|
||||||
|
)
|
||||||
|
|
||||||
print("Testing optimized implementation with 30-day lookback:")
|
print("Testing optimized implementation with 30-day lookback:")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|||||||
@@ -3,14 +3,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
|
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
|
||||||
|
|
||||||
|
|
||||||
# ---- Mapping coverage -----------------------------------------------------
|
# ---- Mapping coverage -----------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -71,6 +69,7 @@ def test_case_insensitive_lookup():
|
|||||||
def cli_utils(monkeypatch):
|
def cli_utils(monkeypatch):
|
||||||
"""Import cli.utils with a fresh environment so module-level state is consistent."""
|
"""Import cli.utils with a fresh environment so module-level state is consistent."""
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import cli.utils as cli_utils_module
|
import cli.utils as cli_utils_module
|
||||||
return importlib.reload(cli_utils_module)
|
return importlib.reload(cli_utils_module)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Unit tests for the LLM capability table."""
|
"""Unit tests for the LLM capability table."""
|
||||||
|
|
||||||
|
from dataclasses import FrozenInstanceError
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tradingagents.llm_clients.capabilities import (
|
from tradingagents.llm_clients.capabilities import (
|
||||||
ModelCapabilities,
|
|
||||||
get_capabilities,
|
get_capabilities,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,5 +120,5 @@ class TestDefault:
|
|||||||
def test_capabilities_dataclass_is_frozen():
|
def test_capabilities_dataclass_is_frozen():
|
||||||
"""Capability rows are immutable so they can be safely shared."""
|
"""Capability rows are immutable so they can be safely shared."""
|
||||||
caps = get_capabilities("deepseek-chat")
|
caps = get_capabilities("deepseek-chat")
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(FrozenInstanceError):
|
||||||
caps.supports_tool_choice = False # type: ignore[misc]
|
caps.supports_tool_choice = False # type: ignore[misc]
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
||||||
from langgraph.graph import END, StateGraph
|
from langgraph.graph import END, StateGraph
|
||||||
|
|
||||||
from tradingagents.graph.checkpointer import (
|
from tradingagents.graph.checkpointer import (
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from tradingagents.llm_clients.openai_client import (
|
|||||||
_input_to_messages,
|
_input_to_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
||||||
# (Gemini bot review note: non-list inputs must also work)
|
# (Gemini bot review note: non-list inputs must also work)
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
|
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pandas as pd
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||||
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
|
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
|
||||||
|
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||||
|
from tradingagents.graph.propagation import Propagator
|
||||||
from tradingagents.graph.reflection import Reflector
|
from tradingagents.graph.reflection import Reflector
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.graph.propagation import Propagator
|
|
||||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
|
||||||
|
|
||||||
_SEP = TradingMemoryLog._SEPARATOR
|
_SEP = TradingMemoryLog._SEPARATOR
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ news injected future articles), #993 (empty-after-filter returned a blank body).
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from unittest import mock
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tradingagents.dataflows import stockstats_utils, interface
|
from tradingagents.dataflows import interface, stockstats_utils
|
||||||
from tradingagents.dataflows.config import set_config
|
from tradingagents.dataflows.config import set_config
|
||||||
from tradingagents.dataflows.symbol_utils import NoMarketDataError
|
from tradingagents.dataflows.symbol_utils import NoMarketDataError
|
||||||
|
|
||||||
@@ -33,8 +33,8 @@ class TestLoadOhlcvNoPoison(unittest.TestCase):
|
|||||||
|
|
||||||
def test_empty_download_raises_and_does_not_cache(self):
|
def test_empty_download_raises_and_does_not_cache(self):
|
||||||
empty = pd.DataFrame()
|
empty = pd.DataFrame()
|
||||||
with mock.patch.object(stockstats_utils.yf, "download", return_value=empty) as dl:
|
with mock.patch.object(stockstats_utils.yf, "download", return_value=empty), \
|
||||||
with self.assertRaises(NoMarketDataError):
|
self.assertRaises(NoMarketDataError):
|
||||||
stockstats_utils.load_ohlcv("FAKE", "2026-01-01")
|
stockstats_utils.load_ohlcv("FAKE", "2026-01-01")
|
||||||
# Nothing should have been written to the cache.
|
# Nothing should have been written to the cache.
|
||||||
self.assertEqual(os.listdir(self._tmp), [])
|
self.assertEqual(os.listdir(self._tmp), [])
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ def _resync_reloaded_modules():
|
|||||||
doesn't leak across test modules.
|
doesn't leak across test modules.
|
||||||
"""
|
"""
|
||||||
yield
|
yield
|
||||||
import cli.utils
|
|
||||||
import cli.main
|
import cli.main
|
||||||
|
import cli.utils
|
||||||
importlib.reload(cli.utils)
|
importlib.reload(cli.utils)
|
||||||
importlib.reload(cli.main)
|
importlib.reload(cli.main)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Verifies the user-supplied base_url is required and honored, the key is optional
|
|||||||
(keyless local default), Chat Completions (not the Responses API) is used, any
|
(keyless local default), Chat Completions (not the Responses API) is used, any
|
||||||
model name is accepted, and the env backend URL precedence (#978).
|
model name is accepted, and the env backend URL precedence (#978).
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import pytest
|
|||||||
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
|
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
|
||||||
from tradingagents.graph.signal_processing import SignalProcessor
|
from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Heuristic parser
|
# Heuristic parser
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from tradingagents.agents.schemas import (
|
|||||||
)
|
)
|
||||||
from tradingagents.agents.trader.trader import create_trader
|
from tradingagents.agents.trader.trader import create_trader
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Render functions
|
# Render functions
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import pytest
|
|||||||
|
|
||||||
from tradingagents.dataflows.symbol_utils import (
|
from tradingagents.dataflows.symbol_utils import (
|
||||||
NoMarketDataError,
|
NoMarketDataError,
|
||||||
normalize_symbol,
|
|
||||||
is_yahoo_safe,
|
is_yahoo_safe,
|
||||||
|
normalize_symbol,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ Regressions for #988 (explicit single-vendor config still fell back to others),
|
|||||||
were swallowed without a trace).
|
were swallowed without a trace).
|
||||||
"""
|
"""
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
@@ -76,8 +75,8 @@ class VendorRoutingTests(unittest.TestCase):
|
|||||||
# #989: primary errors + fallback no-data -> NO_DATA, but the failure
|
# #989: primary errors + fallback no-data -> NO_DATA, but the failure
|
||||||
# must be visible in logs (broken primary not hidden).
|
# must be visible in logs (broken primary not hidden).
|
||||||
set_config({"data_vendors": {"core_stock_apis": "yfinance,alpha_vantage"}})
|
set_config({"data_vendors": {"core_stock_apis": "yfinance,alpha_vantage"}})
|
||||||
with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}):
|
with self._route({"yfinance": _raises(ValueError("boom")), "alpha_vantage": _no_data}), \
|
||||||
with self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm:
|
self.assertLogs("tradingagents.dataflows.interface", level="WARNING") as cm:
|
||||||
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
result = interface.route_to_vendor("get_stock_data", "AAPL", "2026-01-01", "2026-01-10")
|
||||||
self.assertIn("NO_DATA_AVAILABLE", result)
|
self.assertIn("NO_DATA_AVAILABLE", result)
|
||||||
joined = "\n".join(cm.output)
|
joined = "\n".join(cm.output)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
|
# Load .env files at package import so DEFAULT_CONFIG's env-var overlay
|
||||||
@@ -20,10 +21,8 @@ except ImportError:
|
|||||||
# subclassed warning categories. To suppress a specific warning we must
|
# subclassed warning categories. To suppress a specific warning we must
|
||||||
# install our filter AFTER langchain-core has installed its own, so import
|
# install our filter AFTER langchain-core has installed its own, so import
|
||||||
# it first. The package is a guaranteed transitive dep via langgraph.
|
# it first. The package is a guaranteed transitive dep via langgraph.
|
||||||
try:
|
with contextlib.suppress(ImportError):
|
||||||
import langchain_core # noqa: F401
|
import langchain_core # noqa: F401
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
|
# langgraph-checkpoint 4.0.3 calls Reviver() at module load without an
|
||||||
# explicit allowed_objects, which triggers a noisy pending-deprecation
|
# explicit allowed_objects, which triggers a noisy pending-deprecation
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
from .utils.agent_utils import create_msg_delete
|
|
||||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
|
||||||
|
|
||||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||||
from .analysts.market_analyst import create_market_analyst
|
from .analysts.market_analyst import create_market_analyst
|
||||||
from .analysts.news_analyst import create_news_analyst
|
from .analysts.news_analyst import create_news_analyst
|
||||||
@@ -8,18 +5,16 @@ from .analysts.sentiment_analyst import (
|
|||||||
create_sentiment_analyst,
|
create_sentiment_analyst,
|
||||||
create_social_media_analyst, # deprecated alias kept for back-compat
|
create_social_media_analyst, # deprecated alias kept for back-compat
|
||||||
)
|
)
|
||||||
|
from .managers.portfolio_manager import create_portfolio_manager
|
||||||
|
from .managers.research_manager import create_research_manager
|
||||||
from .researchers.bear_researcher import create_bear_researcher
|
from .researchers.bear_researcher import create_bear_researcher
|
||||||
from .researchers.bull_researcher import create_bull_researcher
|
from .researchers.bull_researcher import create_bull_researcher
|
||||||
|
|
||||||
from .risk_mgmt.aggressive_debator import create_aggressive_debator
|
from .risk_mgmt.aggressive_debator import create_aggressive_debator
|
||||||
from .risk_mgmt.conservative_debator import create_conservative_debator
|
from .risk_mgmt.conservative_debator import create_conservative_debator
|
||||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||||
|
|
||||||
from .managers.research_manager import create_research_manager
|
|
||||||
from .managers.portfolio_manager import create_portfolio_manager
|
|
||||||
|
|
||||||
from .trader.trader import create_trader
|
from .trader.trader import create_trader
|
||||||
|
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||||
|
from .utils.agent_utils import create_msg_delete
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentState",
|
"AgentState",
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
get_instrument_context_from_state,
|
|
||||||
get_balance_sheet,
|
get_balance_sheet,
|
||||||
get_cashflow,
|
get_cashflow,
|
||||||
get_fundamentals,
|
get_fundamentals,
|
||||||
get_income_statement,
|
get_income_statement,
|
||||||
get_insider_transactions,
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
|
||||||
|
|
||||||
|
|
||||||
def create_fundamentals_analyst(llm):
|
def create_fundamentals_analyst(llm):
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
get_instrument_context_from_state,
|
|
||||||
get_indicators,
|
get_indicators,
|
||||||
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
get_verified_market_snapshot,
|
get_verified_market_snapshot,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
|
||||||
|
|
||||||
|
|
||||||
def create_market_analyst(llm):
|
def create_market_analyst(llm):
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
get_instrument_context_from_state,
|
|
||||||
get_global_news,
|
get_global_news,
|
||||||
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
get_macro_indicators,
|
get_macro_indicators,
|
||||||
get_news,
|
get_news,
|
||||||
get_prediction_markets,
|
get_prediction_markets,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
|
||||||
|
|
||||||
|
|
||||||
def create_news_analyst(llm):
|
def create_news_analyst(llm):
|
||||||
|
|||||||
@@ -19,11 +19,10 @@ so that:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Shared rating types
|
# Shared rating types
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -124,15 +123,15 @@ class TraderProposal(BaseModel):
|
|||||||
"the research plan. Two to four sentences."
|
"the research plan. Two to four sentences."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
entry_price: Optional[float] = Field(
|
entry_price: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional entry price target in the instrument's quote currency.",
|
description="Optional entry price target in the instrument's quote currency.",
|
||||||
)
|
)
|
||||||
stop_loss: Optional[float] = Field(
|
stop_loss: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional stop-loss price in the instrument's quote currency.",
|
description="Optional stop-loss price in the instrument's quote currency.",
|
||||||
)
|
)
|
||||||
position_sizing: Optional[str] = Field(
|
position_sizing: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional sizing guidance, e.g. '5% of portfolio'.",
|
description="Optional sizing guidance, e.g. '5% of portfolio'.",
|
||||||
)
|
)
|
||||||
@@ -196,11 +195,11 @@ class PortfolioDecision(BaseModel):
|
|||||||
"incorporate them; otherwise rely solely on the current analysis."
|
"incorporate them; otherwise rely solely on the current analysis."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
price_target: Optional[float] = Field(
|
price_target: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional target price in the instrument's quote currency.",
|
description="Optional target price in the instrument's quote currency.",
|
||||||
)
|
)
|
||||||
time_horizon: Optional[str] = Field(
|
time_horizon: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional recommended holding period, e.g. '3-6 months'.",
|
description="Optional recommended holding period, e.g. '3-6 months'.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from typing_extensions import TypedDict
|
|
||||||
from langgraph.graph import MessagesState
|
from langgraph.graph import MessagesState
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
# Researcher team state
|
# Researcher team state
|
||||||
|
|||||||
@@ -1,37 +1,50 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Mapping, Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||||
|
|
||||||
# Import tools from separate utility files
|
# Import tools from separate utility files
|
||||||
from tradingagents.agents.utils.core_stock_tools import (
|
from tradingagents.agents.utils.core_stock_tools import get_stock_data
|
||||||
get_stock_data
|
|
||||||
)
|
|
||||||
from tradingagents.agents.utils.technical_indicators_tools import (
|
|
||||||
get_indicators
|
|
||||||
)
|
|
||||||
from tradingagents.agents.utils.fundamental_data_tools import (
|
from tradingagents.agents.utils.fundamental_data_tools import (
|
||||||
get_fundamentals,
|
|
||||||
get_balance_sheet,
|
get_balance_sheet,
|
||||||
get_cashflow,
|
get_cashflow,
|
||||||
get_income_statement
|
get_fundamentals,
|
||||||
|
get_income_statement,
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.macro_data_tools import get_macro_indicators
|
||||||
|
from tradingagents.agents.utils.market_data_validation_tools import get_verified_market_snapshot
|
||||||
from tradingagents.agents.utils.news_data_tools import (
|
from tradingagents.agents.utils.news_data_tools import (
|
||||||
get_news,
|
get_global_news,
|
||||||
get_insider_transactions,
|
get_insider_transactions,
|
||||||
get_global_news
|
get_news,
|
||||||
)
|
|
||||||
from tradingagents.agents.utils.macro_data_tools import (
|
|
||||||
get_macro_indicators
|
|
||||||
)
|
|
||||||
from tradingagents.agents.utils.prediction_markets_tools import (
|
|
||||||
get_prediction_markets
|
|
||||||
)
|
|
||||||
from tradingagents.agents.utils.market_data_validation_tools import (
|
|
||||||
get_verified_market_snapshot
|
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.prediction_markets_tools import get_prediction_markets
|
||||||
|
from tradingagents.agents.utils.technical_indicators_tools import get_indicators
|
||||||
|
|
||||||
|
# Public surface: the data tools are imported here so agents and the graph
|
||||||
|
# import them from one place, plus the instrument/language helpers defined below.
|
||||||
|
__all__ = [
|
||||||
|
"get_stock_data",
|
||||||
|
"get_indicators",
|
||||||
|
"get_fundamentals",
|
||||||
|
"get_balance_sheet",
|
||||||
|
"get_cashflow",
|
||||||
|
"get_income_statement",
|
||||||
|
"get_news",
|
||||||
|
"get_global_news",
|
||||||
|
"get_insider_transactions",
|
||||||
|
"get_macro_indicators",
|
||||||
|
"get_prediction_markets",
|
||||||
|
"get_verified_market_snapshot",
|
||||||
|
"build_instrument_context",
|
||||||
|
"resolve_instrument_identity",
|
||||||
|
"get_instrument_context_from_state",
|
||||||
|
"get_language_instruction",
|
||||||
|
"create_msg_delete",
|
||||||
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -52,7 +65,7 @@ def get_language_instruction() -> str:
|
|||||||
return f" Write your entire response in {lang}."
|
return f" Write your entire response in {lang}."
|
||||||
|
|
||||||
|
|
||||||
def _clean_identity_value(value: Any) -> Optional[str]:
|
def _clean_identity_value(value: Any) -> str | None:
|
||||||
"""Return a trimmed string, or None for empty / placeholder-ish values."""
|
"""Return a trimmed string, or None for empty / placeholder-ish values."""
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
return None
|
return None
|
||||||
@@ -109,7 +122,7 @@ def resolve_instrument_identity(ticker: str) -> dict:
|
|||||||
def build_instrument_context(
|
def build_instrument_context(
|
||||||
ticker: str,
|
ticker: str,
|
||||||
asset_type: str = "stock",
|
asset_type: str = "stock",
|
||||||
identity: Optional[Mapping[str, str]] = None,
|
identity: Mapping[str, str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Describe the exact instrument so agents preserve identity and ticker.
|
"""Describe the exact instrument so agents preserve identity and ticker.
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from langchain_core.tools import tool
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from tradingagents.dataflows.interface import route_to_vendor
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from langchain_core.tools import tool
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from tradingagents.dataflows.interface import route_to_vendor
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Append-only markdown decision log for TradingAgents."""
|
"""Append-only markdown decision log for TradingAgents."""
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
from pathlib import Path
|
|
||||||
import re
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from tradingagents.agents.utils.rating import parse_rating
|
from tradingagents.agents.utils.rating import parse_rating
|
||||||
|
|
||||||
@@ -51,7 +50,7 @@ class TradingMemoryLog:
|
|||||||
|
|
||||||
# --- Read path (Phase A) ---
|
# --- Read path (Phase A) ---
|
||||||
|
|
||||||
def load_entries(self) -> List[dict]:
|
def load_entries(self) -> list[dict]:
|
||||||
"""Parse all entries from log. Returns list of dicts."""
|
"""Parse all entries from log. Returns list of dicts."""
|
||||||
if not self._log_path or not self._log_path.exists():
|
if not self._log_path or not self._log_path.exists():
|
||||||
return []
|
return []
|
||||||
@@ -64,7 +63,7 @@ class TradingMemoryLog:
|
|||||||
entries.append(parsed)
|
entries.append(parsed)
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
def get_pending_entries(self) -> List[dict]:
|
def get_pending_entries(self) -> list[dict]:
|
||||||
"""Return entries with outcome:pending (for Phase B)."""
|
"""Return entries with outcome:pending (for Phase B)."""
|
||||||
return [e for e in self.load_entries() if e.get("pending")]
|
return [e for e in self.load_entries() if e.get("pending")]
|
||||||
|
|
||||||
@@ -162,7 +161,7 @@ class TradingMemoryLog:
|
|||||||
tmp_path.write_text(new_text, encoding="utf-8")
|
tmp_path.write_text(new_text, encoding="utf-8")
|
||||||
tmp_path.replace(self._log_path)
|
tmp_path.replace(self._log_path)
|
||||||
|
|
||||||
def batch_update_with_outcomes(self, updates: List[dict]) -> None:
|
def batch_update_with_outcomes(self, updates: list[dict]) -> None:
|
||||||
"""Apply multiple outcome updates in a single read + atomic write.
|
"""Apply multiple outcome updates in a single read + atomic write.
|
||||||
|
|
||||||
Each element of updates must have keys: ticker, trade_date,
|
Each element of updates must have keys: ticker, trade_date,
|
||||||
@@ -218,7 +217,7 @@ class TradingMemoryLog:
|
|||||||
|
|
||||||
# --- Helpers ---
|
# --- Helpers ---
|
||||||
|
|
||||||
def _apply_rotation(self, blocks: List[str]) -> List[str]:
|
def _apply_rotation(self, blocks: list[str]) -> list[str]:
|
||||||
"""Drop oldest resolved blocks when their count exceeds max_entries.
|
"""Drop oldest resolved blocks when their count exceeds max_entries.
|
||||||
|
|
||||||
Pending blocks are always kept (they represent unprocessed work).
|
Pending blocks are always kept (they represent unprocessed work).
|
||||||
@@ -247,7 +246,7 @@ class TradingMemoryLog:
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
to_drop = resolved_count - self._max_entries
|
to_drop = resolved_count - self._max_entries
|
||||||
kept: List[str] = []
|
kept: list[str] = []
|
||||||
for block, is_resolved in decisions:
|
for block, is_resolved in decisions:
|
||||||
if is_resolved and to_drop > 0:
|
if is_resolved and to_drop > 0:
|
||||||
to_drop -= 1
|
to_drop -= 1
|
||||||
@@ -255,7 +254,7 @@ class TradingMemoryLog:
|
|||||||
kept.append(block)
|
kept.append(block)
|
||||||
return kept
|
return kept
|
||||||
|
|
||||||
def _parse_entry(self, raw: str) -> Optional[dict]:
|
def _parse_entry(self, raw: str) -> dict | None:
|
||||||
lines = raw.strip().splitlines()
|
lines = raw.strip().splitlines()
|
||||||
if not lines:
|
if not lines:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from typing import Annotated, Optional
|
|
||||||
from tradingagents.dataflows.interface import route_to_vendor
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_news(
|
def get_news(
|
||||||
ticker: Annotated[str, "Ticker symbol"],
|
ticker: Annotated[str, "Ticker symbol"],
|
||||||
@@ -23,8 +26,8 @@ def get_news(
|
|||||||
@tool
|
@tool
|
||||||
def get_global_news(
|
def get_global_news(
|
||||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||||
look_back_days: Annotated[Optional[int], "Days to look back; omit to use the configured default"] = None,
|
look_back_days: Annotated[int | None, "Days to look back; omit to use the configured default"] = None,
|
||||||
limit: Annotated[Optional[int], "Max articles to return; omit to use the configured default"] = None,
|
limit: Annotated[int | None, "Max articles to return; omit to use the configured default"] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve global news data.
|
Retrieve global news data.
|
||||||
|
|||||||
@@ -12,11 +12,9 @@ Centralising it here avoids drift between those call sites.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
|
||||||
# Canonical, ordered 5-tier scale (most bullish to most bearish).
|
# Canonical, ordered 5-tier scale (most bullish to most bearish).
|
||||||
RATINGS_5_TIER: Tuple[str, ...] = (
|
RATINGS_5_TIER: tuple[str, ...] = (
|
||||||
"Buy", "Overweight", "Hold", "Underweight", "Sell",
|
"Buy", "Overweight", "Hold", "Underweight", "Sell",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ all three agents log the same warnings when fallback fires.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Optional, TypeVar
|
from collections.abc import Callable
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]:
|
def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Any | None:
|
||||||
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
|
"""Return ``llm.with_structured_output(schema)`` or ``None`` if unsupported.
|
||||||
|
|
||||||
Logs a warning when the binding fails so the user understands the agent
|
Logs a warning when the binding fails so the user understands the agent
|
||||||
@@ -46,7 +47,7 @@ def bind_structured(llm: Any, schema: type[T], agent_name: str) -> Optional[Any]
|
|||||||
|
|
||||||
|
|
||||||
def invoke_structured_or_freetext(
|
def invoke_structured_or_freetext(
|
||||||
structured_llm: Optional[Any],
|
structured_llm: Any | None,
|
||||||
plain_llm: Any,
|
plain_llm: Any,
|
||||||
prompt: Any,
|
prompt: Any,
|
||||||
render: Callable[[T], str],
|
render: Callable[[T], str],
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from langchain_core.tools import tool
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from tradingagents.dataflows.interface import route_to_vendor
|
from tradingagents.dataflows.interface import route_to_vendor
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_indicators(
|
def get_indicators(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
|||||||
@@ -1,5 +1,23 @@
|
|||||||
# Import functions from specialized modules
|
# Aggregates the per-category Alpha Vantage implementations into one module the
|
||||||
from .alpha_vantage_stock import get_stock
|
# vendor router imports from; the imports below are the public surface.
|
||||||
|
from .alpha_vantage_fundamentals import (
|
||||||
|
get_balance_sheet,
|
||||||
|
get_cashflow,
|
||||||
|
get_fundamentals,
|
||||||
|
get_income_statement,
|
||||||
|
)
|
||||||
from .alpha_vantage_indicator import get_indicator
|
from .alpha_vantage_indicator import get_indicator
|
||||||
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
|
from .alpha_vantage_news import get_global_news, get_insider_transactions, get_news
|
||||||
from .alpha_vantage_news import get_news, get_global_news, get_insider_transactions
|
from .alpha_vantage_stock import get_stock
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_balance_sheet",
|
||||||
|
"get_cashflow",
|
||||||
|
"get_fundamentals",
|
||||||
|
"get_income_statement",
|
||||||
|
"get_indicator",
|
||||||
|
"get_global_news",
|
||||||
|
"get_insider_transactions",
|
||||||
|
"get_news",
|
||||||
|
"get_stock",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .alpha_vantage_common import _make_api_request, AlphaVantageNotConfiguredError
|
from .alpha_vantage_common import AlphaVantageNotConfiguredError, _make_api_request
|
||||||
|
|
||||||
|
|
||||||
def get_indicator(
|
def get_indicator(
|
||||||
symbol: str,
|
symbol: str,
|
||||||
@@ -25,6 +26,7 @@ def get_indicator(
|
|||||||
String containing indicator values and description
|
String containing indicator values and description
|
||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
supported_indicators = {
|
supported_indicators = {
|
||||||
@@ -98,21 +100,7 @@ def get_indicator(
|
|||||||
"series_type": series_type,
|
"series_type": series_type,
|
||||||
"datatype": "csv"
|
"datatype": "csv"
|
||||||
})
|
})
|
||||||
elif indicator == "macd":
|
elif indicator == "macd" or indicator == "macds" or indicator == "macdh":
|
||||||
data = _make_api_request("MACD", {
|
|
||||||
"symbol": symbol,
|
|
||||||
"interval": interval,
|
|
||||||
"series_type": series_type,
|
|
||||||
"datatype": "csv"
|
|
||||||
})
|
|
||||||
elif indicator == "macds":
|
|
||||||
data = _make_api_request("MACD", {
|
|
||||||
"symbol": symbol,
|
|
||||||
"interval": interval,
|
|
||||||
"series_type": series_type,
|
|
||||||
"datatype": "csv"
|
|
||||||
})
|
|
||||||
elif indicator == "macdh":
|
|
||||||
data = _make_api_request("MACD", {
|
data = _make_api_request("MACD", {
|
||||||
"symbol": symbol,
|
"symbol": symbol,
|
||||||
"interval": interval,
|
"interval": interval,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||||
|
|
||||||
|
|
||||||
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||||
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
|
|
||||||
|
from .alpha_vantage_common import _filter_csv_by_date_range, _make_api_request
|
||||||
|
|
||||||
|
|
||||||
def get_stock(
|
def get_stock(
|
||||||
symbol: str,
|
symbol: str,
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import tradingagents.default_config as default_config
|
import tradingagents.default_config as default_config
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
# Use default config but allow it to be overridden
|
||||||
_config: Optional[Dict] = None
|
_config: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
def initialize_config():
|
def initialize_config():
|
||||||
@@ -14,7 +13,7 @@ def initialize_config():
|
|||||||
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Dict):
|
def set_config(config: dict):
|
||||||
"""Update the configuration with custom values.
|
"""Update the configuration with custom values.
|
||||||
|
|
||||||
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
|
Dict-valued keys (e.g. ``data_vendors``) are merged one level deep so a
|
||||||
@@ -31,7 +30,7 @@ def set_config(config: Dict):
|
|||||||
_config[key] = value
|
_config[key] = value
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> Dict:
|
def get_config() -> dict:
|
||||||
"""Get the current configuration."""
|
"""Get the current configuration."""
|
||||||
if _config is None:
|
if _config is None:
|
||||||
initialize_config()
|
initialize_config()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ claim. Deterministic, no LLM involved.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable, Optional
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from stockstats import wrap
|
from stockstats import wrap
|
||||||
@@ -63,7 +63,7 @@ def build_verified_market_snapshot(
|
|||||||
symbol: str,
|
symbol: str,
|
||||||
curr_date: str,
|
curr_date: str,
|
||||||
look_back_days: int = 30,
|
look_back_days: int = 30,
|
||||||
indicators: Optional[Iterable[str]] = None,
|
indicators: Iterable[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Render a ground-truth snapshot: latest OHLCV row, indicators, recent closes."""
|
"""Render a ground-truth snapshot: latest OHLCV row, indicators, recent closes."""
|
||||||
# `df` keeps the original capitalized OHLCV columns (Open/High/Low/Close/
|
# `df` keeps the original capitalized OHLCV columns (Open/High/Low/Close/
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import json
|
from datetime import date, datetime, timedelta
|
||||||
import pandas as pd
|
|
||||||
from datetime import date, timedelta, datetime
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||||
|
|
||||||
# Tickers can contain letters, digits, dot, dash, underscore, caret
|
# Tickers can contain letters, digits, dot, dash, underscore, caret
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""yfinance-based news data fetching functions."""
|
"""yfinance-based news data fetching functions."""
|
||||||
|
|
||||||
from typing import Optional
|
import contextlib
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from datetime import datetime
|
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
@@ -28,10 +28,8 @@ def _extract_article_data(article: dict) -> dict:
|
|||||||
pub_date_str = content.get("pubDate", "")
|
pub_date_str = content.get("pubDate", "")
|
||||||
pub_date = None
|
pub_date = None
|
||||||
if pub_date_str:
|
if pub_date_str:
|
||||||
try:
|
with contextlib.suppress(ValueError, AttributeError):
|
||||||
pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00"))
|
pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00"))
|
||||||
except (ValueError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"title": title,
|
"title": title,
|
||||||
@@ -47,10 +45,8 @@ def _extract_article_data(article: dict) -> dict:
|
|||||||
pub_date = None
|
pub_date = None
|
||||||
ts = article.get("providerPublishTime")
|
ts = article.get("providerPublishTime")
|
||||||
if ts:
|
if ts:
|
||||||
try:
|
with contextlib.suppress(ValueError, OSError, TypeError):
|
||||||
pub_date = datetime.fromtimestamp(ts)
|
pub_date = datetime.fromtimestamp(ts)
|
||||||
except (ValueError, OSError, TypeError):
|
|
||||||
pass
|
|
||||||
return {
|
return {
|
||||||
"title": article.get("title", "No title"),
|
"title": article.get("title", "No title"),
|
||||||
"summary": article.get("summary", ""),
|
"summary": article.get("summary", ""),
|
||||||
@@ -131,8 +127,8 @@ def get_news_yfinance(
|
|||||||
|
|
||||||
def get_global_news_yfinance(
|
def get_global_news_yfinance(
|
||||||
curr_date: str,
|
curr_date: str,
|
||||||
look_back_days: Optional[int] = None,
|
look_back_days: int | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve global/macro economic news using yfinance Search.
|
Retrieve global/macro economic news using yfinance Search.
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
# TradingAgents/graph/__init__.py
|
# TradingAgents/graph/__init__.py
|
||||||
|
|
||||||
from .trading_graph import TradingAgentsGraph
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
from .reflection import Reflector
|
from .reflection import Reflector
|
||||||
|
from .setup import GraphSetup
|
||||||
from .signal_processing import SignalProcessor
|
from .signal_processing import SignalProcessor
|
||||||
|
from .trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TradingAgentsGraph",
|
"TradingAgentsGraph",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
from typing import Dict, Iterable, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -14,11 +14,11 @@ class AnalystNodeSpec:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class AnalystExecutionPlan:
|
class AnalystExecutionPlan:
|
||||||
specs: List[AnalystNodeSpec]
|
specs: list[AnalystNodeSpec]
|
||||||
concurrency_limit: int
|
concurrency_limit: int
|
||||||
|
|
||||||
|
|
||||||
ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = {
|
ANALYST_NODE_SPECS: dict[str, AnalystNodeSpec] = {
|
||||||
"market": AnalystNodeSpec(
|
"market": AnalystNodeSpec(
|
||||||
key="market",
|
key="market",
|
||||||
agent_node="Market Analyst",
|
agent_node="Market Analyst",
|
||||||
@@ -61,7 +61,7 @@ def build_analyst_execution_plan(
|
|||||||
if concurrency_limit < 1:
|
if concurrency_limit < 1:
|
||||||
raise ValueError("analyst concurrency limit must be >= 1")
|
raise ValueError("analyst concurrency limit must be >= 1")
|
||||||
|
|
||||||
specs: List[AnalystNodeSpec] = []
|
specs: list[AnalystNodeSpec] = []
|
||||||
for analyst_key in selected_analysts:
|
for analyst_key in selected_analysts:
|
||||||
spec = ANALYST_NODE_SPECS.get(analyst_key)
|
spec = ANALYST_NODE_SPECS.get(analyst_key)
|
||||||
if spec is None:
|
if spec is None:
|
||||||
@@ -81,10 +81,10 @@ def get_initial_analyst_node(plan: AnalystExecutionPlan) -> str:
|
|||||||
class AnalystWallTimeTracker:
|
class AnalystWallTimeTracker:
|
||||||
def __init__(self, plan: AnalystExecutionPlan):
|
def __init__(self, plan: AnalystExecutionPlan):
|
||||||
self.plan = plan
|
self.plan = plan
|
||||||
self._started_at: Dict[str, float] = {}
|
self._started_at: dict[str, float] = {}
|
||||||
self._wall_times: Dict[str, float] = {}
|
self._wall_times: dict[str, float] = {}
|
||||||
|
|
||||||
def mark_started(self, analyst_key: str, started_at: Optional[float] = None) -> None:
|
def mark_started(self, analyst_key: str, started_at: float | None = None) -> None:
|
||||||
if analyst_key not in ANALYST_NODE_SPECS:
|
if analyst_key not in ANALYST_NODE_SPECS:
|
||||||
raise ValueError(f"unknown analyst key: {analyst_key}")
|
raise ValueError(f"unknown analyst key: {analyst_key}")
|
||||||
self._started_at.setdefault(analyst_key, monotonic() if started_at is None else started_at)
|
self._started_at.setdefault(analyst_key, monotonic() if started_at is None else started_at)
|
||||||
@@ -92,7 +92,7 @@ class AnalystWallTimeTracker:
|
|||||||
def mark_completed(
|
def mark_completed(
|
||||||
self,
|
self,
|
||||||
analyst_key: str,
|
analyst_key: str,
|
||||||
completed_at: Optional[float] = None,
|
completed_at: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if analyst_key not in ANALYST_NODE_SPECS:
|
if analyst_key not in ANALYST_NODE_SPECS:
|
||||||
raise ValueError(f"unknown analyst key: {analyst_key}")
|
raise ValueError(f"unknown analyst key: {analyst_key}")
|
||||||
@@ -104,7 +104,7 @@ class AnalystWallTimeTracker:
|
|||||||
finished_at = monotonic() if completed_at is None else completed_at
|
finished_at = monotonic() if completed_at is None else completed_at
|
||||||
self._wall_times[analyst_key] = max(0.0, finished_at - started_at)
|
self._wall_times[analyst_key] = max(0.0, finished_at - started_at)
|
||||||
|
|
||||||
def get_wall_times(self) -> Dict[str, float]:
|
def get_wall_times(self) -> dict[str, float]:
|
||||||
return dict(self._wall_times)
|
return dict(self._wall_times)
|
||||||
|
|
||||||
def format_summary(self) -> str:
|
def format_summary(self) -> str:
|
||||||
@@ -121,8 +121,8 @@ class AnalystWallTimeTracker:
|
|||||||
|
|
||||||
def sync_analyst_tracker_from_chunk(
|
def sync_analyst_tracker_from_chunk(
|
||||||
tracker: AnalystWallTimeTracker,
|
tracker: AnalystWallTimeTracker,
|
||||||
chunk: Dict[str, str],
|
chunk: dict[str, str],
|
||||||
now: Optional[float] = None,
|
now: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_time = monotonic() if now is None else now
|
current_time = monotonic() if now is None else now
|
||||||
active_found = False
|
active_found = False
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
# TradingAgents/graph/propagation.py
|
# TradingAgents/graph/propagation.py
|
||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_states import (
|
from tradingagents.agents.utils.agent_states import (
|
||||||
AgentState,
|
|
||||||
InvestDebateState,
|
InvestDebateState,
|
||||||
RiskDebateState,
|
RiskDebateState,
|
||||||
)
|
)
|
||||||
@@ -22,7 +22,7 @@ class Propagator:
|
|||||||
asset_type: str = "stock",
|
asset_type: str = "stock",
|
||||||
past_context: str = "",
|
past_context: str = "",
|
||||||
instrument_context: str = "",
|
instrument_context: str = "",
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Create the initial state for the agent graph.
|
"""Create the initial state for the agent graph.
|
||||||
|
|
||||||
``instrument_context`` is the deterministic ticker-identity string
|
``instrument_context`` is the deterministic ticker-identity string
|
||||||
@@ -68,7 +68,7 @@ class Propagator:
|
|||||||
"news_report": "",
|
"news_report": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
|
def get_graph_args(self, callbacks: list | None = None) -> dict[str, Any]:
|
||||||
"""Get arguments for the graph invocation.
|
"""Get arguments for the graph invocation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,10 +1,25 @@
|
|||||||
# TradingAgents/graph/setup.py
|
# TradingAgents/graph/setup.py
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from langgraph.graph import END, START, StateGraph
|
from langgraph.graph import END, START, StateGraph
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import (
|
||||||
|
create_aggressive_debator,
|
||||||
|
create_bear_researcher,
|
||||||
|
create_bull_researcher,
|
||||||
|
create_conservative_debator,
|
||||||
|
create_fundamentals_analyst,
|
||||||
|
create_market_analyst,
|
||||||
|
create_msg_delete,
|
||||||
|
create_neutral_debator,
|
||||||
|
create_news_analyst,
|
||||||
|
create_portfolio_manager,
|
||||||
|
create_research_manager,
|
||||||
|
create_sentiment_analyst,
|
||||||
|
create_trader,
|
||||||
|
)
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
|
|
||||||
from .analyst_execution import build_analyst_execution_plan
|
from .analyst_execution import build_analyst_execution_plan
|
||||||
@@ -18,7 +33,7 @@ class GraphSetup:
|
|||||||
self,
|
self,
|
||||||
quick_thinking_llm: Any,
|
quick_thinking_llm: Any,
|
||||||
deep_thinking_llm: Any,
|
deep_thinking_llm: Any,
|
||||||
tool_nodes: Dict[str, ToolNode],
|
tool_nodes: dict[str, ToolNode],
|
||||||
conditional_logic: ConditionalLogic,
|
conditional_logic: ConditionalLogic,
|
||||||
analyst_concurrency_limit: int = 1,
|
analyst_concurrency_limit: int = 1,
|
||||||
):
|
):
|
||||||
@@ -30,7 +45,7 @@ class GraphSetup:
|
|||||||
self.analyst_concurrency_limit = analyst_concurrency_limit
|
self.analyst_concurrency_limit = analyst_concurrency_limit
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
self, selected_analysts=("market", "social", "news", "fundamentals")
|
||||||
):
|
):
|
||||||
"""Set up and compile the agent workflow graph.
|
"""Set up and compile the agent workflow graph.
|
||||||
|
|
||||||
|
|||||||
@@ -1,66 +1,57 @@
|
|||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.llm_clients import create_llm_client
|
# Import the abstract tool methods from agent_utils
|
||||||
|
|
||||||
from tradingagents.agents import *
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
|
||||||
from tradingagents.dataflows.utils import safe_ticker_component
|
|
||||||
from tradingagents.agents.utils.agent_states import (
|
|
||||||
AgentState,
|
|
||||||
InvestDebateState,
|
|
||||||
RiskDebateState,
|
|
||||||
)
|
|
||||||
from tradingagents.dataflows.config import set_config
|
|
||||||
|
|
||||||
# Import the new abstract tool methods from agent_utils
|
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
resolve_instrument_identity,
|
|
||||||
get_stock_data,
|
|
||||||
get_indicators,
|
|
||||||
get_verified_market_snapshot,
|
|
||||||
get_fundamentals,
|
|
||||||
get_balance_sheet,
|
get_balance_sheet,
|
||||||
get_cashflow,
|
get_cashflow,
|
||||||
get_income_statement,
|
get_fundamentals,
|
||||||
get_news,
|
|
||||||
get_insider_transactions,
|
|
||||||
get_global_news,
|
get_global_news,
|
||||||
|
get_income_statement,
|
||||||
|
get_indicators,
|
||||||
|
get_insider_transactions,
|
||||||
get_macro_indicators,
|
get_macro_indicators,
|
||||||
get_prediction_markets
|
get_news,
|
||||||
|
get_prediction_markets,
|
||||||
|
get_stock_data,
|
||||||
|
get_verified_market_snapshot,
|
||||||
|
resolve_instrument_identity,
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
from .reflection import Reflector
|
from .reflection import Reflector
|
||||||
|
from .setup import GraphSetup
|
||||||
from .signal_processing import SignalProcessor
|
from .signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TradingAgentsGraph:
|
class TradingAgentsGraph:
|
||||||
"""Main class that orchestrates the trading agents framework."""
|
"""Main class that orchestrates the trading agents framework."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=("market", "social", "news", "fundamentals"),
|
||||||
debug=False,
|
debug=False,
|
||||||
config: Dict[str, Any] = None,
|
config: dict[str, Any] = None,
|
||||||
callbacks: Optional[List] = None,
|
callbacks: list | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize the trading agents graph and components.
|
"""Initialize the trading agents graph and components.
|
||||||
|
|
||||||
@@ -138,7 +129,7 @@ class TradingAgentsGraph:
|
|||||||
self.graph = self.workflow.compile()
|
self.graph = self.workflow.compile()
|
||||||
self._checkpointer_ctx = None
|
self._checkpointer_ctx = None
|
||||||
|
|
||||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
def _get_provider_kwargs(self) -> dict[str, Any]:
|
||||||
"""Get provider-specific kwargs for LLM client creation."""
|
"""Get provider-specific kwargs for LLM client creation."""
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
provider = self.config.get("llm_provider", "").lower()
|
provider = self.config.get("llm_provider", "").lower()
|
||||||
@@ -167,7 +158,7 @@ class TradingAgentsGraph:
|
|||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
def _create_tool_nodes(self) -> dict[str, ToolNode]:
|
||||||
"""Create tool nodes for different data sources using abstract methods."""
|
"""Create tool nodes for different data sources using abstract methods."""
|
||||||
return {
|
return {
|
||||||
"market": ToolNode(
|
"market": ToolNode(
|
||||||
@@ -233,7 +224,7 @@ class TradingAgentsGraph:
|
|||||||
def _fetch_returns(
|
def _fetch_returns(
|
||||||
self, ticker: str, trade_date: str, holding_days: int = 5,
|
self, ticker: str, trade_date: str, holding_days: int = 5,
|
||||||
benchmark: str = "SPY",
|
benchmark: str = "SPY",
|
||||||
) -> Tuple[Optional[float], Optional[float], Optional[int]]:
|
) -> tuple[float | None, float | None, int | None]:
|
||||||
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
|
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
|
||||||
|
|
||||||
``benchmark`` is the index used as the alpha baseline (resolved by the
|
``benchmark`` is the index used as the alpha baseline (resolved by the
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ class NormalizedChatAnthropic(ChatAnthropic):
|
|||||||
class AnthropicClient(BaseLLMClient):
|
class AnthropicClient(BaseLLMClient):
|
||||||
"""Client for Anthropic Claude models."""
|
"""Client for Anthropic Claude models."""
|
||||||
|
|
||||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
def __init__(self, model: str, base_url: str | None = None, **kwargs):
|
||||||
super().__init__(model, base_url, **kwargs)
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
def get_llm(self) -> Any:
|
def get_llm(self) -> Any:
|
||||||
|
|||||||
@@ -11,10 +11,7 @@ prompts for it automatically instead of failing on first API call.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional
|
PROVIDER_API_KEY_ENV: dict[str, str | None] = {
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
|
||||||
"openai": "OPENAI_API_KEY",
|
"openai": "OPENAI_API_KEY",
|
||||||
"anthropic": "ANTHROPIC_API_KEY",
|
"anthropic": "ANTHROPIC_API_KEY",
|
||||||
"google": "GOOGLE_API_KEY",
|
"google": "GOOGLE_API_KEY",
|
||||||
@@ -47,7 +44,7 @@ PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_env(provider: str) -> Optional[str]:
|
def get_api_key_env(provider: str) -> str | None:
|
||||||
"""Return the env var name for `provider`'s API key, or None if not applicable.
|
"""Return the env var name for `provider`'s API key, or None if not applicable.
|
||||||
|
|
||||||
Unknown providers also return None — callers should treat that as
|
Unknown providers also return None — callers should treat that as
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from langchain_openai import AzureChatOpenAI
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
from .base_client import BaseLLMClient, normalize_content
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
from .validators import validate_model
|
|
||||||
|
|
||||||
_PASSTHROUGH_KWARGS = (
|
_PASSTHROUGH_KWARGS = (
|
||||||
"timeout", "max_retries", "api_key", "reasoning_effort", "temperature",
|
"timeout", "max_retries", "api_key", "reasoning_effort", "temperature",
|
||||||
@@ -29,7 +28,7 @@ class AzureOpenAIClient(BaseLLMClient):
|
|||||||
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
|
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
def __init__(self, model: str, base_url: str | None = None, **kwargs):
|
||||||
super().__init__(model, base_url, **kwargs)
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
def get_llm(self) -> Any:
|
def get_llm(self) -> Any:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Optional
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def normalize_content(response):
|
def normalize_content(response):
|
||||||
@@ -25,7 +25,7 @@ def normalize_content(response):
|
|||||||
class BaseLLMClient(ABC):
|
class BaseLLMClient(ABC):
|
||||||
"""Abstract base class for LLM clients."""
|
"""Abstract base class for LLM clients."""
|
||||||
|
|
||||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
def __init__(self, model: str, base_url: str | None = None, **kwargs):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
StructuredMethod = Literal[
|
StructuredMethod = Literal[
|
||||||
"function_calling", # uses tools; respects supports_tool_choice
|
"function_calling", # uses tools; respects supports_tool_choice
|
||||||
"json_mode", # uses response_format={"type":"json_object"}
|
"json_mode", # uses response_format={"type":"json_object"}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .base_client import BaseLLMClient
|
from .base_client import BaseLLMClient
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
base_url: Optional[str] = None,
|
base_url: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BaseLLMClient:
|
) -> BaseLLMClient:
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI):
|
|||||||
class GoogleClient(BaseLLMClient):
|
class GoogleClient(BaseLLMClient):
|
||||||
"""Client for Google Gemini models."""
|
"""Client for Google Gemini models."""
|
||||||
|
|
||||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
def __init__(self, model: str, base_url: str | None = None, **kwargs):
|
||||||
super().__init__(model, base_url, **kwargs)
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
def get_llm(self) -> Any:
|
def get_llm(self) -> Any:
|
||||||
|
|||||||
@@ -2,14 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, List, Tuple
|
ModelOption = tuple[str, str]
|
||||||
|
ProviderModeOptions = dict[str, dict[str, list[ModelOption]]]
|
||||||
ModelOption = Tuple[str, str]
|
|
||||||
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
|
|
||||||
|
|
||||||
# Providers that serve many / frequently-changing models: offer only "Custom
|
# Providers that serve many / frequently-changing models: offer only "Custom
|
||||||
# model ID" rather than a list that goes stale.
|
# model ID" rather than a list that goes stale.
|
||||||
_CUSTOM_ONLY: Dict[str, List[ModelOption]] = {
|
_CUSTOM_ONLY: dict[str, list[ModelOption]] = {
|
||||||
"quick": [("Custom model ID", "custom")],
|
"quick": [("Custom model ID", "custom")],
|
||||||
"deep": [("Custom model ID", "custom")],
|
"deep": [("Custom model ID", "custom")],
|
||||||
}
|
}
|
||||||
@@ -18,7 +16,7 @@ _CUSTOM_ONLY: Dict[str, List[ModelOption]] = {
|
|||||||
# Shared model list for GLM via Z.AI (international) and BigModel (China).
|
# Shared model list for GLM via Z.AI (international) and BigModel (China).
|
||||||
# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides).
|
# Source: docs.z.ai (GLM Coding Plan supported models + LLM guides).
|
||||||
# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}.
|
# All GLM 4.7+ entries support thinking mode via thinking={"type":"enabled"}.
|
||||||
_GLM_MODELS: Dict[str, List[ModelOption]] = {
|
_GLM_MODELS: dict[str, list[ModelOption]] = {
|
||||||
"quick": [
|
"quick": [
|
||||||
("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"),
|
("GLM-5-Turbo - Fast, switchable thinking modes", "glm-5-turbo"),
|
||||||
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
|
("GLM-4.7 - Previous-gen flagship", "glm-4.7"),
|
||||||
@@ -44,7 +42,7 @@ _GLM_MODELS: Dict[str, List[ModelOption]] = {
|
|||||||
# the backing model. Users who want a specific generation pick it
|
# the backing model. Users who want a specific generation pick it
|
||||||
# explicitly; users who really want auto-latest can enter the alias via
|
# explicitly; users who really want auto-latest can enter the alias via
|
||||||
# "Custom model ID".
|
# "Custom model ID".
|
||||||
_QWEN_MODELS: Dict[str, List[ModelOption]] = {
|
_QWEN_MODELS: dict[str, list[ModelOption]] = {
|
||||||
"quick": [
|
"quick": [
|
||||||
("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"),
|
("Qwen 3.6 Flash - Latest fast, agentic coding + vision-language", "qwen3.6-flash"),
|
||||||
("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"),
|
("Qwen 3.5 Flash - Previous-gen fast", "qwen3.5-flash"),
|
||||||
@@ -62,7 +60,7 @@ _QWEN_MODELS: Dict[str, List[ModelOption]] = {
|
|||||||
# Shared model list for MiniMax's global and CN endpoints (same IDs).
|
# Shared model list for MiniMax's global and CN endpoints (same IDs).
|
||||||
# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api.
|
# Full official lineup per platform.minimax.io/docs/api-reference/text-openai-api.
|
||||||
# All M2.x models share a 204,800-token context window.
|
# All M2.x models share a 204,800-token context window.
|
||||||
_MINIMAX_MODELS: Dict[str, List[ModelOption]] = {
|
_MINIMAX_MODELS: dict[str, list[ModelOption]] = {
|
||||||
"quick": [
|
"quick": [
|
||||||
("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"),
|
("MiniMax-M2.7-highspeed - Faster M2.7, 204K ctx, ~100 TPS", "MiniMax-M2.7-highspeed"),
|
||||||
("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"),
|
("MiniMax-M2.5-highspeed - Previous-gen highspeed, 204K ctx", "MiniMax-M2.5-highspeed"),
|
||||||
@@ -198,12 +196,12 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
|
def get_model_options(provider: str, mode: str) -> list[ModelOption]:
|
||||||
"""Return shared model options for a provider and selection mode."""
|
"""Return shared model options for a provider and selection mode."""
|
||||||
return MODEL_OPTIONS[provider.lower()][mode]
|
return MODEL_OPTIONS[provider.lower()][mode]
|
||||||
|
|
||||||
|
|
||||||
def get_known_models() -> Dict[str, List[str]]:
|
def get_known_models() -> dict[str, list[str]]:
|
||||||
"""Build known model names from the shared CLI catalog."""
|
"""Build known model names from the shared CLI catalog."""
|
||||||
return {
|
return {
|
||||||
provider: sorted(
|
provider: sorted(
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from .model_catalog import get_known_models
|
from .model_catalog import get_known_models
|
||||||
|
|
||||||
|
|
||||||
# Providers whose model names are user-defined (local servers, relays, hosted
|
# Providers whose model names are user-defined (local servers, relays, hosted
|
||||||
# OpenAI-compatible endpoints serving many models), so any model string is
|
# OpenAI-compatible endpoints serving many models), so any model string is
|
||||||
# accepted without warning.
|
# accepted without warning.
|
||||||
|
|||||||
Reference in New Issue
Block a user