mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-01 14:33:10 +03:00
Long analyses can take many minutes; a crash or interruption forced users to re-run from scratch and re-pay every LLM call. This adds an opt-in checkpoint layer backed by per-ticker SQLite databases so the graph resumes from the last successful node. How to use: - CLI: tradingagents analyze --checkpoint - CLI: tradingagents analyze --clear-checkpoints - Python: config["checkpoint_enabled"] = True Lifecycle: - propagate() recompiles the graph with a SqliteSaver when enabled and injects a deterministic thread_id derived from ticker+date so the same ticker+date resumes while a different date starts fresh. - On successful completion the per-thread checkpoint rows are cleared. - The context manager is closed in a try/finally so a crash never leaks the SQLite connection or leaves the graph in checkpoint mode. Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db (override via TRADINGAGENTS_CACHE_DIR). The checkpointer module is new (tradingagents/graph/checkpointer.py) and the GraphSetup now returns the uncompiled workflow so it can be recompiled with a saver when needed. Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify the crash/resume cycle and that a different date starts fresh.
148 lines
4.9 KiB
Python
148 lines
4.9 KiB
Python
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
|
|
|
import sqlite3
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from typing import TypedDict
|
|
|
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
from langgraph.graph import END, StateGraph
|
|
|
|
from tradingagents.graph.checkpointer import (
|
|
checkpoint_step,
|
|
clear_checkpoint,
|
|
get_checkpointer,
|
|
has_checkpoint,
|
|
thread_id,
|
|
)
|
|
|
|
# Mutable flag to simulate crash on first run
|
|
_should_crash = False
|
|
|
|
|
|
class _SimpleState(TypedDict):
|
|
count: int
|
|
|
|
|
|
def _node_a(state: _SimpleState) -> dict:
|
|
return {"count": state["count"] + 1}
|
|
|
|
|
|
def _node_b(state: _SimpleState) -> dict:
|
|
if _should_crash:
|
|
raise RuntimeError("simulated mid-analysis crash")
|
|
return {"count": state["count"] + 10}
|
|
|
|
|
|
def _build_graph() -> StateGraph:
|
|
builder = StateGraph(_SimpleState)
|
|
builder.add_node("analyst", _node_a)
|
|
builder.add_node("trader", _node_b)
|
|
builder.set_entry_point("analyst")
|
|
builder.add_edge("analyst", "trader")
|
|
builder.add_edge("trader", END)
|
|
return builder
|
|
|
|
|
|
class TestCheckpointResume(unittest.TestCase):
|
|
def setUp(self):
|
|
self.tmpdir = tempfile.mkdtemp()
|
|
self.ticker = "TEST"
|
|
self.date = "2026-04-20"
|
|
|
|
def test_crash_and_resume(self):
|
|
"""Crash at 'trader' node, then resume from checkpoint."""
|
|
global _should_crash
|
|
builder = _build_graph()
|
|
tid = thread_id(self.ticker, self.date)
|
|
cfg = {"configurable": {"thread_id": tid}}
|
|
|
|
# Run 1: crash at trader node
|
|
_should_crash = True
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
with self.assertRaises(RuntimeError):
|
|
graph.invoke({"count": 0}, config=cfg)
|
|
|
|
# Checkpoint should exist at step 1 (analyst completed)
|
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
|
|
self.assertEqual(step, 1)
|
|
|
|
# Run 2: resume — trader succeeds this time
|
|
_should_crash = False
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
result = graph.invoke(None, config=cfg)
|
|
|
|
# analyst added 1, trader added 10 → 11
|
|
self.assertEqual(result["count"], 11)
|
|
|
|
def test_clear_checkpoint_allows_fresh_start(self):
|
|
"""After clearing, the graph starts from scratch."""
|
|
global _should_crash
|
|
builder = _build_graph()
|
|
tid = thread_id(self.ticker, self.date)
|
|
cfg = {"configurable": {"thread_id": tid}}
|
|
|
|
# Create a checkpoint by crashing
|
|
_should_crash = True
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
with self.assertRaises(RuntimeError):
|
|
graph.invoke({"count": 0}, config=cfg)
|
|
|
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
|
|
# Clear it
|
|
clear_checkpoint(self.tmpdir, self.ticker, self.date)
|
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
|
|
# Fresh run succeeds from scratch
|
|
_should_crash = False
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
result = graph.invoke({"count": 0}, config=cfg)
|
|
|
|
self.assertEqual(result["count"], 11)
|
|
|
|
|
|
def test_different_date_starts_fresh(self):
|
|
"""A different date must NOT resume from an existing checkpoint."""
|
|
global _should_crash
|
|
builder = _build_graph()
|
|
date2 = "2026-04-21"
|
|
|
|
# Run with date1 — crash to leave a checkpoint
|
|
_should_crash = True
|
|
tid1 = thread_id(self.ticker, self.date)
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
with self.assertRaises(RuntimeError):
|
|
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
|
|
|
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
|
|
# date2 should have no checkpoint
|
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
|
|
|
|
# Run with date2 — should start fresh and succeed
|
|
_should_crash = False
|
|
tid2 = thread_id(self.ticker, date2)
|
|
self.assertNotEqual(tid1, tid2)
|
|
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
|
|
|
|
# Fresh run: analyst +1, trader +10 = 11
|
|
self.assertEqual(result["count"], 11)
|
|
|
|
# Original date checkpoint still exists (untouched)
|
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|