feat(sentiment): structured output for the Sentiment Analyst

The analyst emitted free-form prose, so its sentiment header varied by
provider and run and downstream consumers needed drifting regex. Extend
the structured-output pattern the trio already uses: a SentimentReport
schema (band + 0-10 score + confidence + narrative) rendered to a
deterministic header, with a free-text fallback for providers that lack
native structured output.

#796
This commit is contained in:
Yijia-Xiao
2026-05-31 01:45:25 +00:00
parent a66aa8fb94
commit e80636fc0e
3 changed files with 259 additions and 21 deletions

View File

@@ -1,23 +1,28 @@
"""Tests for structured-output agents (Trader and Research Manager). """Tests for structured-output agents (Trader, Research Manager, Sentiment Analyst).
The Portfolio Manager has its own coverage in tests/test_memory_log.py The Portfolio Manager has its own coverage in tests/test_memory_log.py
(which exercises the full memory-log → PM injection cycle). This file (which exercises the full memory-log → PM injection cycle). This file
covers the parallel schemas, render functions, and graceful-fallback covers the parallel schemas, render functions, and graceful-fallback
behavior we added for the Trader and Research Manager so all three behavior we added for the Trader, Research Manager, and Sentiment Analyst
decision-making agents share the same shape. so they share the same deterministic output shape.
""" """
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from pydantic import ValidationError
from tradingagents.agents.analysts.sentiment_analyst import create_sentiment_analyst
from tradingagents.agents.managers.research_manager import create_research_manager from tradingagents.agents.managers.research_manager import create_research_manager
from tradingagents.agents.schemas import ( from tradingagents.agents.schemas import (
PortfolioRating, PortfolioRating,
ResearchPlan, ResearchPlan,
SentimentBand,
SentimentReport,
TraderAction, TraderAction,
TraderProposal, TraderProposal,
render_research_plan, render_research_plan,
render_sentiment_report,
render_trader_proposal, render_trader_proposal,
) )
from tradingagents.agents.trader.trader import create_trader from tradingagents.agents.trader.trader import create_trader
@@ -230,3 +235,126 @@ class TestResearchManagerAgent:
rm = create_research_manager(llm) rm = create_research_manager(llm)
result = rm(_make_rm_state()) result = rm(_make_rm_state())
assert result["investment_plan"] == plain_response assert result["investment_plan"] == plain_response
# ---------------------------------------------------------------------------
# Sentiment Analyst: schema, render, structured happy path + fallback
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestRenderSentimentReport:
def test_header_contains_band_and_score(self):
report = SentimentReport(
overall_band=SentimentBand.BULLISH,
overall_score=7.2,
confidence="high",
narrative="Source breakdown here.",
)
md = render_sentiment_report(report)
assert "**Overall Sentiment:** **Bullish**" in md
assert "(Score: 7.2/10)" in md
def test_header_contains_confidence(self):
report = SentimentReport(
overall_band=SentimentBand.NEUTRAL,
overall_score=5.0,
confidence="low",
narrative="Limited data.",
)
assert "**Confidence:** Low" in render_sentiment_report(report)
def test_narrative_preserved_in_output(self):
narrative = "## Breakdown\n\nStockTwits: 70% bullish.\n\n| Signal | Direction |\n|---|---|\n| News | Neutral |"
report = SentimentReport(
overall_band=SentimentBand.MILDLY_BULLISH,
overall_score=6.0,
confidence="medium",
narrative=narrative,
)
assert narrative in render_sentiment_report(report)
def test_all_six_bands_render(self):
for band in SentimentBand:
report = SentimentReport(
overall_band=band, overall_score=5.0,
confidence="medium", narrative="n",
)
assert band.value in render_sentiment_report(report)
def test_score_out_of_range_rejected(self):
with pytest.raises(ValidationError):
SentimentReport(
overall_band=SentimentBand.BULLISH, overall_score=11.0,
confidence="high", narrative="n",
)
def _make_sentiment_state():
return {
"company_of_interest": "NVDA",
"trade_date": "2026-01-15",
"asset_type": "stock",
"messages": [],
}
def _structured_sentiment_llm(captured: dict, report: SentimentReport | None = None):
"""MagicMock LLM whose structured binding captures the prompt and returns
a real SentimentReport so render_sentiment_report works."""
if report is None:
report = SentimentReport(
overall_band=SentimentBand.BULLISH, overall_score=7.5,
confidence="high",
narrative="StockTwits 75% bullish. News constructive. Reddit upbeat.",
)
structured = MagicMock()
structured.invoke.side_effect = lambda prompt: (
captured.__setitem__("prompt", prompt) or report
)
llm = MagicMock()
llm.with_structured_output.return_value = structured
return llm
@pytest.mark.unit
class TestSentimentAnalystAgent:
def test_structured_path_produces_rendered_markdown(self):
captured = {}
report = SentimentReport(
overall_band=SentimentBand.MILDLY_BEARISH, overall_score=4.0,
confidence="medium", narrative="Mixed signals across sources.",
)
analyst = create_sentiment_analyst(_structured_sentiment_llm(captured, report))
sr = analyst(_make_sentiment_state())["sentiment_report"]
assert "**Overall Sentiment:** **Mildly Bearish**" in sr
assert "(Score: 4.0/10)" in sr
assert "Mixed signals across sources." in sr
def test_sentiment_report_also_in_messages(self):
captured = {}
analyst = create_sentiment_analyst(_structured_sentiment_llm(captured))
result = analyst(_make_sentiment_state())
assert len(result["messages"]) == 1
assert result["sentiment_report"] == result["messages"][0].content
def test_prompt_contains_ticker(self):
captured = {}
create_sentiment_analyst(_structured_sentiment_llm(captured))(_make_sentiment_state())
assert any("NVDA" in str(m) for m in captured["prompt"])
def test_falls_back_to_freetext_when_structured_unavailable(self):
plain = "**Overall Sentiment:** **Bearish** (Score: 3.0/10)\n**Confidence:** Low\n\nLimited data."
llm = MagicMock()
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
llm.invoke.return_value = MagicMock(content=plain)
assert create_sentiment_analyst(llm)(_make_sentiment_state())["sentiment_report"] == plain
def test_falls_back_to_freetext_when_structured_call_fails(self):
plain = "Fallback free-text sentiment."
structured = MagicMock()
structured.invoke.side_effect = ValueError("bad JSON from model")
llm = MagicMock()
llm.with_structured_output.return_value = structured
llm.invoke.return_value = MagicMock(content=plain)
assert create_sentiment_analyst(llm)(_make_sentiment_state())["sentiment_report"] == plain

