diff --git a/tests/test_analyst_execution.py b/tests/test_analyst_execution.py index 7d4e851a9..896c9e21e 100644 --- a/tests/test_analyst_execution.py +++ b/tests/test_analyst_execution.py @@ -34,6 +34,17 @@ class AnalystExecutionPlanTests(unittest.TestCase): "Fundamentals Analyst", ) + def test_social_key_displays_as_sentiment_analyst(self): + # The wire key stays "social" for saved-config back-compat, but the + # user-visible agent_node label must match the v0.2.5 rename so the + # wall-time summary and any future consumer of agent_node says + # "Sentiment Analyst" rather than the legacy "Social Analyst". + plan = build_analyst_execution_plan(["social"]) + spec = plan.specs[0] + self.assertEqual(spec.key, "social") + self.assertEqual(spec.agent_node, "Sentiment Analyst") + self.assertEqual(spec.report_key, "sentiment_report") + class AnalystWallTimeTrackerTests(unittest.TestCase): def test_records_wall_time_when_analyst_completes(self): diff --git a/tradingagents/graph/analyst_execution.py b/tradingagents/graph/analyst_execution.py index 080668f16..14c375338 100644 --- a/tradingagents/graph/analyst_execution.py +++ b/tradingagents/graph/analyst_execution.py @@ -27,9 +27,13 @@ ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = { report_key="market_report", ), "social": AnalystNodeSpec( + # Wire key stays "social" for saved-config back-compat; the + # user-facing label is "Sentiment Analyst" to match the rename + # that landed in v0.2.5 (sentiment_analyst now ingests news + + # StockTwits + Reddit, not just social media). key="social", - agent_node="Social Analyst", - clear_node="Msg Clear Social", + agent_node="Sentiment Analyst", + clear_node="Msg Clear Sentiment", tool_node="tools_social", report_key="sentiment_report", ), diff --git a/tradingagents/graph/conditional_logic.py b/tradingagents/graph/conditional_logic.py index 483717935..b03273564 100644 --- a/tradingagents/graph/conditional_logic.py +++ b/tradingagents/graph/conditional_logic.py @@ -20,12 +20,18 @@ class ConditionalLogic: return "Msg Clear Market" def should_continue_social(self, state: AgentState): - """Determine if social media analysis should continue.""" + """Determine if sentiment-analyst tool round should continue. + + Method name keeps the legacy ``social`` suffix to match the + ``AnalystType.SOCIAL = "social"`` wire value (saved-config + back-compat); the returned ``clear_node`` label uses the v0.2.5 + rename so it matches the node registered by the execution plan. + """ messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: return "tools_social" - return "Msg Clear Social" + return "Msg Clear Sentiment" def should_continue_news(self, state: AgentState): """Determine if news analysis should continue.""" diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 5910a548d..e3d1f125b 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -292,11 +292,14 @@ class TradingAgentsGraph: if updates: self.memory_log.batch_update_with_outcomes(updates) - def propagate(self, company_name, trade_date): + def propagate(self, company_name, trade_date, asset_type: str = "stock"): """Run the trading agents graph for a company on a specific date. - When ``checkpoint_enabled`` is set in config, the graph is recompiled - with a per-ticker SqliteSaver so a crashed run can resume from the last + ``asset_type`` selects between the stock pipeline (default) and the + crypto pipeline (``"crypto"``) shipped in #567 — the CLI auto-detects + from the ticker; programmatic callers pass it explicitly. When + ``checkpoint_enabled`` is set in config, the graph is recompiled with + a per-ticker SqliteSaver so a crashed run can resume from the last successful node on a subsequent invocation with the same ticker+date. """ self.ticker = company_name @@ -323,19 +326,19 @@ class TradingAgentsGraph: logger.info("Starting fresh for %s on %s", company_name, trade_date) try: - return self._run_graph(company_name, trade_date) + return self._run_graph(company_name, trade_date, asset_type=asset_type) finally: if self._checkpointer_ctx is not None: self._checkpointer_ctx.__exit__(None, None, None) self._checkpointer_ctx = None self.graph = self.workflow.compile() - def _run_graph(self, company_name, trade_date): + def _run_graph(self, company_name, trade_date, asset_type: str = "stock"): """Execute the graph and write the resulting state to disk and memory log.""" # Initialize state — inject memory log context for PM. past_context = self.memory_log.get_past_context(company_name) init_agent_state = self.propagator.create_initial_state( - company_name, trade_date, past_context=past_context + company_name, trade_date, asset_type=asset_type, past_context=past_context ) args = self.propagator.get_graph_args()