mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-01 14:33:10 +03:00
feat: add LangGraph checkpoint resume for crash recovery (#594)
Long analyses can take many minutes; a crash or interruption forced users to re-run from scratch and re-pay every LLM call. This adds an opt-in checkpoint layer backed by per-ticker SQLite databases so the graph resumes from the last successful node. How to use: - CLI: tradingagents analyze --checkpoint - CLI: tradingagents analyze --clear-checkpoints - Python: config["checkpoint_enabled"] = True Lifecycle: - propagate() recompiles the graph with a SqliteSaver when enabled and injects a deterministic thread_id derived from ticker+date so the same ticker+date resumes while a different date starts fresh. - On successful completion the per-thread checkpoint rows are cleared. - The context manager is closed in a try/finally so a crash never leaks the SQLite connection or leaves the graph in checkpoint mode. Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db (override via TRADINGAGENTS_CACHE_DIR). The checkpointer module is new (tradingagents/graph/checkpointer.py) and the GraphSetup now returns the uncompiled workflow so it can be recompiled with a saver when needed. Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify the crash/resume cycle and that a different date starts fresh.
This commit is contained in:
32
README.md
32
README.md
@@ -183,7 +183,7 @@ An interface will appear showing results as they load, letting you track the age
|
|||||||
|
|
||||||
### Implementation Details
|
### Implementation Details
|
||||||
|
|
||||||
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama.
|
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise.
|
||||||
|
|
||||||
### Python Usage
|
### Python Usage
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama
|
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, openrouter, ollama, azure
|
||||||
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
|
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
|
||||||
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
||||||
config["max_debate_rounds"] = 2
|
config["max_debate_rounds"] = 2
|
||||||
@@ -219,6 +219,34 @@ print(decision)
|
|||||||
|
|
||||||
See `tradingagents/default_config.py` for all configuration options.
|
See `tradingagents/default_config.py` for all configuration options.
|
||||||
|
|
||||||
|
## Persistence and Recovery
|
||||||
|
|
||||||
|
TradingAgents persists two kinds of state across runs.
|
||||||
|
|
||||||
|
### Decision log
|
||||||
|
|
||||||
|
The decision log is always on. Each completed run appends its decision to `~/.tradingagents/memory/trading_memory.md`. On the next run for the same ticker, TradingAgents fetches the realised return (raw and alpha vs SPY), generates a one-paragraph reflection, and injects the most recent same-ticker decisions plus recent cross-ticker lessons into the Portfolio Manager prompt, so each analysis carries forward what worked and what didn't.
|
||||||
|
|
||||||
|
Override the path with `TRADINGAGENTS_MEMORY_LOG_PATH`.
|
||||||
|
|
||||||
|
### Checkpoint resume
|
||||||
|
|
||||||
|
Checkpoint resume is opt-in via `--checkpoint`. When enabled, LangGraph saves state after each node so a crashed or interrupted run resumes from the last successful step instead of starting over. On a resume run you will see `Resuming from step N for <TICKER> on <date>` in the logs; on a new run you will see `Starting fresh`. Checkpoints are cleared automatically on successful completion.
|
||||||
|
|
||||||
|
Per-ticker SQLite databases live at `~/.tradingagents/cache/checkpoints/<TICKER>.db` (override the base with `TRADINGAGENTS_CACHE_DIR`). Use `--clear-checkpoints` to reset all of them before a run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tradingagents analyze --checkpoint # enable for this run
|
||||||
|
tradingagents analyze --clear-checkpoints # reset before running
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = DEFAULT_CONFIG.copy()
|
||||||
|
config["checkpoint_enabled"] = True
|
||||||
|
ta = TradingAgentsGraph(config=config)
|
||||||
|
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||||
|
```
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||||
|
|||||||
22
cli/main.py
22
cli/main.py
@@ -926,7 +926,7 @@ def format_tool_args(args, max_length=80) -> str:
|
|||||||
return result[:max_length - 3] + "..."
|
return result[:max_length - 3] + "..."
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def run_analysis():
|
def run_analysis(checkpoint: bool = False):
|
||||||
# First get all user selections
|
# First get all user selections
|
||||||
selections = get_user_selections()
|
selections = get_user_selections()
|
||||||
|
|
||||||
@@ -943,6 +943,7 @@ def run_analysis():
|
|||||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||||
config["anthropic_effort"] = selections.get("anthropic_effort")
|
config["anthropic_effort"] = selections.get("anthropic_effort")
|
||||||
config["output_language"] = selections.get("output_language", "English")
|
config["output_language"] = selections.get("output_language", "English")
|
||||||
|
config["checkpoint_enabled"] = checkpoint
|
||||||
|
|
||||||
# Create stats callback handler for tracking LLM/tool calls
|
# Create stats callback handler for tracking LLM/tool calls
|
||||||
stats_handler = StatsCallbackHandler()
|
stats_handler = StatsCallbackHandler()
|
||||||
@@ -1197,8 +1198,23 @@ def run_analysis():
|
|||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def analyze():
|
def analyze(
|
||||||
run_analysis()
|
checkpoint: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--checkpoint",
|
||||||
|
help="Enable checkpoint/resume: save state after each node so a crashed run can resume.",
|
||||||
|
),
|
||||||
|
clear_checkpoints: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--clear-checkpoints",
|
||||||
|
help="Delete all saved checkpoints before running (force fresh start).",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
if clear_checkpoints:
|
||||||
|
from tradingagents.graph.checkpointer import clear_all_checkpoints
|
||||||
|
n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"])
|
||||||
|
console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]")
|
||||||
|
run_analysis(checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ dependencies = [
|
|||||||
"langchain-google-genai>=4.0.0",
|
"langchain-google-genai>=4.0.0",
|
||||||
"langchain-openai>=0.3.23",
|
"langchain-openai>=0.3.23",
|
||||||
"langgraph>=0.4.8",
|
"langgraph>=0.4.8",
|
||||||
|
"langgraph-checkpoint-sqlite>=2.0.0",
|
||||||
"pandas>=2.3.0",
|
"pandas>=2.3.0",
|
||||||
"parsel>=1.10.0",
|
"parsel>=1.10.0",
|
||||||
"pytz>=2025.2",
|
"pytz>=2025.2",
|
||||||
|
|||||||
147
tests/test_checkpoint_resume.py
Normal file
147
tests/test_checkpoint_resume.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
from langgraph.graph import END, StateGraph
|
||||||
|
|
||||||
|
from tradingagents.graph.checkpointer import (
|
||||||
|
checkpoint_step,
|
||||||
|
clear_checkpoint,
|
||||||
|
get_checkpointer,
|
||||||
|
has_checkpoint,
|
||||||
|
thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mutable flag to simulate crash on first run
|
||||||
|
_should_crash = False
|
||||||
|
|
||||||
|
|
||||||
|
class _SimpleState(TypedDict):
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
def _node_a(state: _SimpleState) -> dict:
|
||||||
|
return {"count": state["count"] + 1}
|
||||||
|
|
||||||
|
|
||||||
|
def _node_b(state: _SimpleState) -> dict:
|
||||||
|
if _should_crash:
|
||||||
|
raise RuntimeError("simulated mid-analysis crash")
|
||||||
|
return {"count": state["count"] + 10}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_graph() -> StateGraph:
|
||||||
|
builder = StateGraph(_SimpleState)
|
||||||
|
builder.add_node("analyst", _node_a)
|
||||||
|
builder.add_node("trader", _node_b)
|
||||||
|
builder.set_entry_point("analyst")
|
||||||
|
builder.add_edge("analyst", "trader")
|
||||||
|
builder.add_edge("trader", END)
|
||||||
|
return builder
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointResume(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.mkdtemp()
|
||||||
|
self.ticker = "TEST"
|
||||||
|
self.date = "2026-04-20"
|
||||||
|
|
||||||
|
def test_crash_and_resume(self):
|
||||||
|
"""Crash at 'trader' node, then resume from checkpoint."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
tid = thread_id(self.ticker, self.date)
|
||||||
|
cfg = {"configurable": {"thread_id": tid}}
|
||||||
|
|
||||||
|
# Run 1: crash at trader node
|
||||||
|
_should_crash = True
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
# Checkpoint should exist at step 1 (analyst completed)
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
|
||||||
|
self.assertEqual(step, 1)
|
||||||
|
|
||||||
|
# Run 2: resume — trader succeeds this time
|
||||||
|
_should_crash = False
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke(None, config=cfg)
|
||||||
|
|
||||||
|
# analyst added 1, trader added 10 → 11
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
def test_clear_checkpoint_allows_fresh_start(self):
|
||||||
|
"""After clearing, the graph starts from scratch."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
tid = thread_id(self.ticker, self.date)
|
||||||
|
cfg = {"configurable": {"thread_id": tid}}
|
||||||
|
|
||||||
|
# Create a checkpoint by crashing
|
||||||
|
_should_crash = True
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# Clear it
|
||||||
|
clear_checkpoint(self.tmpdir, self.ticker, self.date)
|
||||||
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# Fresh run succeeds from scratch
|
||||||
|
_should_crash = False
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke({"count": 0}, config=cfg)
|
||||||
|
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_date_starts_fresh(self):
|
||||||
|
"""A different date must NOT resume from an existing checkpoint."""
|
||||||
|
global _should_crash
|
||||||
|
builder = _build_graph()
|
||||||
|
date2 = "2026-04-21"
|
||||||
|
|
||||||
|
# Run with date1 — crash to leave a checkpoint
|
||||||
|
_should_crash = True
|
||||||
|
tid1 = thread_id(self.ticker, self.date)
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
|
||||||
|
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
# date2 should have no checkpoint
|
||||||
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
|
||||||
|
|
||||||
|
# Run with date2 — should start fresh and succeed
|
||||||
|
_should_crash = False
|
||||||
|
tid2 = thread_id(self.ticker, date2)
|
||||||
|
self.assertNotEqual(tid1, tid2)
|
||||||
|
|
||||||
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
|
||||||
|
|
||||||
|
# Fresh run: analyst +1, trader +10 = 11
|
||||||
|
self.assertEqual(result["count"], 11)
|
||||||
|
|
||||||
|
# Original date checkpoint still exists (untouched)
|
||||||
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -629,7 +629,9 @@ class TestLegacyRemoval:
|
|||||||
create_portfolio_manager(mock_llm, memory=MagicMock())
|
create_portfolio_manager(mock_llm, memory=MagicMock())
|
||||||
|
|
||||||
def test_full_pipeline_no_regression(self, tmp_path):
|
def test_full_pipeline_no_regression(self, tmp_path):
|
||||||
"""propagate() completes without AttributeError after legacy cleanup."""
|
"""propagate() completes and stores the decision after the redesign."""
|
||||||
|
import functools
|
||||||
|
|
||||||
fake_state = {
|
fake_state = {
|
||||||
"final_trade_decision": "Rating: Buy\nBuy NVDA.",
|
"final_trade_decision": "Rating: Buy\nBuy NVDA.",
|
||||||
"company_of_interest": "NVDA",
|
"company_of_interest": "NVDA",
|
||||||
@@ -660,6 +662,11 @@ class TestLegacyRemoval:
|
|||||||
mock_graph.propagator.create_initial_state.return_value = fake_state
|
mock_graph.propagator.create_initial_state.return_value = fake_state
|
||||||
mock_graph.propagator.get_graph_args.return_value = {}
|
mock_graph.propagator.get_graph_args.return_value = {}
|
||||||
mock_graph.signal_processor.process_signal.return_value = "Buy"
|
mock_graph.signal_processor.process_signal.return_value = "Buy"
|
||||||
|
# Bind the real _run_graph so propagate's call to self._run_graph executes
|
||||||
|
# the actual write path instead of the auto-MagicMock.
|
||||||
|
mock_graph._run_graph = functools.partial(
|
||||||
|
TradingAgentsGraph._run_graph, mock_graph
|
||||||
|
)
|
||||||
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
|
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
|
||||||
entries = mock_graph.memory_log.load_entries()
|
entries = mock_graph.memory_log.load_entries()
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ DEFAULT_CONFIG = {
|
|||||||
"google_thinking_level": None, # "high", "minimal", etc.
|
"google_thinking_level": None, # "high", "minimal", etc.
|
||||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||||
"anthropic_effort": None, # "high", "medium", "low"
|
"anthropic_effort": None, # "high", "medium", "low"
|
||||||
|
# Checkpoint/resume: when True, LangGraph saves state after each node
|
||||||
|
# so a crashed run can resume from the last successful step.
|
||||||
|
"checkpoint_enabled": False,
|
||||||
# Output language for analyst reports and final decision
|
# Output language for analyst reports and final decision
|
||||||
# Internal agent debate stays in English for reasoning quality
|
# Internal agent debate stays in English for reasoning quality
|
||||||
"output_language": "English",
|
"output_language": "English",
|
||||||
|
|||||||
86
tradingagents/graph/checkpointer.py
Normal file
86
tradingagents/graph/checkpointer.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""LangGraph checkpoint support for resumable analysis runs.
|
||||||
|
|
||||||
|
Per-ticker SQLite databases so concurrent tickers don't contend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
|
|
||||||
|
def _db_path(data_dir: str | Path, ticker: str) -> Path:
|
||||||
|
"""Return the SQLite checkpoint DB path for a ticker."""
|
||||||
|
p = Path(data_dir) / "checkpoints"
|
||||||
|
p.mkdir(parents=True, exist_ok=True)
|
||||||
|
return p / f"{ticker.upper()}.db"
|
||||||
|
|
||||||
|
|
||||||
|
def thread_id(ticker: str, date: str) -> str:
|
||||||
|
"""Deterministic thread ID for a ticker+date pair."""
|
||||||
|
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
|
||||||
|
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
conn = sqlite3.connect(str(db), check_same_thread=False)
|
||||||
|
try:
|
||||||
|
saver = SqliteSaver(conn)
|
||||||
|
saver.setup()
|
||||||
|
yield saver
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
|
||||||
|
"""Check whether a resumable checkpoint exists for ticker+date."""
|
||||||
|
return checkpoint_step(data_dir, ticker, date) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
|
||||||
|
"""Return the step number of the latest checkpoint, or None if none exists."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
if not db.exists():
|
||||||
|
return None
|
||||||
|
tid = thread_id(ticker, date)
|
||||||
|
with get_checkpointer(data_dir, ticker) as saver:
|
||||||
|
config = {"configurable": {"thread_id": tid}}
|
||||||
|
cp = saver.get_tuple(config)
|
||||||
|
if cp is None:
|
||||||
|
return None
|
||||||
|
return cp.metadata.get("step")
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all_checkpoints(data_dir: str | Path) -> int:
|
||||||
|
"""Remove all checkpoint DBs. Returns number of files deleted."""
|
||||||
|
cp_dir = Path(data_dir) / "checkpoints"
|
||||||
|
if not cp_dir.exists():
|
||||||
|
return 0
|
||||||
|
dbs = list(cp_dir.glob("*.db"))
|
||||||
|
for db in dbs:
|
||||||
|
db.unlink()
|
||||||
|
return len(dbs)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
|
||||||
|
"""Remove checkpoint for a specific ticker+date by deleting the thread's rows."""
|
||||||
|
db = _db_path(data_dir, ticker)
|
||||||
|
if not db.exists():
|
||||||
|
return
|
||||||
|
tid = thread_id(ticker, date)
|
||||||
|
conn = sqlite3.connect(str(db))
|
||||||
|
try:
|
||||||
|
for table in ("writes", "checkpoints"):
|
||||||
|
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
@@ -179,5 +179,4 @@ class GraphSetup:
|
|||||||
|
|
||||||
workflow.add_edge("Portfolio Manager", END)
|
workflow.add_edge("Portfolio Manager", END)
|
||||||
|
|
||||||
# Compile and return
|
return workflow
|
||||||
return workflow.compile()
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||||||
get_global_news
|
get_global_news
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
from .setup import GraphSetup
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
@@ -123,8 +124,10 @@ class TradingAgentsGraph:
|
|||||||
self.ticker = None
|
self.ticker = None
|
||||||
self.log_states_dict = {} # date to full state dict
|
self.log_states_dict = {} # date to full state dict
|
||||||
|
|
||||||
# Set up the graph
|
# Set up the graph: keep the workflow for recompilation with a checkpointer.
|
||||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
self.workflow = self.graph_setup.setup_graph(selected_analysts)
|
||||||
|
self.graph = self.workflow.compile()
|
||||||
|
self._checkpointer_ctx = None
|
||||||
|
|
||||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||||
"""Get provider-specific kwargs for LLM client creation."""
|
"""Get provider-specific kwargs for LLM client creation."""
|
||||||
@@ -259,23 +262,58 @@ class TradingAgentsGraph:
|
|||||||
self.memory_log.batch_update_with_outcomes(updates)
|
self.memory_log.batch_update_with_outcomes(updates)
|
||||||
|
|
||||||
def propagate(self, company_name, trade_date):
|
def propagate(self, company_name, trade_date):
|
||||||
"""Run the trading agents graph for a company on a specific date."""
|
"""Run the trading agents graph for a company on a specific date.
|
||||||
|
|
||||||
|
When ``checkpoint_enabled`` is set in config, the graph is recompiled
|
||||||
|
with a per-ticker SqliteSaver so a crashed run can resume from the last
|
||||||
|
successful node on a subsequent invocation with the same ticker+date.
|
||||||
|
"""
|
||||||
self.ticker = company_name
|
self.ticker = company_name
|
||||||
|
|
||||||
# Resolve any pending log entries for this ticker before the pipeline runs.
|
# Resolve any pending memory-log entries for this ticker before the pipeline runs.
|
||||||
# This adds the outcome + reflection from the previous run at zero latency cost.
|
|
||||||
self._resolve_pending_entries(company_name)
|
self._resolve_pending_entries(company_name)
|
||||||
|
|
||||||
# Initialize state — inject memory log context for PM
|
# Recompile with a checkpointer if the user opted in.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
self._checkpointer_ctx = get_checkpointer(
|
||||||
|
self.config["data_cache_dir"], company_name
|
||||||
|
)
|
||||||
|
saver = self._checkpointer_ctx.__enter__()
|
||||||
|
self.graph = self.workflow.compile(checkpointer=saver)
|
||||||
|
|
||||||
|
step = checkpoint_step(
|
||||||
|
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||||
|
)
|
||||||
|
if step is not None:
|
||||||
|
logger.info(
|
||||||
|
"Resuming from step %d for %s on %s", step, company_name, trade_date
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Starting fresh for %s on %s", company_name, trade_date)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._run_graph(company_name, trade_date)
|
||||||
|
finally:
|
||||||
|
if self._checkpointer_ctx is not None:
|
||||||
|
self._checkpointer_ctx.__exit__(None, None, None)
|
||||||
|
self._checkpointer_ctx = None
|
||||||
|
self.graph = self.workflow.compile()
|
||||||
|
|
||||||
|
def _run_graph(self, company_name, trade_date):
|
||||||
|
"""Execute the graph and write the resulting state to disk and memory log."""
|
||||||
|
# Initialize state — inject memory log context for PM.
|
||||||
past_context = self.memory_log.get_past_context(company_name)
|
past_context = self.memory_log.get_past_context(company_name)
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date, past_context=past_context
|
company_name, trade_date, past_context=past_context
|
||||||
)
|
)
|
||||||
args = self.propagator.get_graph_args()
|
args = self.propagator.get_graph_args()
|
||||||
|
|
||||||
|
# Inject thread_id so same ticker+date resumes, different date starts fresh.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
tid = thread_id(company_name, str(trade_date))
|
||||||
|
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
# Debug mode with tracing
|
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in self.graph.stream(init_agent_state, **args):
|
for chunk in self.graph.stream(init_agent_state, **args):
|
||||||
if len(chunk["messages"]) == 0:
|
if len(chunk["messages"]) == 0:
|
||||||
@@ -283,26 +321,29 @@ class TradingAgentsGraph:
|
|||||||
else:
|
else:
|
||||||
chunk["messages"][-1].pretty_print()
|
chunk["messages"][-1].pretty_print()
|
||||||
trace.append(chunk)
|
trace.append(chunk)
|
||||||
|
|
||||||
final_state = trace[-1]
|
final_state = trace[-1]
|
||||||
else:
|
else:
|
||||||
# Standard mode without tracing
|
|
||||||
final_state = self.graph.invoke(init_agent_state, **args)
|
final_state = self.graph.invoke(init_agent_state, **args)
|
||||||
|
|
||||||
# Store current state for reflection
|
# Store current state for reflection.
|
||||||
self.curr_state = final_state
|
self.curr_state = final_state
|
||||||
|
|
||||||
# Log state
|
# Log state to disk.
|
||||||
self._log_state(trade_date, final_state)
|
self._log_state(trade_date, final_state)
|
||||||
|
|
||||||
# Store decision for deferred reflection.
|
# Store decision for deferred reflection on the next same-ticker run.
|
||||||
self.memory_log.store_decision(
|
self.memory_log.store_decision(
|
||||||
ticker=company_name,
|
ticker=company_name,
|
||||||
trade_date=trade_date,
|
trade_date=trade_date,
|
||||||
final_trade_decision=final_state["final_trade_decision"],
|
final_trade_decision=final_state["final_trade_decision"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return decision and processed signal
|
# Clear checkpoint on successful completion to avoid stale state.
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
clear_checkpoint(
|
||||||
|
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||||
|
)
|
||||||
|
|
||||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||||
|
|
||||||
def _log_state(self, trade_date, final_state):
|
def _log_state(self, trade_date, final_state):
|
||||||
|
|||||||
Reference in New Issue
Block a user