feat: add LangGraph checkpoint resume for crash recovery (#594)

Long analyses can take many minutes; a crash or interruption forced users
to re-run from scratch and re-pay every LLM call.  This adds an opt-in
checkpoint layer backed by per-ticker SQLite databases so the graph
resumes from the last successful node.

How to use:
- CLI:    tradingagents analyze --checkpoint
- CLI:    tradingagents analyze --clear-checkpoints
- Python: config["checkpoint_enabled"] = True

Lifecycle:
- propagate() recompiles the graph with a SqliteSaver when enabled and
  injects a deterministic thread_id derived from ticker+date so the
  same ticker+date resumes while a different date starts fresh.
- On successful completion the per-thread checkpoint rows are cleared.
- The context manager is closed in a try/finally so a crash never
  leaks the SQLite connection or leaves the graph in checkpoint mode.

Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db
(override via TRADINGAGENTS_CACHE_DIR).

The checkpointer module is new (tradingagents/graph/checkpointer.py)
and the GraphSetup now returns the uncompiled workflow so it can be
recompiled with a saver when needed.

Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify
the crash/resume cycle and that a different date starts fresh.
This commit is contained in:
Yijia-Xiao
2026-04-25 08:39:27 +00:00
parent ebd2e12e67
commit 4cbd4b086f
9 changed files with 349 additions and 21 deletions

View File

@@ -183,7 +183,7 @@ An interface will appear showing results as they load, letting you track the age
### Implementation Details ### Implementation Details
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama. We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise.
### Python Usage ### Python Usage
@@ -207,7 +207,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, openrouter, ollama, azure
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
config["max_debate_rounds"] = 2 config["max_debate_rounds"] = 2
@@ -219,6 +219,34 @@ print(decision)
See `tradingagents/default_config.py` for all configuration options. See `tradingagents/default_config.py` for all configuration options.
## Persistence and Recovery
TradingAgents persists two kinds of state across runs.
### Decision log
The decision log is always on. Each completed run appends its decision to `~/.tradingagents/memory/trading_memory.md`. On the next run for the same ticker, TradingAgents fetches the realised return (raw and alpha vs SPY), generates a one-paragraph reflection, and injects the most recent same-ticker decisions plus recent cross-ticker lessons into the Portfolio Manager prompt, so each analysis carries forward what worked and what didn't.
Override the path with `TRADINGAGENTS_MEMORY_LOG_PATH`.
### Checkpoint resume
Checkpoint resume is opt-in via `--checkpoint`. When enabled, LangGraph saves state after each node so a crashed or interrupted run resumes from the last successful step instead of starting over. On a resume run you will see `Resuming from step N for <TICKER> on <date>` in the logs; on a new run you will see `Starting fresh`. Checkpoints are cleared automatically on successful completion.
Per-ticker SQLite databases live at `~/.tradingagents/cache/checkpoints/<TICKER>.db` (override the base with `TRADINGAGENTS_CACHE_DIR`). Use `--clear-checkpoints` to reset all of them before a run.
```bash
tradingagents analyze --checkpoint # enable for this run
tradingagents analyze --clear-checkpoints # reset before running
```
```python
config = DEFAULT_CONFIG.copy()
config["checkpoint_enabled"] = True
ta = TradingAgentsGraph(config=config)
_, decision = ta.propagate("NVDA", "2026-01-15")
```
## Contributing ## Contributing
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/). We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).

View File