View File

@@ -14,19 +14,31 @@ the LLM is invoked and injects them into the prompt as structured blocks:
3. Reddit posts — r/wallstreetbets, r/stocks, r/investing 3. Reddit posts — r/wallstreetbets, r/stocks, r/investing
The agent does not use tool-calling; the data is in the prompt from The agent does not use tool-calling; the data is in the prompt from
turn 0. The LLM produces the sentiment report in a single invocation. turn 0. Output uses the structured-output pattern (json_schema for
OpenAI/xAI, response_schema for Gemini, tool-use for Anthropic), falling
back to free-text generation for providers that lack native support, so
the sentiment header (band + score + confidence) is deterministic across
runs and providers instead of free-form per-model prose.
See: https://github.com/TauricResearch/TradingAgents/issues/557 See: https://github.com/TauricResearch/TradingAgents/issues/557
See: https://github.com/TauricResearch/TradingAgents/issues/796
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.schemas import SentimentReport, render_sentiment_report
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
get_instrument_context_from_state, get_instrument_context_from_state,
get_language_instruction, get_language_instruction,
get_news, get_news,
) )
from tradingagents.agents.utils.structured import (
bind_structured,
invoke_structured_or_freetext,
)
from tradingagents.dataflows.reddit import fetch_reddit_posts from tradingagents.dataflows.reddit import fetch_reddit_posts
from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages from tradingagents.dataflows.stocktwits import fetch_stocktwits_messages
@@ -39,9 +51,11 @@ def create_sentiment_analyst(llm):
"""Create a sentiment analyst node for the trading graph. """Create a sentiment analyst node for the trading graph.
Pre-fetches news + StockTwits + Reddit data, injects them into the Pre-fetches news + StockTwits + Reddit data, injects them into the
prompt as structured blocks, and produces a sentiment report in a prompt as structured blocks, and produces a deterministic sentiment
single LLM call. report via structured output (with a free-text fallback for providers
that do not support it).
""" """
structured_llm = bind_structured(llm, SentimentReport, "Sentiment Analyst")
def sentiment_analyst_node(state): def sentiment_analyst_node(state):
ticker = state["company_of_interest"] ticker = state["company_of_interest"]
@@ -83,14 +97,22 @@ def create_sentiment_analyst(llm):
prompt = prompt.partial(current_date=end_date) prompt = prompt.partial(current_date=end_date)
prompt = prompt.partial(instrument_context=instrument_context) prompt = prompt.partial(instrument_context=instrument_context)
# No bind_tools — the data is already in the prompt; a single LLM # Format the template into a concrete message list so the structured
# call produces the report directly. # and free-text paths receive the same input. No bind_tools — the
chain = prompt | llm # data is already in the prompt.
result = chain.invoke(state["messages"]) formatted_messages = prompt.format_messages(messages=state["messages"])
report_text = invoke_structured_or_freetext(
structured_llm,
llm,
formatted_messages,
render_sentiment_report,
"Sentiment Analyst",
)
return { return {
"messages": [result], "messages": [AIMessage(content=report_text)],
"sentiment_report": result.content, "sentiment_report": report_text,
} }
return sentiment_analyst_node return sentiment_analyst_node
@@ -143,21 +165,20 @@ Community discussion. Engagement signal via upvote score and comment count. Subr
5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment. 5. **Identify recurring narrative themes.** What topic keeps coming up across sources? That's the dominant narrative driving current sentiment.
6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this caveat explicitly. If the sources are silent on a given subreddit, say so. 6. **Be honest about data limits.** If StockTwits returned only a handful of messages, or one or more sources returned an "<unavailable>" placeholder, the sentiment read is less robust — flag this explicitly in the `confidence` field and the narrative. If the sources are silent on a given subreddit, say so.
7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc. 7. **Identify catalysts and risks** that emerge across sources — news of upcoming earnings, product launches, competitive threats, macro headlines, etc.
8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call. 8. **Past sentiment is not predictive.** Frame your conclusions as signal for the trader to weigh alongside fundamentals and technicals, not as a price call.
## Output ## Output fields
Produce a sentiment report covering, in order: Fill the following fields:
1. **Overall sentiment direction** — Bullish / Bearish / Neutral / Mixed — with a brief confidence note based on data quality and sample size. - **overall_band**: Exactly one of Bullish / Mildly Bullish / Neutral / Mixed / Mildly Bearish / Bearish. Use Mixed when sources point in clearly different directions; Neutral only when all sources are genuinely silent.
2. **Source-by-source breakdown** — what each of news / StockTwits / Reddit is telling you, with specific evidence (cite message counts, ratios, notable posts). - **overall_score**: A number from 0 (maximally bearish) to 10 (maximally bullish); 5 is neutral. Keep it consistent with overall_band.
3. **Divergences, alignments, and key narratives** across sources. - **confidence**: low / medium / high, based on data quality and sample size.
4. **Catalysts and risks** surfaced by the data. - **narrative**: Full source-by-source breakdown, divergences, dominant narrative themes, catalysts and risks, and a markdown summary table of key sentiment signals (direction, source, supporting evidence).
5. **Markdown table** at the end summarizing key sentiment signals, their direction, source, and supporting evidence.
{get_language_instruction()}""" {get_language_instruction()}"""

