"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node.""" import tempfile import unittest from typing import TypedDict from langgraph.graph import END, StateGraph from tradingagents.graph.checkpointer import ( checkpoint_step, clear_checkpoint, get_checkpointer, has_checkpoint, thread_id, ) # Mutable flag to simulate crash on first run _should_crash = False class _SimpleState(TypedDict): count: int def _node_a(state: _SimpleState) -> dict: return {"count": state["count"] + 1} def _node_b(state: _SimpleState) -> dict: if _should_crash: raise RuntimeError("simulated mid-analysis crash") return {"count": state["count"] + 10} def _build_graph() -> StateGraph: builder = StateGraph(_SimpleState) builder.add_node("analyst", _node_a) builder.add_node("trader", _node_b) builder.set_entry_point("analyst") builder.add_edge("analyst", "trader") builder.add_edge("trader", END) return builder class TestCheckpointResume(unittest.TestCase): def setUp(self): self.tmpdir = tempfile.mkdtemp() self.ticker = "TEST" self.date = "2026-04-20" def test_crash_and_resume(self): """Crash at 'trader' node, then resume from checkpoint.""" global _should_crash builder = _build_graph() tid = thread_id(self.ticker, self.date) cfg = {"configurable": {"thread_id": tid}} # Run 1: crash at trader node _should_crash = True with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) with self.assertRaises(RuntimeError): graph.invoke({"count": 0}, config=cfg) # Checkpoint should exist at step 1 (analyst completed) self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) step = checkpoint_step(self.tmpdir, self.ticker, self.date) self.assertEqual(step, 1) # Run 2: resume — trader succeeds this time _should_crash = False with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) result = graph.invoke(None, config=cfg) # analyst added 1, trader added 10 → 11 self.assertEqual(result["count"], 11) def test_clear_checkpoint_allows_fresh_start(self): """After clearing, the graph starts from scratch.""" global _should_crash builder = _build_graph() tid = thread_id(self.ticker, self.date) cfg = {"configurable": {"thread_id": tid}} # Create a checkpoint by crashing _should_crash = True with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) with self.assertRaises(RuntimeError): graph.invoke({"count": 0}, config=cfg) self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) # Clear it clear_checkpoint(self.tmpdir, self.ticker, self.date) self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date)) # Fresh run succeeds from scratch _should_crash = False with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) result = graph.invoke({"count": 0}, config=cfg) self.assertEqual(result["count"], 11) def test_different_date_starts_fresh(self): """A different date must NOT resume from an existing checkpoint.""" global _should_crash builder = _build_graph() date2 = "2026-04-21" # Run with date1 — crash to leave a checkpoint _should_crash = True tid1 = thread_id(self.ticker, self.date) with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) with self.assertRaises(RuntimeError): graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}}) self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) # date2 should have no checkpoint self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2)) # Run with date2 — should start fresh and succeed _should_crash = False tid2 = thread_id(self.ticker, date2) self.assertNotEqual(tid1, tid2) with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}}) # Fresh run: analyst +1, trader +10 = 11 self.assertEqual(result["count"], 11) # Original date checkpoint still exists (untouched) self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) if __name__ == "__main__": unittest.main()