mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
feat(sentiment): structured output for the Sentiment Analyst
The analyst emitted free-form prose, so its sentiment header varied by provider and run and downstream consumers needed drifting regex. Extend the structured-output pattern the trio already uses: a SentimentReport schema (band + 0-10 score + confidence + narrative) rendered to a deterministic header, with a free-text fallback for providers that lack native structured output. #796
This commit is contained in:
@@ -1,23 +1,28 @@
|
|||||||
"""Tests for structured-output agents (Trader and Research Manager).
|
"""Tests for structured-output agents (Trader, Research Manager, Sentiment Analyst).
|
||||||
|
|
||||||
The Portfolio Manager has its own coverage in tests/test_memory_log.py
|
The Portfolio Manager has its own coverage in tests/test_memory_log.py
|
||||||
(which exercises the full memory-log → PM injection cycle). This file
|
(which exercises the full memory-log → PM injection cycle). This file
|
||||||
covers the parallel schemas, render functions, and graceful-fallback
|
covers the parallel schemas, render functions, and graceful-fallback
|
||||||
behavior we added for the Trader and Research Manager so all three
|
behavior we added for the Trader, Research Manager, and Sentiment Analyst
|
||||||
decision-making agents share the same shape.
|
so they share the same deterministic output shape.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from tradingagents.agents.analysts.sentiment_analyst import create_sentiment_analyst
|
||||||
from tradingagents.agents.managers.research_manager import create_research_manager
|
from tradingagents.agents.managers.research_manager import create_research_manager
|
||||||
from tradingagents.agents.schemas import (
|
from tradingagents.agents.schemas import (
|
||||||
PortfolioRating,
|
PortfolioRating,
|
||||||
ResearchPlan,
|
ResearchPlan,
|
||||||
|
SentimentBand,
|
||||||
|
SentimentReport,
|
||||||
TraderAction,
|
TraderAction,
|
||||||
TraderProposal,
|
TraderProposal,
|
||||||
render_research_plan,
|
render_research_plan,
|
||||||
|
render_sentiment_report,
|
||||||
render_trader_proposal,
|
render_trader_proposal,
|
||||||
)
|
)
|
||||||
from tradingagents.agents.trader.trader import create_trader
|
from tradingagents.agents.trader.trader import create_trader
|
||||||
@@ -230,3 +235,126 @@ class TestResearchManagerAgent:
|
|||||||
rm = create_research_manager(llm)
|
rm = create_research_manager(llm)
|
||||||
result = rm(_make_rm_state())
|
result = rm(_make_rm_state())
|
||||||
assert result["investment_plan"] == plain_response
|
assert result["investment_plan"] == plain_response
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sentiment Analyst: schema, render, structured happy path + fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestRenderSentimentReport:
|
||||||
|
def test_header_contains_band_and_score(self):
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.BULLISH,
|
||||||
|
overall_score=7.2,
|
||||||
|
confidence="high",
|
||||||
|
narrative="Source breakdown here.",
|
||||||
|
)
|
||||||
|
md = render_sentiment_report(report)
|
||||||
|
assert "**Overall Sentiment:** **Bullish**" in md
|
||||||
|
assert "(Score: 7.2/10)" in md
|
||||||
|
|
||||||
|
def test_header_contains_confidence(self):
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.NEUTRAL,
|
||||||
|
overall_score=5.0,
|
||||||
|
confidence="low",
|
||||||
|
narrative="Limited data.",
|
||||||
|
)
|
||||||
|
assert "**Confidence:** Low" in render_sentiment_report(report)
|
||||||
|
|
||||||
|
def test_narrative_preserved_in_output(self):
|
||||||
|
narrative = "## Breakdown\n\nStockTwits: 70% bullish.\n\n| Signal | Direction |\n|---|---|\n| News | Neutral |"
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.MILDLY_BULLISH,
|
||||||
|
overall_score=6.0,
|
||||||
|
confidence="medium",
|
||||||
|
narrative=narrative,
|
||||||
|
)
|
||||||
|
assert narrative in render_sentiment_report(report)
|
||||||
|
|
||||||
|
def test_all_six_bands_render(self):
|
||||||
|
for band in SentimentBand:
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=band, overall_score=5.0,
|
||||||
|
confidence="medium", narrative="n",
|
||||||
|
)
|
||||||
|
assert band.value in render_sentiment_report(report)
|
||||||
|
|
||||||
|
def test_score_out_of_range_rejected(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SentimentReport(
|
||||||
|
overall_band=SentimentBand.BULLISH, overall_score=11.0,
|
||||||
|
confidence="high", narrative="n",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sentiment_state():
|
||||||
|
return {
|
||||||
|
"company_of_interest": "NVDA",
|
||||||
|
"trade_date": "2026-01-15",
|
||||||
|
"asset_type": "stock",
|
||||||
|
"messages": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _structured_sentiment_llm(captured: dict, report: SentimentReport | None = None):
|
||||||
|
"""MagicMock LLM whose structured binding captures the prompt and returns
|
||||||
|
a real SentimentReport so render_sentiment_report works."""
|
||||||
|
if report is None:
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.BULLISH, overall_score=7.5,
|
||||||
|
confidence="high",
|
||||||
|
narrative="StockTwits 75% bullish. News constructive. Reddit upbeat.",
|
||||||
|
)
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = lambda prompt: (
|
||||||
|
captured.__setitem__("prompt", prompt) or report
|
||||||
|
)
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSentimentAnalystAgent:
|
||||||
|
def test_structured_path_produces_rendered_markdown(self):
|
||||||
|
captured = {}
|
||||||
|
report = SentimentReport(
|
||||||
|
overall_band=SentimentBand.MILDLY_BEARISH, overall_score=4.0,
|
||||||
|
confidence="medium", narrative="Mixed signals across sources.",
|
||||||
|
)
|
||||||
|
analyst = create_sentiment_analyst(_structured_sentiment_llm(captured, report))
|
||||||
|
sr = analyst(_make_sentiment_state())["sentiment_report"]
|
||||||
|
assert "**Overall Sentiment:** **Mildly Bearish**" in sr
|
||||||
|
assert "(Score: 4.0/10)" in sr
|
||||||
|
assert "Mixed signals across sources." in sr
|
||||||
|
|
||||||
|
def test_sentiment_report_also_in_messages(self):
|
||||||
|
captured = {}
|
||||||
|
analyst = create_sentiment_analyst(_structured_sentiment_llm(captured))
|
||||||
|
result = analyst(_make_sentiment_state())
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
|
assert result["sentiment_report"] == result["messages"][0].content
|
||||||
|
|
||||||
|
def test_prompt_contains_ticker(self):
|
||||||
|
captured = {}
|
||||||
|
create_sentiment_analyst(_structured_sentiment_llm(captured))(_make_sentiment_state())
|
||||||
|
assert any("NVDA" in str(m) for m in captured["prompt"])
|
||||||
|
|
||||||
|
def test_falls_back_to_freetext_when_structured_unavailable(self):
|
||||||
|
plain = "**Overall Sentiment:** **Bearish** (Score: 3.0/10)\n**Confidence:** Low\n\nLimited data."
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain)
|
||||||
|
assert create_sentiment_analyst(llm)(_make_sentiment_state())["sentiment_report"] == plain
|
||||||
|
|
||||||
|
def test_falls_back_to_freetext_when_structured_call_fails(self):
|
||||||
|
plain = "Fallback free-text sentiment."
|
||||||
|
structured = MagicMock()
|
||||||
|
structured.invoke.side_effect = ValueError("bad JSON from model")
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.with_structured_output.return_value = structured
|
||||||
|
llm.invoke.return_value = MagicMock(content=plain)
|
||||||
|
assert create_sentiment_analyst(llm)(_make_sentiment_state())["sentiment_report"] == plain
|
||||||
|
|||||||
@@ -14,19 +14,31 @@ the LLM is invoked and injects them into the prompt as structured blocks:
|
|||||||
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
|
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
|
||||||
|
|
||||||
The agent does not use tool-calling; the data is in the prompt from
|
The agent does not use tool-calling; the data is in the prompt from
|
||||||
turn 0. The LLM produces the sentiment report in a single invocation.
|
turn 0. Output uses the structured-output pattern (json_schema for
|
||||||
|
OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic), falling
|
||||||
|
back to free-text generation for providers that lack native support, so
|
||||||
|
the sentiment header (band + score + confidence) is deterministic across
|
||||||
|
runs and providers instead of free-form per-model prose.
|
||||||
|
|
||||||
See: https://github.com/TauricResearch/TradingAgents/issues/557
|
See: https://github.com/TauricResearch/TradingAgents/issues/557
|
||||||
|
See: https://github.com/TauricResearch/TradingAgents/issues/796
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
|
from tradingagents.agents.schemas import SentimentReport, render_sentiment_report
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
get_instrument_context_from_state,
|
get_instrument_context_from_state,
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
get_news,
|
get_news,
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.structured import (
|
||||||
|
bind_structured,
|
||||||
|
invoke_structured_or_freetext,
|
||||||
|
)
|
||||||
from tradingagents.dataflows.reddit import fetch_reddit_posts
|
from tradingagents.dataflows.reddit import fetch_reddit_posts
|
||||||
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
|
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
|
||||||
|
|
||||||
@@ -39,9 +51,11 @@ def create_sentiment_analyst(llm):
|
|||||||
"""Create a sentiment analyst node for the trading graph.
|
"""Create a sentiment analyst node for the trading graph.
|
||||||
|
|
||||||
Pre-fetches news + StockTwits + Reddit data, injects them into the
|
Pre-fetches news + StockTwits + Reddit data, injects them into the
|
||||||
prompt as structured blocks, and produces a sentiment report in a
|
prompt as structured blocks, and produces a deterministic sentiment
|
||||||
single LLM call.
|
report via structured output (with a free-text fallback for providers
|
||||||
|
that do not support it).
|
||||||
"""
|
"""
|
||||||
|
structured_llm = bind_structured(llm, SentimentReport, "Sentiment Analyst")
|
||||||
|
|
||||||
def sentiment_analyst_node(state):
|
def sentiment_analyst_node(state):
|
||||||
ticker = state["company_of_interest"]
|
ticker = state["company_of_interest"]
|
||||||
@@ -83,14 +97,22 @@ def create_sentiment_analyst(llm):
|
|||||||
prompt = prompt.partial(current_date=end_date)
|
prompt = prompt.partial(current_date=end_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
# No bind_tools — the data is already in the prompt; a single LLM
|
# Format the template into a concrete message list so the structured
|
||||||
# call produces the report directly.
|
# and free-text paths receive the same input. No bind_tools — the
|
||||||
chain = prompt | llm
|
# data is already in the prompt.
|
||||||
result = chain.invoke(state["messages"])
|
formatted_messages = prompt.format_messages(messages=state["messages"])
|
||||||
|
|
||||||
|
report_text = invoke_structured_or_freetext(
|
||||||
|
structured_llm,
|
||||||
|
llm,
|
||||||
|
formatted_messages,
|
||||||
|
render_sentiment_report,
|
||||||
|
"Sentiment Analyst",
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [AIMessage(content=report_text)],
|
||||||
"sentiment_report": result.content,
|
"sentiment_report": report_text,
|
||||||
}
|
}
|
||||||
|
|
||||||
return sentiment_analyst_node
|
return sentiment_analyst_node
|
||||||
@@ -143,21 +165,20 @@ Community discussion. Engagement signal via upvote score and comment count. Subr
|
|||||||
|
|
||||||
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
|
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
|
||||||
|
|
||||||
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this caveat explicitly. If the sources are silent on a given subreddit, say so.
|
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this explicitly in the `confidence` field and the narrative. If the sources are silent on a given subreddit, say so.
|
||||||
|
|
||||||
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
|
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
|
||||||
|
|
||||||
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
|
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
|
||||||
|
|
||||||
## Output
|
## Output fields
|
||||||
|
|
||||||
Produce a sentiment report covering, in order:
|
Fill the following fields:
|
||||||
|
|
||||||
1. **Overall sentiment direction** — Bullish / Bearish / Neutral / Mixed — with a brief confidence note based on data quality and sample size.
|
- **overall_band**: Exactly one of Bullish / Mildly Bullish / Neutral / Mixed / Mildly Bearish / Bearish. Use Mixed when sources point in clearly different directions; Neutral only when all sources are genuinely silent.
|
||||||
2. **Source-by-source breakdown** — what each of news / StockTwits / Reddit is telling you, with specific evidence (cite message counts, ratios, notable posts).
|
- **overall_score**: A number from 0 (maximally bearish) to 10 (maximally bullish); 5 is neutral. Keep it consistent with overall_band.
|
||||||
3. **Divergences, alignments, and key narratives** across sources.
|
- **confidence**: low / medium / high, based on data quality and sample size.
|
||||||
4. **Catalysts and risks** surfaced by the data.
|
- **narrative**: Full source-by-source breakdown, divergences, dominant narrative themes, catalysts and risks, and a markdown summary table of key sentiment signals (direction, source, supporting evidence).
|
||||||
5. **Markdown table** at the end summarizing key sentiment signals, their direction, source, and supporting evidence.
|
|
||||||
|
|
||||||
{get_language_instruction()}"""
|
{get_language_instruction()}"""
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ so that:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -226,3 +226,92 @@ def render_pm_decision(decision: PortfolioDecision) -> str:
|
|||||||
if decision.time_horizon:
|
if decision.time_horizon:
|
||||||
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
|
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sentiment Analyst
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentBand(str, Enum):
|
||||||
|
"""Discrete sentiment direction produced by the Sentiment Analyst.
|
||||||
|
|
||||||
|
Six tiers keep the signal granular enough to be actionable while remaining
|
||||||
|
small enough for every provider to map reliably from its JSON output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BULLISH = "Bullish"
|
||||||
|
MILDLY_BULLISH = "Mildly Bullish"
|
||||||
|
NEUTRAL = "Neutral"
|
||||||
|
MIXED = "Mixed"
|
||||||
|
MILDLY_BEARISH = "Mildly Bearish"
|
||||||
|
BEARISH = "Bearish"
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentReport(BaseModel):
|
||||||
|
"""Structured sentiment report produced by the Sentiment Analyst.
|
||||||
|
|
||||||
|
Replaces the previous free-form prose output so downstream consumers
|
||||||
|
(dashboards, audit logs, PDF renderers, other agents) can read
|
||||||
|
``overall_band`` and ``overall_score`` without maintaining fragile regex
|
||||||
|
fallbacks that drift with every model release. ``narrative`` preserves the
|
||||||
|
rich source-by-source analysis; ``render_sentiment_report`` prepends a
|
||||||
|
deterministic header so the saved report stays human-readable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
overall_band: SentimentBand = Field(
|
||||||
|
description=(
|
||||||
|
"Overall sentiment direction. Exactly one of: "
|
||||||
|
"Bullish / Mildly Bullish / Neutral / Mixed / Mildly Bearish / Bearish. "
|
||||||
|
"Use Mixed when sources point in clearly different directions. "
|
||||||
|
"Use Neutral only when all sources are genuinely silent or non-committal."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
overall_score: float = Field(
|
||||||
|
ge=0.0,
|
||||||
|
le=10.0,
|
||||||
|
description=(
|
||||||
|
"Numeric sentiment intensity on a 0–10 scale. "
|
||||||
|
"0 = maximally bearish, 5 = neutral, 10 = maximally bullish. "
|
||||||
|
"Guideline for consistency with overall_band: "
|
||||||
|
"Bullish ~6.5–10, Mildly Bullish ~5.5–6.4, Neutral/Mixed ~4.5–5.5, "
|
||||||
|
"Mildly Bearish ~3.5–4.4, Bearish ~0–3.4. "
|
||||||
|
"Only the 0–10 bounds are enforced."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
confidence: Literal["low", "medium", "high"] = Field(
|
||||||
|
description=(
|
||||||
|
"Confidence in the assessment based on data quality and sample size. "
|
||||||
|
"Use 'low' when one or more sources returned a placeholder or fewer "
|
||||||
|
"than 5 data points; 'medium' when data is present but sparse; "
|
||||||
|
"'high' when all three sources returned substantive data."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
narrative: str = Field(
|
||||||
|
description=(
|
||||||
|
"Full sentiment report covering, in order: "
|
||||||
|
"(1) source-by-source breakdown with specific evidence (cite message "
|
||||||
|
"counts, ratios, notable posts); "
|
||||||
|
"(2) cross-source divergences and alignments; "
|
||||||
|
"(3) dominant narrative themes; "
|
||||||
|
"(4) catalysts and risks surfaced by the data; "
|
||||||
|
"(5) a markdown table summarising key sentiment signals, their "
|
||||||
|
"direction, source, and supporting evidence."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_sentiment_report(report: SentimentReport) -> str:
|
||||||
|
"""Render a SentimentReport to the markdown shape the rest of the system expects.
|
||||||
|
|
||||||
|
The structured header (band + score + confidence) is prepended to the
|
||||||
|
narrative so the saved report is both human-readable and machine-parseable
|
||||||
|
without regex.
|
||||||
|
"""
|
||||||
|
return "\n".join([
|
||||||
|
f"**Overall Sentiment:** **{report.overall_band.value}** "
|
||||||
|
f"(Score: {report.overall_score:.1f}/10)",
|
||||||
|
f"**Confidence:** {report.confidence.capitalize()}",
|
||||||
|
"",
|
||||||
|
report.narrative,
|
||||||
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user