View File

@@ -19,7 +19,7 @@ so that:
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Optional from typing import Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -226,3 +226,92 @@ def render_pm_decision(decision: PortfolioDecision) -> str:
if decision.time_horizon: if decision.time_horizon:
parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"]) parts.extend(["", f"**Time Horizon**: {decision.time_horizon}"])
return "\n".join(parts) return "\n".join(parts)
# ---------------------------------------------------------------------------
# Sentiment Analyst
# ---------------------------------------------------------------------------
class SentimentBand(str, Enum):
"""Discrete sentiment direction produced by the Sentiment Analyst.
Six tiers keep the signal granular enough to be actionable while remaining
small enough for every provider to map reliably from its JSON output.
"""
BULLISH = "Bullish"
MILDLY_BULLISH = "Mildly Bullish"
NEUTRAL = "Neutral"
MIXED = "Mixed"
MILDLY_BEARISH = "Mildly Bearish"
BEARISH = "Bearish"
class SentimentReport(BaseModel):
"""Structured sentiment report produced by the Sentiment Analyst.
Replaces the previous free-form prose output so downstream consumers
(dashboards, audit logs, PDF renderers, other agents) can read
``overall_band`` and ``overall_score`` without maintaining fragile regex
fallbacks that drift with every model release. ``narrative`` preserves the
rich source-by-source analysis; ``render_sentiment_report`` prepends a
deterministic header so the saved report stays human-readable.
"""
overall_band: SentimentBand = Field(
description=(
"Overall sentiment direction. Exactly one of: "
"Bullish / Mildly Bullish / Neutral / Mixed / Mildly Bearish / Bearish. "
"Use Mixed when sources point in clearly different directions. "
"Use Neutral only when all sources are genuinely silent or non-committal."
),
)
overall_score: float = Field(
ge=0.0,
le=10.0,
description=(
"Numeric sentiment intensity on a 010 scale. "
"0 = maximally bearish, 5 = neutral, 10 = maximally bullish. "
"Guideline for consistency with overall_band: "
"Bullish ~6.510, Mildly Bullish ~5.56.4, Neutral/Mixed ~4.55.5, "
"Mildly Bearish ~3.54.4, Bearish ~03.4. "
"Only the 010 bounds are enforced."
),
)
confidence: Literal["low", "medium", "high"] = Field(
description=(
"Confidence in the assessment based on data quality and sample size. "
"Use 'low' when one or more sources returned a placeholder or fewer "
"than 5 data points; 'medium' when data is present but sparse; "
"'high' when all three sources returned substantive data."
),
)
narrative: str = Field(
description=(
"Full sentiment report covering, in order: "
"(1) source-by-source breakdown with specific evidence (cite message "
"counts, ratios, notable posts); "
"(2) cross-source divergences and alignments; "
"(3) dominant narrative themes; "
"(4) catalysts and risks surfaced by the data; "
"(5) a markdown table summarising key sentiment signals, their "
"direction, source, and supporting evidence."
),
)
def render_sentiment_report(report: SentimentReport) -> str:
"""Render a SentimentReport to the markdown shape the rest of the system expects.
The structured header (band + score + confidence) is prepended to the
narrative so the saved report is both human-readable and machine-parseable
without regex.
"""
return "\n".join([
f"**Overall Sentiment:** **{report.overall_band.value}** "
f"(Score: {report.overall_score:.1f}/10)",
f"**Confidence:** {report.confidence.capitalize()}",
"",
report.narrative,
])