feat: add LangGraph checkpoint resume for crash recovery (#594)

Long analyses can take many minutes; a crash or interruption forced users
to re-run from scratch and re-pay every LLM call.  This adds an opt-in
checkpoint layer backed by per-ticker SQLite databases so the graph
resumes from the last successful node.

How to use:
- CLI:    tradingagents analyze --checkpoint
- CLI:    tradingagents analyze --clear-checkpoints
- Python: config["checkpoint_enabled"] = True

Lifecycle:
- propagate() recompiles the graph with a SqliteSaver when enabled and
  injects a deterministic thread_id derived from ticker+date so the
  same ticker+date resumes while a different date starts fresh.
- On successful completion the per-thread checkpoint rows are cleared.
- The context manager is closed in a try/finally so a crash never
  leaks the SQLite connection or leaves the graph in checkpoint mode.

Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db
(override via TRADINGAGENTS_CACHE_DIR).

The checkpointer module is new (tradingagents/graph/checkpointer.py)
and the GraphSetup now returns the uncompiled workflow so it can be
recompiled with a saver when needed.

Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify
the crash/resume cycle and that a different date starts fresh.
This commit is contained in:
Yijia-Xiao
2026-04-25 08:39:27 +00:00
parent ebd2e12e67
commit 4cbd4b086f
9 changed files with 349 additions and 21 deletions

View File

@@ -183,7 +183,7 @@ An interface will appear showing results as they load, letting you track the age
### Implementation Details
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama.
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise.
### Python Usage
@@ -207,7 +207,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, openrouter, ollama, azure
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
config["max_debate_rounds"] = 2
@@ -219,6 +219,34 @@ print(decision)
See `tradingagents/default_config.py` for all configuration options.
## Persistence and Recovery
TradingAgents persists two kinds of state across runs.
### Decision log
The decision log is always on. Each completed run appends its decision to `~/.tradingagents/memory/trading_memory.md`. On the next run for the same ticker, TradingAgents fetches the realised return (raw and alpha vs SPY), generates a one-paragraph reflection, and injects the most recent same-ticker decisions plus recent cross-ticker lessons into the Portfolio Manager prompt, so each analysis carries forward what worked and what didn't.
Override the path with `TRADINGAGENTS_MEMORY_LOG_PATH`.
### Checkpoint resume
Checkpoint resume is opt-in via `--checkpoint`. When enabled, LangGraph saves state after each node so a crashed or interrupted run resumes from the last successful step instead of starting over. On a resume run you will see `Resuming from step N for <TICKER> on <date>` in the logs; on a new run you will see `Starting fresh`. Checkpoints are cleared automatically on successful completion.
Per-ticker SQLite databases live at `~/.tradingagents/cache/checkpoints/<TICKER>.db` (override the base with `TRADINGAGENTS_CACHE_DIR`). Use `--clear-checkpoints` to reset all of them before a run.
```bash
tradingagents analyze --checkpoint # enable for this run
tradingagents analyze --clear-checkpoints # reset before running
```
```python
config = DEFAULT_CONFIG.copy()
config["checkpoint_enabled"] = True
ta = TradingAgentsGraph(config=config)
_, decision = ta.propagate("NVDA", "2026-01-15")
```
## Contributing
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).

View File

@@ -926,7 +926,7 @@ def format_tool_args(args, max_length=80) -> str:
return result[:max_length - 3] + "..."
return result
def run_analysis():
def run_analysis(checkpoint: bool = False):
# First get all user selections
selections = get_user_selections()
@@ -943,6 +943,7 @@ def run_analysis():
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
config["anthropic_effort"] = selections.get("anthropic_effort")
config["output_language"] = selections.get("output_language", "English")
config["checkpoint_enabled"] = checkpoint
# Create stats callback handler for tracking LLM/tool calls
stats_handler = StatsCallbackHandler()
@@ -1197,8 +1198,23 @@ def run_analysis():
@app.command()
def analyze():
run_analysis()
def analyze(
checkpoint: bool = typer.Option(
False,
"--checkpoint",
help="Enable checkpoint/resume: save state after each node so a crashed run can resume.",
),
clear_checkpoints: bool = typer.Option(
False,
"--clear-checkpoints",
help="Delete all saved checkpoints before running (force fresh start).",
),
):
if clear_checkpoints:
from tradingagents.graph.checkpointer import clear_all_checkpoints
n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"])
console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]")
run_analysis(checkpoint=checkpoint)
if __name__ == "__main__":

View File