@@ -926,7 +926,7 @@ def format_tool_args(args, max_length=80) -> str:
return result[:max_length - 3] + "..." return result[:max_length - 3] + "..."
return result return result
def run_analysis(): def run_analysis(checkpoint: bool = False):
# First get all user selections # First get all user selections
selections = get_user_selections() selections = get_user_selections()
@@ -943,6 +943,7 @@ def run_analysis():
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
config["anthropic_effort"] = selections.get("anthropic_effort") config["anthropic_effort"] = selections.get("anthropic_effort")
config["output_language"] = selections.get("output_language", "English") config["output_language"] = selections.get("output_language", "English")
config["checkpoint_enabled"] = checkpoint
# Create stats callback handler for tracking LLM/tool calls # Create stats callback handler for tracking LLM/tool calls
stats_handler = StatsCallbackHandler() stats_handler = StatsCallbackHandler()
@@ -1197,8 +1198,23 @@ def run_analysis():
@app.command() @app.command()
def analyze(): def analyze(
run_analysis() checkpoint: bool = typer.Option(
False,
"--checkpoint",
help="Enable checkpoint/resume: save state after each node so a crashed run can resume.",
),
clear_checkpoints: bool = typer.Option(
False,
"--clear-checkpoints",
help="Delete all saved checkpoints before running (force fresh start).",
),
):
if clear_checkpoints:
from tradingagents.graph.checkpointer import clear_all_checkpoints
n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"])
console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]")
run_analysis(checkpoint=checkpoint)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -16,6 +16,7 @@ dependencies = [
"langchain-google-genai>=4.0.0", "langchain-google-genai>=4.0.0",
"langchain-openai>=0.3.23", "langchain-openai>=0.3.23",
"langgraph>=0.4.8", "langgraph>=0.4.8",
"langgraph-checkpoint-sqlite>=2.0.0",
"pandas>=2.3.0", "pandas>=2.3.0",
"parsel>=1.10.0", "parsel>=1.10.0",
"pytz>=2025.2", "pytz>=2025.2",

View File

@@ -0,0 +1,147 @@
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
import sqlite3
import tempfile
import unittest
from pathlib import Path
from typing import TypedDict
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph
from tradingagents.graph.checkpointer import (
checkpoint_step,
clear_checkpoint,
get_checkpointer,
has_checkpoint,
thread_id,
)
# Mutable flag to simulate crash on first run
_should_crash = False
class _SimpleState(TypedDict):
count: int
def _node_a(state: _SimpleState) -> dict:
return {"count": state["count"] + 1}
def _node_b(state: _SimpleState) -> dict:
if _should_crash:
raise RuntimeError("simulated mid-analysis crash")
return {"count": state["count"] + 10}
def _build_graph() -> StateGraph:
builder = StateGraph(_SimpleState)
builder.add_node("analyst", _node_a)
builder.add_node("trader", _node_b)
builder.set_entry_point("analyst")
builder.add_edge("analyst", "trader")
builder.add_edge("trader", END)
return builder
class TestCheckpointResume(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.ticker = "TEST"
self.date = "2026-04-20"
def test_crash_and_resume(self):
"""Crash at 'trader' node, then resume from checkpoint."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Run 1: crash at trader node
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
# Checkpoint should exist at step 1 (analyst completed)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
self.assertEqual(step, 1)
# Run 2: resume — trader succeeds this time
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke(None, config=cfg)
# analyst added 1, trader added 10 → 11
self.assertEqual(result["count"], 11)
def test_clear_checkpoint_allows_fresh_start(self):
"""After clearing, the graph starts from scratch."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Create a checkpoint by crashing
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Clear it
clear_checkpoint(self.tmpdir, self.ticker, self.date)
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Fresh run succeeds from scratch
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config=cfg)
self.assertEqual(result["count"], 11)
def test_different_date_starts_fresh(self):
"""A different date must NOT resume from an existing checkpoint."""
global _should_crash
builder = _build_graph()
date2 = "2026-04-21"
# Run with date1 — crash to leave a checkpoint
_should_crash = True
tid1 = thread_id(self.ticker, self.date)
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# date2 should have no checkpoint
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
# Run with date2 — should start fresh and succeed
_should_crash = False
tid2 = thread_id(self.ticker, date2)
self.assertNotEqual(tid1, tid2)
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
# Fresh run: analyst +1, trader +10 = 11
self.assertEqual(result["count"], 11)
# Original date checkpoint still exists (untouched)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
if __name__ == "__main__":
unittest.main()

View File

@@ -629,7 +629,9 @@ class TestLegacyRemoval:
create_portfolio_manager(mock_llm, memory=MagicMock()) create_portfolio_manager(mock_llm, memory=MagicMock())
def test_full_pipeline_no_regression(self, tmp_path): def test_full_pipeline_no_regression(self, tmp_path):
"""propagate() completes without AttributeError after legacy cleanup.""" """propagate() completes and stores the decision after the redesign."""
import functools
fake_state = { fake_state = {
"final_trade_decision": "Rating: Buy\nBuy NVDA.", "final_trade_decision": "Rating: Buy\nBuy NVDA.",
"company_of_interest": "NVDA", "company_of_interest": "NVDA",
@@ -660,6 +662,11 @@ class TestLegacyRemoval:
mock_graph.propagator.create_initial_state.return_value = fake_state mock_graph.propagator.create_initial_state.return_value = fake_state
mock_graph.propagator.get_graph_args.return_value = {} mock_graph.propagator.get_graph_args.return_value = {}
mock_graph.signal_processor.process_signal.return_value = "Buy" mock_graph.signal_processor.process_signal.return_value = "Buy"
# Bind the real _run_graph so propagate's call to self._run_graph executes
# the actual write path instead of the auto-MagicMock.
mock_graph._run_graph = functools.partial(
TradingAgentsGraph._run_graph, mock_graph
)
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10") TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
entries = mock_graph.memory_log.load_entries() entries = mock_graph.memory_log.load_entries()
assert len(entries) == 1 assert len(entries) == 1

View File

@@ -16,6 +16,9 @@ DEFAULT_CONFIG = {
"google_thinking_level": None, # "high", "minimal", etc. "google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low" "openai_reasoning_effort": None, # "medium", "high", "low"
"anthropic_effort": None, # "high", "medium", "low" "anthropic_effort": None, # "high", "medium", "low"
# Checkpoint/resume: when True, LangGraph saves state after each node
# so a crashed run can resume from the last successful step.
"checkpoint_enabled": False,
# Output language for analyst reports and final decision # Output language for analyst reports and final decision
# Internal agent debate stays in English for reasoning quality # Internal agent debate stays in English for reasoning quality
"output_language": "English", "output_language": "English",

View File

@@ -0,0 +1,86 @@
"""LangGraph checkpoint support for resumable analysis runs.
Per-ticker SQLite databases so concurrent tickers don't contend.
"""
from __future__ import annotations
import hashlib
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from langgraph.checkpoint.sqlite import SqliteSaver
def _db_path(data_dir: str | Path, ticker: str) -> Path:
"""Return the SQLite checkpoint DB path for a ticker."""
p = Path(data_dir) / "checkpoints"
p.mkdir(parents=True, exist_ok=True)
return p / f"{ticker.upper()}.db"
def thread_id(ticker: str, date: str) -> str:
"""Deterministic thread ID for a ticker+date pair."""
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
@contextmanager
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
db = _db_path(data_dir, ticker)
conn = sqlite3.connect(str(db), check_same_thread=False)
try:
saver = SqliteSaver(conn)
saver.setup()
yield saver
finally:
conn.close()
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
"""Check whether a resumable checkpoint exists for ticker+date."""
return checkpoint_step(data_dir, ticker, date) is not None
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
"""Return the step number of the latest checkpoint, or None if none exists."""
db = _db_path(data_dir, ticker)
if not db.exists():
return None
tid = thread_id(ticker, date)
with get_checkpointer(data_dir, ticker) as saver:
config = {"configurable": {"thread_id": tid}}
cp = saver.get_tuple(config)
if cp is None:
return None
return cp.metadata.get("step")
def clear_all_checkpoints(data_dir: str | Path) -> int:
"""Remove all checkpoint DBs. Returns number of files deleted."""
cp_dir = Path(data_dir) / "checkpoints"
if not cp_dir.exists():
return 0
dbs = list(cp_dir.glob("*.db"))
for db in dbs:
db.unlink()
return len(dbs)
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
"""Remove checkpoint for a specific ticker+date by deleting the thread's rows."""
db = _db_path(data_dir, ticker)
if not db.exists():
return
tid = thread_id(ticker, date)
conn = sqlite3.connect(str(db))
try:
for table in ("writes", "checkpoints"):
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
conn.commit()
except sqlite3.OperationalError:
pass
finally:
conn.close()

View File

@@ -179,5 +179,4 @@ class GraphSetup:
workflow.add_edge("Portfolio Manager", END) workflow.add_edge("Portfolio Manager", END)
# Compile and return return workflow
return workflow.compile()

View File

@@ -38,6 +38,7 @@ from tradingagents.agents.utils.agent_utils import (
get_global_news get_global_news
) )
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
from .conditional_logic import ConditionalLogic from .conditional_logic import ConditionalLogic
from .setup import GraphSetup from .setup import GraphSetup
from .propagation import Propagator from .propagation import Propagator
@@ -123,8 +124,10 @@ class TradingAgentsGraph:
self.ticker = None self.ticker = None
self.log_states_dict = {} # date to full state dict self.log_states_dict = {} # date to full state dict
# Set up the graph # Set up the graph: keep the workflow for recompilation with a checkpointer.
self.graph = self.graph_setup.setup_graph(selected_analysts) self.workflow = self.graph_setup.setup_graph(selected_analysts)
self.graph = self.workflow.compile()
self._checkpointer_ctx = None
def _get_provider_kwargs(self) -> Dict[str, Any]: def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation.""" """Get provider-specific kwargs for LLM client creation."""
@@ -259,23 +262,58 @@ class TradingAgentsGraph:
self.memory_log.batch_update_with_outcomes(updates) self.memory_log.batch_update_with_outcomes(updates)
def propagate(self, company_name, trade_date): def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date.""" """Run the trading agents graph for a company on a specific date.
When ``checkpoint_enabled`` is set in config, the graph is recompiled
with a per-ticker SqliteSaver so a crashed run can resume from the last
successful node on a subsequent invocation with the same ticker+date.
"""
self.ticker = company_name self.ticker = company_name
# Resolve any pending log entries for this ticker before the pipeline runs. # Resolve any pending memory-log entries for this ticker before the pipeline runs.
# This adds the outcome + reflection from the previous run at zero latency cost.
self._resolve_pending_entries(company_name) self._resolve_pending_entries(company_name)
# Initialize state — inject memory log context for PM # Recompile with a checkpointer if the user opted in.
if self.config.get("checkpoint_enabled"):
self._checkpointer_ctx = get_checkpointer(
self.config["data_cache_dir"], company_name
)
saver = self._checkpointer_ctx.__enter__()
self.graph = self.workflow.compile(checkpointer=saver)
step = checkpoint_step(
self.config["data_cache_dir"], company_name, str(trade_date)
)
if step is not None:
logger.info(
"Resuming from step %d for %s on %s", step, company_name, trade_date
)
else:
logger.info("Starting fresh for %s on %s", company_name, trade_date)
try:
return self._run_graph(company_name, trade_date)
finally:
if self._checkpointer_ctx is not None:
self._checkpointer_ctx.__exit__(None, None, None)
self._checkpointer_ctx = None
self.graph = self.workflow.compile()
def _run_graph(self, company_name, trade_date):
"""Execute the graph and write the resulting state to disk and memory log."""
# Initialize state — inject memory log context for PM.
past_context = self.memory_log.get_past_context(company_name) past_context = self.memory_log.get_past_context(company_name)
init_agent_state = self.propagator.create_initial_state( init_agent_state = self.propagator.create_initial_state(
company_name, trade_date, past_context=past_context company_name, trade_date, past_context=past_context
) )
args = self.propagator.get_graph_args() args = self.propagator.get_graph_args()
# Inject thread_id so same ticker+date resumes, different date starts fresh.
if self.config.get("checkpoint_enabled"):
tid = thread_id(company_name, str(trade_date))
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
if self.debug: if self.debug:
# Debug mode with tracing
trace = [] trace = []
for chunk in self.graph.stream(init_agent_state, **args): for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0: if len(chunk["messages"]) == 0:
@@ -283,26 +321,29 @@ class TradingAgentsGraph:
else: else:
chunk["messages"][-1].pretty_print() chunk["messages"][-1].pretty_print()
trace.append(chunk) trace.append(chunk)
final_state = trace[-1] final_state = trace[-1]
else: else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args) final_state = self.graph.invoke(init_agent_state, **args)
# Store current state for reflection # Store current state for reflection.
self.curr_state = final_state self.curr_state = final_state
# Log state # Log state to disk.
self._log_state(trade_date, final_state) self._log_state(trade_date, final_state)
# Store decision for deferred reflection. # Store decision for deferred reflection on the next same-ticker run.
self.memory_log.store_decision( self.memory_log.store_decision(
ticker=company_name, ticker=company_name,
trade_date=trade_date, trade_date=trade_date,
final_trade_decision=final_state["final_trade_decision"], final_trade_decision=final_state["final_trade_decision"],
) )
# Return decision and processed signal # Clear checkpoint on successful completion to avoid stale state.
if self.config.get("checkpoint_enabled"):
clear_checkpoint(
self.config["data_cache_dir"], company_name, str(trade_date)
)
return final_state, self.process_signal(final_state["final_trade_decision"]) return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state): def _log_state(self, trade_date, final_state):