mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
- analyst_execution.py: rename "Social Analyst" / "Msg Clear Social" to "Sentiment Analyst" / "Msg Clear Sentiment" to match v0.2.5. - conditional_logic.should_continue_social returns the renamed route. - TradingAgentsGraph.propagate accepts asset_type and threads through to Propagator.create_initial_state. - Regression test on the Sentiment Analyst label. Verified end-to-end (NVDA stock + BTC-USD crypto) on gpt-5.4-mini.
This commit is contained in:
@@ -34,6 +34,17 @@ class AnalystExecutionPlanTests(unittest.TestCase):
|
|||||||
"Fundamentals Analyst",
|
"Fundamentals Analyst",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_social_key_displays_as_sentiment_analyst(self):
|
||||||
|
# The wire key stays "social" for saved-config back-compat, but the
|
||||||
|
# user-visible agent_node label must match the v0.2.5 rename so the
|
||||||
|
# wall-time summary and any future consumer of agent_node says
|
||||||
|
# "Sentiment Analyst" rather than the legacy "Social Analyst".
|
||||||
|
plan = build_analyst_execution_plan(["social"])
|
||||||
|
spec = plan.specs[0]
|
||||||
|
self.assertEqual(spec.key, "social")
|
||||||
|
self.assertEqual(spec.agent_node, "Sentiment Analyst")
|
||||||
|
self.assertEqual(spec.report_key, "sentiment_report")
|
||||||
|
|
||||||
|
|
||||||
class AnalystWallTimeTrackerTests(unittest.TestCase):
|
class AnalystWallTimeTrackerTests(unittest.TestCase):
|
||||||
def test_records_wall_time_when_analyst_completes(self):
|
def test_records_wall_time_when_analyst_completes(self):
|
||||||
|
|||||||
@@ -27,9 +27,13 @@ ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = {
|
|||||||
report_key="market_report",
|
report_key="market_report",
|
||||||
),
|
),
|
||||||
"social": AnalystNodeSpec(
|
"social": AnalystNodeSpec(
|
||||||
|
# Wire key stays "social" for saved-config back-compat; the
|
||||||
|
# user-facing label is "Sentiment Analyst" to match the rename
|
||||||
|
# that landed in v0.2.5 (sentiment_analyst now ingests news +
|
||||||
|
# StockTwits + Reddit, not just social media).
|
||||||
key="social",
|
key="social",
|
||||||
agent_node="Social Analyst",
|
agent_node="Sentiment Analyst",
|
||||||
clear_node="Msg Clear Social",
|
clear_node="Msg Clear Sentiment",
|
||||||
tool_node="tools_social",
|
tool_node="tools_social",
|
||||||
report_key="sentiment_report",
|
report_key="sentiment_report",
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -20,12 +20,18 @@ class ConditionalLogic:
|
|||||||
return "Msg Clear Market"
|
return "Msg Clear Market"
|
||||||
|
|
||||||
def should_continue_social(self, state: AgentState):
|
def should_continue_social(self, state: AgentState):
|
||||||
"""Determine if social media analysis should continue."""
|
"""Determine if sentiment-analyst tool round should continue.
|
||||||
|
|
||||||
|
Method name keeps the legacy ``social`` suffix to match the
|
||||||
|
``AnalystType.SOCIAL = "social"`` wire value (saved-config
|
||||||
|
back-compat); the returned ``clear_node`` label uses the v0.2.5
|
||||||
|
rename so it matches the node registered by the execution plan.
|
||||||
|
"""
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
if last_message.tool_calls:
|
if last_message.tool_calls:
|
||||||
return "tools_social"
|
return "tools_social"
|
||||||
return "Msg Clear Social"
|
return "Msg Clear Sentiment"
|
||||||
|
|
||||||
def should_continue_news(self, state: AgentState):
|
def should_continue_news(self, state: AgentState):
|
||||||
"""Determine if news analysis should continue."""
|
"""Determine if news analysis should continue."""
|
||||||
|
|||||||
@@ -292,11 +292,14 @@ class TradingAgentsGraph:
|
|||||||
if updates:
|
if updates:
|
||||||
self.memory_log.batch_update_with_outcomes(updates)
|
self.memory_log.batch_update_with_outcomes(updates)
|
||||||
|
|
||||||
def propagate(self, company_name, trade_date):
|
def propagate(self, company_name, trade_date, asset_type: str = "stock"):
|
||||||
"""Run the trading agents graph for a company on a specific date.
|
"""Run the trading agents graph for a company on a specific date.
|
||||||
|
|
||||||
When ``checkpoint_enabled`` is set in config, the graph is recompiled
|
``asset_type`` selects between the stock pipeline (default) and the
|
||||||
with a per-ticker SqliteSaver so a crashed run can resume from the last
|
crypto pipeline (``"crypto"``) shipped in #567 — the CLI auto-detects
|
||||||
|
from the ticker; programmatic callers pass it explicitly. When
|
||||||
|
``checkpoint_enabled`` is set in config, the graph is recompiled with
|
||||||
|
a per-ticker SqliteSaver so a crashed run can resume from the last
|
||||||
successful node on a subsequent invocation with the same ticker+date.
|
successful node on a subsequent invocation with the same ticker+date.
|
||||||
"""
|
"""
|
||||||
self.ticker = company_name
|
self.ticker = company_name
|
||||||
@@ -323,19 +326,19 @@ class TradingAgentsGraph:
|
|||||||
logger.info("Starting fresh for %s on %s", company_name, trade_date)
|
logger.info("Starting fresh for %s on %s", company_name, trade_date)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self._run_graph(company_name, trade_date)
|
return self._run_graph(company_name, trade_date, asset_type=asset_type)
|
||||||
finally:
|
finally:
|
||||||
if self._checkpointer_ctx is not None:
|
if self._checkpointer_ctx is not None:
|
||||||
self._checkpointer_ctx.__exit__(None, None, None)
|
self._checkpointer_ctx.__exit__(None, None, None)
|
||||||
self._checkpointer_ctx = None
|
self._checkpointer_ctx = None
|
||||||
self.graph = self.workflow.compile()
|
self.graph = self.workflow.compile()
|
||||||
|
|
||||||
def _run_graph(self, company_name, trade_date):
|
def _run_graph(self, company_name, trade_date, asset_type: str = "stock"):
|
||||||
"""Execute the graph and write the resulting state to disk and memory log."""
|
"""Execute the graph and write the resulting state to disk and memory log."""
|
||||||
# Initialize state — inject memory log context for PM.
|
# Initialize state — inject memory log context for PM.
|
||||||
past_context = self.memory_log.get_past_context(company_name)
|
past_context = self.memory_log.get_past_context(company_name)
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date, past_context=past_context
|
company_name, trade_date, asset_type=asset_type, past_context=past_context
|
||||||
)
|
)
|
||||||
args = self.propagator.get_graph_args()
|
args = self.propagator.get_graph_args()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user