From 4cbd4b086fd94074d79987c8ca31daec8d33902c Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sat, 25 Apr 2026 08:39:27 +0000 Subject: [PATCH] feat: add LangGraph checkpoint resume for crash recovery (#594) Long analyses can take many minutes; a crash or interruption forced users to re-run from scratch and re-pay every LLM call. This adds an opt-in checkpoint layer backed by per-ticker SQLite databases so the graph resumes from the last successful node. How to use: - CLI: tradingagents analyze --checkpoint - CLI: tradingagents analyze --clear-checkpoints - Python: config["checkpoint_enabled"] = True Lifecycle: - propagate() recompiles the graph with a SqliteSaver when enabled and injects a deterministic thread_id derived from ticker+date so the same ticker+date resumes while a different date starts fresh. - On successful completion the per-thread checkpoint rows are cleared. - The context manager is closed in a try/finally so a crash never leaks the SQLite connection or leaves the graph in checkpoint mode. Storage: ~/.tradingagents/cache/checkpoints/.db (override via TRADINGAGENTS_CACHE_DIR). The checkpointer module is new (tradingagents/graph/checkpointer.py) and the GraphSetup now returns the uncompiled workflow so it can be recompiled with a saver when needed. Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify the crash/resume cycle and that a different date starts fresh. --- README.md | 32 +++++- cli/main.py | 22 +++- pyproject.toml | 1 + tests/test_checkpoint_resume.py | 147 +++++++++++++++++++++++++++ tests/test_memory_log.py | 9 +- tradingagents/default_config.py | 3 + tradingagents/graph/checkpointer.py | 86 ++++++++++++++++ tradingagents/graph/setup.py | 3 +- tradingagents/graph/trading_graph.py | 67 +++++++++--- 9 files changed, 349 insertions(+), 21 deletions(-) create mode 100644 tests/test_checkpoint_resume.py create mode 100644 tradingagents/graph/checkpointer.py diff --git a/README.md b/README.md index 97cbde48..6c8f644e 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ An interface will appear showing results as they load, letting you track the age ### Implementation Details -We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama. +We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Qwen (Alibaba DashScope), GLM (Zhipu), OpenRouter, Ollama for local models, and Azure OpenAI for enterprise. ### Python Usage @@ -207,7 +207,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG config = DEFAULT_CONFIG.copy() -config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama +config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, qwen, glm, openrouter, ollama, azure config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks config["max_debate_rounds"] = 2 @@ -219,6 +219,34 @@ print(decision) See `tradingagents/default_config.py` for all configuration options. +## Persistence and Recovery + +TradingAgents persists two kinds of state across runs. + +### Decision log + +The decision log is always on. Each completed run appends its decision to `~/.tradingagents/memory/trading_memory.md`. On the next run for the same ticker, TradingAgents fetches the realised return (raw and alpha vs SPY), generates a one-paragraph reflection, and injects the most recent same-ticker decisions plus recent cross-ticker lessons into the Portfolio Manager prompt, so each analysis carries forward what worked and what didn't. + +Override the path with `TRADINGAGENTS_MEMORY_LOG_PATH`. + +### Checkpoint resume + +Checkpoint resume is opt-in via `--checkpoint`. When enabled, LangGraph saves state after each node so a crashed or interrupted run resumes from the last successful step instead of starting over. On a resume run you will see `Resuming from step N for on ` in the logs; on a new run you will see `Starting fresh`. Checkpoints are cleared automatically on successful completion. + +Per-ticker SQLite databases live at `~/.tradingagents/cache/checkpoints/.db` (override the base with `TRADINGAGENTS_CACHE_DIR`). Use `--clear-checkpoints` to reset all of them before a run. + +```bash +tradingagents analyze --checkpoint # enable for this run +tradingagents analyze --clear-checkpoints # reset before running +``` + +```python +config = DEFAULT_CONFIG.copy() +config["checkpoint_enabled"] = True +ta = TradingAgentsGraph(config=config) +_, decision = ta.propagate("NVDA", "2026-01-15") +``` + ## Contributing We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/). diff --git a/cli/main.py b/cli/main.py index 6e838fc8..534f5037 100644 --- a/cli/main.py +++ b/cli/main.py @@ -926,7 +926,7 @@ def format_tool_args(args, max_length=80) -> str: return result[:max_length - 3] + "..." return result -def run_analysis(): +def run_analysis(checkpoint: bool = False): # First get all user selections selections = get_user_selections() @@ -943,6 +943,7 @@ def run_analysis(): config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["anthropic_effort"] = selections.get("anthropic_effort") config["output_language"] = selections.get("output_language", "English") + config["checkpoint_enabled"] = checkpoint # Create stats callback handler for tracking LLM/tool calls stats_handler = StatsCallbackHandler() @@ -1197,8 +1198,23 @@ def run_analysis(): @app.command() -def analyze(): - run_analysis() +def analyze( + checkpoint: bool = typer.Option( + False, + "--checkpoint", + help="Enable checkpoint/resume: save state after each node so a crashed run can resume.", + ), + clear_checkpoints: bool = typer.Option( + False, + "--clear-checkpoints", + help="Delete all saved checkpoints before running (force fresh start).", + ), +): + if clear_checkpoints: + from tradingagents.graph.checkpointer import clear_all_checkpoints + n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"]) + console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]") + run_analysis(checkpoint=checkpoint) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index b3dbc6fe..b569504e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "langchain-google-genai>=4.0.0", "langchain-openai>=0.3.23", "langgraph>=0.4.8", + "langgraph-checkpoint-sqlite>=2.0.0", "pandas>=2.3.0", "parsel>=1.10.0", "pytz>=2025.2", diff --git a/tests/test_checkpoint_resume.py b/tests/test_checkpoint_resume.py new file mode 100644 index 00000000..6f2692bd --- /dev/null +++ b/tests/test_checkpoint_resume.py @@ -0,0 +1,147 @@ +"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node.""" + +import sqlite3 +import tempfile +import unittest +from pathlib import Path +from typing import TypedDict + +from langgraph.checkpoint.sqlite import SqliteSaver +from langgraph.graph import END, StateGraph + +from tradingagents.graph.checkpointer import ( + checkpoint_step, + clear_checkpoint, + get_checkpointer, + has_checkpoint, + thread_id, +) + +# Mutable flag to simulate crash on first run +_should_crash = False + + +class _SimpleState(TypedDict): + count: int + + +def _node_a(state: _SimpleState) -> dict: + return {"count": state["count"] + 1} + + +def _node_b(state: _SimpleState) -> dict: + if _should_crash: + raise RuntimeError("simulated mid-analysis crash") + return {"count": state["count"] + 10} + + +def _build_graph() -> StateGraph: + builder = StateGraph(_SimpleState) + builder.add_node("analyst", _node_a) + builder.add_node("trader", _node_b) + builder.set_entry_point("analyst") + builder.add_edge("analyst", "trader") + builder.add_edge("trader", END) + return builder + + +class TestCheckpointResume(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.ticker = "TEST" + self.date = "2026-04-20" + + def test_crash_and_resume(self): + """Crash at 'trader' node, then resume from checkpoint.""" + global _should_crash + builder = _build_graph() + tid = thread_id(self.ticker, self.date) + cfg = {"configurable": {"thread_id": tid}} + + # Run 1: crash at trader node + _should_crash = True + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config=cfg) + + # Checkpoint should exist at step 1 (analyst completed) + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + step = checkpoint_step(self.tmpdir, self.ticker, self.date) + self.assertEqual(step, 1) + + # Run 2: resume — trader succeeds this time + _should_crash = False + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke(None, config=cfg) + + # analyst added 1, trader added 10 → 11 + self.assertEqual(result["count"], 11) + + def test_clear_checkpoint_allows_fresh_start(self): + """After clearing, the graph starts from scratch.""" + global _should_crash + builder = _build_graph() + tid = thread_id(self.ticker, self.date) + cfg = {"configurable": {"thread_id": tid}} + + # Create a checkpoint by crashing + _should_crash = True + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config=cfg) + + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # Clear it + clear_checkpoint(self.tmpdir, self.ticker, self.date) + self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # Fresh run succeeds from scratch + _should_crash = False + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke({"count": 0}, config=cfg) + + self.assertEqual(result["count"], 11) + + + def test_different_date_starts_fresh(self): + """A different date must NOT resume from an existing checkpoint.""" + global _should_crash + builder = _build_graph() + date2 = "2026-04-21" + + # Run with date1 — crash to leave a checkpoint + _should_crash = True + tid1 = thread_id(self.ticker, self.date) + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}}) + + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # date2 should have no checkpoint + self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2)) + + # Run with date2 — should start fresh and succeed + _should_crash = False + tid2 = thread_id(self.ticker, date2) + self.assertNotEqual(tid1, tid2) + + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}}) + + # Fresh run: analyst +1, trader +10 = 11 + self.assertEqual(result["count"], 11) + + # Original date checkpoint still exists (untouched) + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_memory_log.py b/tests/test_memory_log.py index 915bc3b1..ccd1ca7e 100644 --- a/tests/test_memory_log.py +++ b/tests/test_memory_log.py @@ -629,7 +629,9 @@ class TestLegacyRemoval: create_portfolio_manager(mock_llm, memory=MagicMock()) def test_full_pipeline_no_regression(self, tmp_path): - """propagate() completes without AttributeError after legacy cleanup.""" + """propagate() completes and stores the decision after the redesign.""" + import functools + fake_state = { "final_trade_decision": "Rating: Buy\nBuy NVDA.", "company_of_interest": "NVDA", @@ -660,6 +662,11 @@ class TestLegacyRemoval: mock_graph.propagator.create_initial_state.return_value = fake_state mock_graph.propagator.get_graph_args.return_value = {} mock_graph.signal_processor.process_signal.return_value = "Buy" + # Bind the real _run_graph so propagate's call to self._run_graph executes + # the actual write path instead of the auto-MagicMock. + mock_graph._run_graph = functools.partial( + TradingAgentsGraph._run_graph, mock_graph + ) TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10") entries = mock_graph.memory_log.load_entries() assert len(entries) == 1 diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 19dbe1c7..89b51765 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -16,6 +16,9 @@ DEFAULT_CONFIG = { "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" "anthropic_effort": None, # "high", "medium", "low" + # Checkpoint/resume: when True, LangGraph saves state after each node + # so a crashed run can resume from the last successful step. + "checkpoint_enabled": False, # Output language for analyst reports and final decision # Internal agent debate stays in English for reasoning quality "output_language": "English", diff --git a/tradingagents/graph/checkpointer.py b/tradingagents/graph/checkpointer.py new file mode 100644 index 00000000..7a73ee44 --- /dev/null +++ b/tradingagents/graph/checkpointer.py @@ -0,0 +1,86 @@ +"""LangGraph checkpoint support for resumable analysis runs. + +Per-ticker SQLite databases so concurrent tickers don't contend. +""" + +from __future__ import annotations + +import hashlib +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Generator + +from langgraph.checkpoint.sqlite import SqliteSaver + + +def _db_path(data_dir: str | Path, ticker: str) -> Path: + """Return the SQLite checkpoint DB path for a ticker.""" + p = Path(data_dir) / "checkpoints" + p.mkdir(parents=True, exist_ok=True) + return p / f"{ticker.upper()}.db" + + +def thread_id(ticker: str, date: str) -> str: + """Deterministic thread ID for a ticker+date pair.""" + return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16] + + +@contextmanager +def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]: + """Context manager yielding a SqliteSaver backed by a per-ticker DB.""" + db = _db_path(data_dir, ticker) + conn = sqlite3.connect(str(db), check_same_thread=False) + try: + saver = SqliteSaver(conn) + saver.setup() + yield saver + finally: + conn.close() + + +def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool: + """Check whether a resumable checkpoint exists for ticker+date.""" + return checkpoint_step(data_dir, ticker, date) is not None + + +def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None: + """Return the step number of the latest checkpoint, or None if none exists.""" + db = _db_path(data_dir, ticker) + if not db.exists(): + return None + tid = thread_id(ticker, date) + with get_checkpointer(data_dir, ticker) as saver: + config = {"configurable": {"thread_id": tid}} + cp = saver.get_tuple(config) + if cp is None: + return None + return cp.metadata.get("step") + + +def clear_all_checkpoints(data_dir: str | Path) -> int: + """Remove all checkpoint DBs. Returns number of files deleted.""" + cp_dir = Path(data_dir) / "checkpoints" + if not cp_dir.exists(): + return 0 + dbs = list(cp_dir.glob("*.db")) + for db in dbs: + db.unlink() + return len(dbs) + + +def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None: + """Remove checkpoint for a specific ticker+date by deleting the thread's rows.""" + db = _db_path(data_dir, ticker) + if not db.exists(): + return + tid = thread_id(ticker, date) + conn = sqlite3.connect(str(db)) + try: + for table in ("writes", "checkpoints"): + conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,)) + conn.commit() + except sqlite3.OperationalError: + pass + finally: + conn.close() diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 1686fc5b..45d6bfd3 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -179,5 +179,4 @@ class GraphSetup: workflow.add_edge("Portfolio Manager", END) - # Compile and return - return workflow.compile() + return workflow diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 4f92b188..bd6f1fc5 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -38,6 +38,7 @@ from tradingagents.agents.utils.agent_utils import ( get_global_news ) +from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator @@ -123,8 +124,10 @@ class TradingAgentsGraph: self.ticker = None self.log_states_dict = {} # date to full state dict - # Set up the graph - self.graph = self.graph_setup.setup_graph(selected_analysts) + # Set up the graph: keep the workflow for recompilation with a checkpointer. + self.workflow = self.graph_setup.setup_graph(selected_analysts) + self.graph = self.workflow.compile() + self._checkpointer_ctx = None def _get_provider_kwargs(self) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" @@ -259,23 +262,58 @@ class TradingAgentsGraph: self.memory_log.batch_update_with_outcomes(updates) def propagate(self, company_name, trade_date): - """Run the trading agents graph for a company on a specific date.""" + """Run the trading agents graph for a company on a specific date. + When ``checkpoint_enabled`` is set in config, the graph is recompiled + with a per-ticker SqliteSaver so a crashed run can resume from the last + successful node on a subsequent invocation with the same ticker+date. + """ self.ticker = company_name - # Resolve any pending log entries for this ticker before the pipeline runs. - # This adds the outcome + reflection from the previous run at zero latency cost. + # Resolve any pending memory-log entries for this ticker before the pipeline runs. self._resolve_pending_entries(company_name) - # Initialize state — inject memory log context for PM + # Recompile with a checkpointer if the user opted in. + if self.config.get("checkpoint_enabled"): + self._checkpointer_ctx = get_checkpointer( + self.config["data_cache_dir"], company_name + ) + saver = self._checkpointer_ctx.__enter__() + self.graph = self.workflow.compile(checkpointer=saver) + + step = checkpoint_step( + self.config["data_cache_dir"], company_name, str(trade_date) + ) + if step is not None: + logger.info( + "Resuming from step %d for %s on %s", step, company_name, trade_date + ) + else: + logger.info("Starting fresh for %s on %s", company_name, trade_date) + + try: + return self._run_graph(company_name, trade_date) + finally: + if self._checkpointer_ctx is not None: + self._checkpointer_ctx.__exit__(None, None, None) + self._checkpointer_ctx = None + self.graph = self.workflow.compile() + + def _run_graph(self, company_name, trade_date): + """Execute the graph and write the resulting state to disk and memory log.""" + # Initialize state — inject memory log context for PM. past_context = self.memory_log.get_past_context(company_name) init_agent_state = self.propagator.create_initial_state( company_name, trade_date, past_context=past_context ) args = self.propagator.get_graph_args() + # Inject thread_id so same ticker+date resumes, different date starts fresh. + if self.config.get("checkpoint_enabled"): + tid = thread_id(company_name, str(trade_date)) + args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid + if self.debug: - # Debug mode with tracing trace = [] for chunk in self.graph.stream(init_agent_state, **args): if len(chunk["messages"]) == 0: @@ -283,26 +321,29 @@ class TradingAgentsGraph: else: chunk["messages"][-1].pretty_print() trace.append(chunk) - final_state = trace[-1] else: - # Standard mode without tracing final_state = self.graph.invoke(init_agent_state, **args) - # Store current state for reflection + # Store current state for reflection. self.curr_state = final_state - # Log state + # Log state to disk. self._log_state(trade_date, final_state) - # Store decision for deferred reflection. + # Store decision for deferred reflection on the next same-ticker run. self.memory_log.store_decision( ticker=company_name, trade_date=trade_date, final_trade_decision=final_state["final_trade_decision"], ) - # Return decision and processed signal + # Clear checkpoint on successful completion to avoid stale state. + if self.config.get("checkpoint_enabled"): + clear_checkpoint( + self.config["data_cache_dir"], company_name, str(trade_date) + ) + return final_state, self.process_signal(final_state["final_trade_decision"]) def _log_state(self, trade_date, final_state):