feat: add LangGraph checkpoint resume for crash recovery (#594)

Long analyses can take many minutes; a crash or interruption forced users
to re-run from scratch and re-pay every LLM call.  This adds an opt-in
checkpoint layer backed by per-ticker SQLite databases so the graph
resumes from the last successful node.

How to use:
- CLI:    tradingagents analyze --checkpoint
- CLI:    tradingagents analyze --clear-checkpoints
- Python: config["checkpoint_enabled"] = True

Lifecycle:
- propagate() recompiles the graph with a SqliteSaver when enabled and
  injects a deterministic thread_id derived from ticker+date so the
  same ticker+date resumes while a different date starts fresh.
- On successful completion the per-thread checkpoint rows are cleared.
- The context manager is closed in a try/finally so a crash never
  leaks the SQLite connection or leaves the graph in checkpoint mode.

Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db
(override via TRADINGAGENTS_CACHE_DIR).

The checkpointer module is new (tradingagents/graph/checkpointer.py)
and the GraphSetup now returns the uncompiled workflow so it can be
recompiled with a saver when needed.

Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify
the crash/resume cycle and that a different date starts fresh.
This commit is contained in:
Yijia-Xiao
2026-04-25 08:39:27 +00:00
parent ebd2e12e67
commit 4cbd4b086f
9 changed files with 349 additions and 21 deletions

View File

@@ -0,0 +1,147 @@
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
import sqlite3
import tempfile
import unittest
from pathlib import Path
from typing import TypedDict
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph
from tradingagents.graph.checkpointer import (
checkpoint_step,
clear_checkpoint,
get_checkpointer,
has_checkpoint,
thread_id,
)
# Mutable flag to simulate crash on first run
_should_crash = False
class _SimpleState(TypedDict):
count: int
def _node_a(state: _SimpleState) -> dict:
return {"count": state["count"] + 1}
def _node_b(state: _SimpleState) -> dict:
if _should_crash:
raise RuntimeError("simulated mid-analysis crash")
return {"count": state["count"] + 10}
def _build_graph() -> StateGraph:
builder = StateGraph(_SimpleState)
builder.add_node("analyst", _node_a)
builder.add_node("trader", _node_b)
builder.set_entry_point("analyst")
builder.add_edge("analyst", "trader")
builder.add_edge("trader", END)
return builder
class TestCheckpointResume(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.ticker = "TEST"
self.date = "2026-04-20"
def test_crash_and_resume(self):
"""Crash at 'trader' node, then resume from checkpoint."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Run 1: crash at trader node
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
# Checkpoint should exist at step 1 (analyst completed)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
self.assertEqual(step, 1)
# Run 2: resume — trader succeeds this time
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke(None, config=cfg)
# analyst added 1, trader added 10 → 11
self.assertEqual(result["count"], 11)
def test_clear_checkpoint_allows_fresh_start(self):
"""After clearing, the graph starts from scratch."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Create a checkpoint by crashing
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Clear it
clear_checkpoint(self.tmpdir, self.ticker, self.date)
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Fresh run succeeds from scratch
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config=cfg)
self.assertEqual(result["count"], 11)
def test_different_date_starts_fresh(self):
"""A different date must NOT resume from an existing checkpoint."""
global _should_crash
builder = _build_graph()
date2 = "2026-04-21"
# Run with date1 — crash to leave a checkpoint
_should_crash = True
tid1 = thread_id(self.ticker, self.date)
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# date2 should have no checkpoint
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
# Run with date2 — should start fresh and succeed
_should_crash = False
tid2 = thread_id(self.ticker, date2)
self.assertNotEqual(tid1, tid2)
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
# Fresh run: analyst +1, trader +10 = 11
self.assertEqual(result["count"], 11)
# Original date checkpoint still exists (untouched)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
if __name__ == "__main__":
unittest.main()

View File

@@ -629,7 +629,9 @@ class TestLegacyRemoval:
create_portfolio_manager(mock_llm, memory=MagicMock())
def test_full_pipeline_no_regression(self, tmp_path):
"""propagate() completes without AttributeError after legacy cleanup."""
"""propagate() completes and stores the decision after the redesign."""
import functools
fake_state = {
"final_trade_decision": "Rating: Buy\nBuy NVDA.",
"company_of_interest": "NVDA",
@@ -660,6 +662,11 @@ class TestLegacyRemoval:
mock_graph.propagator.create_initial_state.return_value = fake_state
mock_graph.propagator.get_graph_args.return_value = {}
mock_graph.signal_processor.process_signal.return_value = "Buy"
# Bind the real _run_graph so propagate's call to self._run_graph executes
# the actual write path instead of the auto-MagicMock.
mock_graph._run_graph = functools.partial(
TradingAgentsGraph._run_graph, mock_graph
)
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
entries = mock_graph.memory_log.load_entries()
assert len(entries) == 1