@@ -16,6 +16,7 @@ dependencies = [
"langchain-google-genai>=4.0.0",
"langchain-openai>=0.3.23",
"langgraph>=0.4.8",
"langgraph-checkpoint-sqlite>=2.0.0",
"pandas>=2.3.0",
"parsel>=1.10.0",
"pytz>=2025.2",

View File

@@ -0,0 +1,147 @@
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
import sqlite3
import tempfile
import unittest
from pathlib import Path
from typing import TypedDict
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph
from tradingagents.graph.checkpointer import (
checkpoint_step,
clear_checkpoint,
get_checkpointer,
has_checkpoint,
thread_id,
)
# Mutable flag to simulate crash on first run
_should_crash = False
class _SimpleState(TypedDict):
count: int
def _node_a(state: _SimpleState) -> dict:
return {"count": state["count"] + 1}
def _node_b(state: _SimpleState) -> dict:
if _should_crash:
raise RuntimeError("simulated mid-analysis crash")
return {"count": state["count"] + 10}
def _build_graph() -> StateGraph:
builder = StateGraph(_SimpleState)
builder.add_node("analyst", _node_a)
builder.add_node("trader", _node_b)
builder.set_entry_point("analyst")
builder.add_edge("analyst", "trader")
builder.add_edge("trader", END)
return builder
class TestCheckpointResume(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.ticker = "TEST"
self.date = "2026-04-20"
def test_crash_and_resume(self):
"""Crash at 'trader' node, then resume from checkpoint."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Run 1: crash at trader node
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
# Checkpoint should exist at step 1 (analyst completed)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
self.assertEqual(step, 1)
# Run 2: resume — trader succeeds this time
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke(None, config=cfg)
# analyst added 1, trader added 10 → 11
self.assertEqual(result["count"], 11)
def test_clear_checkpoint_allows_fresh_start(self):
"""After clearing, the graph starts from scratch."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Create a checkpoint by crashing
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Clear it
clear_checkpoint(self.tmpdir, self.ticker, self.date)
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Fresh run succeeds from scratch
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config=cfg)
self.assertEqual(result["count"], 11)
def test_different_date_starts_fresh(self):
"""A different date must NOT resume from an existing checkpoint."""
global _should_crash
builder = _build_graph()
date2 = "2026-04-21"
# Run with date1 — crash to leave a checkpoint
_should_crash = True
tid1 = thread_id(self.ticker, self.date)
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# date2 should have no checkpoint
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
# Run with date2 — should start fresh and succeed
_should_crash = False
tid2 = thread_id(self.ticker, date2)
self.assertNotEqual(tid1, tid2)
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
# Fresh run: analyst +1, trader +10 = 11
self.assertEqual(result["count"], 11)
# Original date checkpoint still exists (untouched)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
if __name__ == "__main__":
unittest.main()

View File

@@ -629,7 +629,9 @@ class TestLegacyRemoval:
create_portfolio_manager(mock_llm, memory=MagicMock())
def test_full_pipeline_no_regression(self, tmp_path):
"""propagate() completes without AttributeError after legacy cleanup."""
"""propagate() completes and stores the decision after the redesign."""
import functools
fake_state = {
"final_trade_decision": "Rating: Buy\nBuy NVDA.",
"company_of_interest": "NVDA",
@@ -660,6 +662,11 @@ class TestLegacyRemoval:
mock_graph.propagator.create_initial_state.return_value = fake_state
mock_graph.propagator.get_graph_args.return_value = {}
mock_graph.signal_processor.process_signal.return_value = "Buy"
# Bind the real _run_graph so propagate's call to self._run_graph executes
# the actual write path instead of the auto-MagicMock.
mock_graph._run_graph = functools.partial(
TradingAgentsGraph._run_graph, mock_graph
)
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
entries = mock_graph.memory_log.load_entries()
assert len(entries) == 1

View File

@@ -16,6 +16,9 @@ DEFAULT_CONFIG = {
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"
"anthropic_effort": None, # "high", "medium", "low"
# Checkpoint/resume: when True, LangGraph saves state after each node
# so a crashed run can resume from the last successful step.
"checkpoint_enabled": False,
# Output language for analyst reports and final decision
# Internal agent debate stays in English for reasoning quality
"output_language": "English",

View File

@@ -0,0 +1,86 @@
"""LangGraph checkpoint support for resumable analysis runs.
Per-ticker SQLite databases so concurrent tickers don't contend.
"""
from __future__ import annotations
import hashlib
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from langgraph.checkpoint.sqlite import SqliteSaver
def _db_path(data_dir: str | Path, ticker: str) -> Path:
"""Return the SQLite checkpoint DB path for a ticker."""
p = Path(data_dir) / "checkpoints"
p.mkdir(parents=True, exist_ok=True)
return p / f"{ticker.upper()}.db"
def thread_id(ticker: str, date: str) -> str:
"""Deterministic thread ID for a ticker+date pair."""
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
@contextmanager
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
db = _db_path(data_dir, ticker)
conn = sqlite3.connect(str(db), check_same_thread=False)
try:
saver = SqliteSaver(conn)
saver.setup()
yield saver
finally:
conn.close()
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
"""Check whether a resumable checkpoint exists for ticker+date."""
return checkpoint_step(data_dir, ticker, date) is not None
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
"""Return the step number of the latest checkpoint, or None if none exists."""
db = _db_path(data_dir, ticker)
if not db.exists():
return None
tid = thread_id(ticker, date)
with get_checkpointer(data_dir, ticker) as saver:
config = {"configurable": {"thread_id": tid}}
cp = saver.get_tuple(config)
if cp is None:
return None
return cp.metadata.get("step")
def clear_all_checkpoints(data_dir: str | Path) -> int:
"""Remove all checkpoint DBs. Returns number of files deleted."""
cp_dir = Path(data_dir) / "checkpoints"
if not cp_dir.exists():
return 0
dbs = list(cp_dir.glob("*.db"))
for db in dbs:
db.unlink()
return len(dbs)
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
"""Remove checkpoint for a specific ticker+date by deleting the thread's rows."""
db = _db_path(data_dir, ticker)
if not db.exists():
return
tid = thread_id(ticker, date)
conn = sqlite3.connect(str(db))
try:
for table in ("writes", "checkpoints"):
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
conn.commit()
except sqlite3.OperationalError:
pass
finally:
conn.close()

View File

@@ -179,5 +179,4 @@ class GraphSetup:
workflow.add_edge("Portfolio Manager", END)
# Compile and return
return workflow.compile()
return workflow

View File

@@ -38,6 +38,7 @@ from tradingagents.agents.utils.agent_utils import (
get_global_news
)
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
@@ -123,8 +124,10 @@ class TradingAgentsGraph:
self.ticker = None
self.log_states_dict = {} # date to full state dict
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
# Set up the graph: keep the workflow for recompilation with a checkpointer.
self.workflow = self.graph_setup.setup_graph(selected_analysts)
self.graph = self.workflow.compile()
self._checkpointer_ctx = None
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
@@ -259,23 +262,58 @@ class TradingAgentsGraph:
self.memory_log.batch_update_with_outcomes(updates)
def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date."""
"""Run the trading agents graph for a company on a specific date.
When ``checkpoint_enabled`` is set in config, the graph is recompiled
with a per-ticker SqliteSaver so a crashed run can resume from the last
successful node on a subsequent invocation with the same ticker+date.
"""
self.ticker = company_name
# Resolve any pending log entries for this ticker before the pipeline runs.
# This adds the outcome + reflection from the previous run at zero latency cost.
# Resolve any pending memory-log entries for this ticker before the pipeline runs.
self._resolve_pending_entries(company_name)
# Initialize state — inject memory log context for PM
# Recompile with a checkpointer if the user opted in.
if self.config.get("checkpoint_enabled"):
self._checkpointer_ctx = get_checkpointer(
self.config["data_cache_dir"], company_name
)
saver = self._checkpointer_ctx.__enter__()
self.graph = self.workflow.compile(checkpointer=saver)
step = checkpoint_step(
self.config["data_cache_dir"], company_name, str(trade_date)
)
if step is not None:
logger.info(
"Resuming from step %d for %s on %s", step, company_name, trade_date
)
else:
logger.info("Starting fresh for %s on %s", company_name, trade_date)
try:
return self._run_graph(company_name, trade_date)
finally:
if self._checkpointer_ctx is not None:
self._checkpointer_ctx.__exit__(None, None, None)
self._checkpointer_ctx = None
self.graph = self.workflow.compile()
def _run_graph(self, company_name, trade_date):
"""Execute the graph and write the resulting state to disk and memory log."""
# Initialize state — inject memory log context for PM.
past_context = self.memory_log.get_past_context(company_name)
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date, past_context=past_context
)
args = self.propagator.get_graph_args()
# Inject thread_id so same ticker+date resumes, different date starts fresh.
if self.config.get("checkpoint_enabled"):
tid = thread_id(company_name, str(trade_date))
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
@@ -283,26 +321,29 @@ class TradingAgentsGraph:
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
final_state = trace[-1]
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
# Store current state for reflection
# Store current state for reflection.
self.curr_state = final_state
# Log state
# Log state to disk.
self._log_state(trade_date, final_state)
# Store decision for deferred reflection.
# Store decision for deferred reflection on the next same-ticker run.
self.memory_log.store_decision(
ticker=company_name,
trade_date=trade_date,
final_trade_decision=final_state["final_trade_decision"],
)
# Return decision and processed signal
# Clear checkpoint on successful completion to avoid stale state.
if self.config.get("checkpoint_enabled"):
clear_checkpoint(
self.config["data_cache_dir"], company_name, str(trade_date)
)
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state):