Files
tradingagents/tests/test_structured_agents.py
Yijia-Xiao e80636fc0e feat(sentiment): structured output for the Sentiment Analyst
The analyst emitted free-form prose, so its sentiment header varied by
provider and run and downstream consumers needed drifting regex. Extend
the structured-output pattern the trio already uses: a SentimentReport
schema (band + 0-10 score + confidence + narrative) rendered to a
deterministic header, with a free-text fallback for providers that lack
native structured output.

#796
2026-05-31 01:45:25 +00:00

361 lines
14 KiB
Python

"""Tests for structured-output agents (Trader, Research Manager, Sentiment Analyst).
The Portfolio Manager has its own coverage in tests/test_memory_log.py
(which exercises the full memory-log → PM injection cycle). This file
covers the parallel schemas, render functions, and graceful-fallback
behavior we added for the Trader, Research Manager, and Sentiment Analyst
so they share the same deterministic output shape.
"""
from unittest.mock import MagicMock
import pytest
from pydantic import ValidationError
from tradingagents.agents.analysts.sentiment_analyst import create_sentiment_analyst
from tradingagents.agents.managers.research_manager import create_research_manager
from tradingagents.agents.schemas import (
PortfolioRating,
ResearchPlan,
SentimentBand,
SentimentReport,
TraderAction,
TraderProposal,
render_research_plan,
render_sentiment_report,
render_trader_proposal,
)
from tradingagents.agents.trader.trader import create_trader
# ---------------------------------------------------------------------------
# Render functions
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestRenderTraderProposal:
def test_minimal_required_fields(self):
p = TraderProposal(action=TraderAction.HOLD, reasoning="Balanced setup; no edge.")
md = render_trader_proposal(p)
assert "**Action**: Hold" in md
assert "**Reasoning**: Balanced setup; no edge." in md
# The trailing FINAL TRANSACTION PROPOSAL line is preserved for the
# analyst stop-signal text and any external code that greps for it.
assert "FINAL TRANSACTION PROPOSAL: **HOLD**" in md
def test_optional_fields_included_when_present(self):
p = TraderProposal(
action=TraderAction.BUY,
reasoning="Strong technicals + fundamentals.",
entry_price=189.5,
stop_loss=178.0,
position_sizing="6% of portfolio",
)
md = render_trader_proposal(p)
assert "**Action**: Buy" in md
assert "**Entry Price**: 189.5" in md
assert "**Stop Loss**: 178.0" in md
assert "**Position Sizing**: 6% of portfolio" in md
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in md
def test_optional_fields_omitted_when_absent(self):
p = TraderProposal(action=TraderAction.SELL, reasoning="Guidance cut.")
md = render_trader_proposal(p)
assert "Entry Price" not in md
assert "Stop Loss" not in md
assert "Position Sizing" not in md
assert "FINAL TRANSACTION PROPOSAL: **SELL**" in md
@pytest.mark.unit
class TestRenderResearchPlan:
def test_required_fields(self):
p = ResearchPlan(
recommendation=PortfolioRating.OVERWEIGHT,
rationale="Bull case carried; tailwinds intact.",
strategic_actions="Build position over two weeks; cap at 5%.",
)
md = render_research_plan(p)
assert "**Recommendation**: Overweight" in md
assert "**Rationale**: Bull case carried" in md
assert "**Strategic Actions**: Build position" in md
def test_all_5_tier_ratings_render(self):
for rating in PortfolioRating:
p = ResearchPlan(
recommendation=rating,
rationale="r",
strategic_actions="s",
)
md = render_research_plan(p)
assert f"**Recommendation**: {rating.value}" in md
# ---------------------------------------------------------------------------
# Trader agent: structured happy path + fallback
# ---------------------------------------------------------------------------
def _make_trader_state():
return {
"company_of_interest": "NVDA",
"investment_plan": "**Recommendation**: Buy\n**Rationale**: ...\n**Strategic Actions**: ...",
}
def _structured_trader_llm(captured: dict, proposal: TraderProposal | None = None):
"""Build a MagicMock LLM whose with_structured_output binding captures the
prompt and returns a real TraderProposal so render_trader_proposal works.
"""
if proposal is None:
proposal = TraderProposal(
action=TraderAction.BUY,
reasoning="Strong setup.",
)
structured = MagicMock()
structured.invoke.side_effect = lambda prompt: (
captured.__setitem__("prompt", prompt) or proposal
)
llm = MagicMock()
llm.with_structured_output.return_value = structured
return llm
@pytest.mark.unit
class TestTraderAgent:
def test_structured_path_produces_rendered_markdown(self):
captured = {}
proposal = TraderProposal(
action=TraderAction.BUY,
reasoning="AI capex cycle intact; institutional flows constructive.",
entry_price=189.5,
stop_loss=178.0,
position_sizing="6% of portfolio",
)
llm = _structured_trader_llm(captured, proposal)
trader = create_trader(llm)
result = trader(_make_trader_state())
plan = result["trader_investment_plan"]
assert "**Action**: Buy" in plan
assert "**Entry Price**: 189.5" in plan
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in plan
# The same rendered markdown is also added to messages for downstream agents.
assert plan in result["messages"][0].content
def test_prompt_includes_investment_plan(self):
captured = {}
llm = _structured_trader_llm(captured)
trader = create_trader(llm)
trader(_make_trader_state())
# The investment plan is in the user message of the captured prompt.
prompt = captured["prompt"]
assert any("Proposed Investment Plan" in m["content"] for m in prompt)
def test_falls_back_to_freetext_when_structured_unavailable(self):
plain_response = (
"**Action**: Sell\n\nGuidance cut hits margins.\n\n"
"FINAL TRANSACTION PROPOSAL: **SELL**"
)
llm = MagicMock()
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
llm.invoke.return_value = MagicMock(content=plain_response)
trader = create_trader(llm)
result = trader(_make_trader_state())
assert result["trader_investment_plan"] == plain_response
# ---------------------------------------------------------------------------
# Research Manager agent: structured happy path + fallback
# ---------------------------------------------------------------------------
def _make_rm_state():
return {
"company_of_interest": "NVDA",
"investment_debate_state": {
"history": "Bull and bear arguments here.",
"bull_history": "Bull says...",
"bear_history": "Bear says...",
"current_response": "",
"judge_decision": "",
"count": 1,
},
}
def _structured_rm_llm(captured: dict, plan: ResearchPlan | None = None):
if plan is None:
plan = ResearchPlan(
recommendation=PortfolioRating.HOLD,
rationale="Balanced view across both sides.",
strategic_actions="Hold current position; reassess after earnings.",
)
structured = MagicMock()
structured.invoke.side_effect = lambda prompt: (
captured.__setitem__("prompt", prompt) or plan
)
llm = MagicMock()
llm.with_structured_output.return_value = structured
return llm
@pytest.mark.unit
class TestResearchManagerAgent:
def test_structured_path_produces_rendered_markdown(self):
captured = {}
plan = ResearchPlan(
recommendation=PortfolioRating.OVERWEIGHT,
rationale="Bull case is stronger; AI tailwind intact.",
strategic_actions="Build position gradually over two weeks.",
)
llm = _structured_rm_llm(captured, plan)
rm = create_research_manager(llm)
result = rm(_make_rm_state())
ip = result["investment_plan"]
assert "**Recommendation**: Overweight" in ip
assert "**Rationale**: Bull case" in ip
assert "**Strategic Actions**: Build position" in ip
def test_prompt_uses_5_tier_rating_scale(self):
"""The RM prompt must list all five tiers so the schema enum matches user expectations."""
captured = {}
llm = _structured_rm_llm(captured)
rm = create_research_manager(llm)
rm(_make_rm_state())
prompt = captured["prompt"]
for tier in ("Buy", "Overweight", "Hold", "Underweight", "Sell"):
assert f"**{tier}**" in prompt, f"missing {tier} in prompt"
def test_falls_back_to_freetext_when_structured_unavailable(self):
plain_response = "**Recommendation**: Sell\n\n**Rationale**: ...\n\n**Strategic Actions**: ..."
llm = MagicMock()
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
llm.invoke.return_value = MagicMock(content=plain_response)
rm = create_research_manager(llm)
result = rm(_make_rm_state())
assert result["investment_plan"] == plain_response
# ---------------------------------------------------------------------------
# Sentiment Analyst: schema, render, structured happy path + fallback
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestRenderSentimentReport:
def test_header_contains_band_and_score(self):
report = SentimentReport(
overall_band=SentimentBand.BULLISH,
overall_score=7.2,
confidence="high",
narrative="Source breakdown here.",
)
md = render_sentiment_report(report)
assert "**Overall Sentiment:** **Bullish**" in md
assert "(Score: 7.2/10)" in md
def test_header_contains_confidence(self):
report = SentimentReport(
overall_band=SentimentBand.NEUTRAL,
overall_score=5.0,
confidence="low",
narrative="Limited data.",
)
assert "**Confidence:** Low" in render_sentiment_report(report)
def test_narrative_preserved_in_output(self):
narrative = "## Breakdown\n\nStockTwits: 70% bullish.\n\n| Signal | Direction |\n|---|---|\n| News | Neutral |"
report = SentimentReport(
overall_band=SentimentBand.MILDLY_BULLISH,
overall_score=6.0,
confidence="medium",
narrative=narrative,
)
assert narrative in render_sentiment_report(report)
def test_all_six_bands_render(self):
for band in SentimentBand:
report = SentimentReport(
overall_band=band, overall_score=5.0,
confidence="medium", narrative="n",
)
assert band.value in render_sentiment_report(report)
def test_score_out_of_range_rejected(self):
with pytest.raises(ValidationError):
SentimentReport(
overall_band=SentimentBand.BULLISH, overall_score=11.0,
confidence="high", narrative="n",
)
def _make_sentiment_state():
return {
"company_of_interest": "NVDA",
"trade_date": "2026-01-15",
"asset_type": "stock",
"messages": [],
}
def _structured_sentiment_llm(captured: dict, report: SentimentReport | None = None):
"""MagicMock LLM whose structured binding captures the prompt and returns
a real SentimentReport so render_sentiment_report works."""
if report is None:
report = SentimentReport(
overall_band=SentimentBand.BULLISH, overall_score=7.5,
confidence="high",
narrative="StockTwits 75% bullish. News constructive. Reddit upbeat.",
)
structured = MagicMock()
structured.invoke.side_effect = lambda prompt: (
captured.__setitem__("prompt", prompt) or report
)
llm = MagicMock()
llm.with_structured_output.return_value = structured
return llm
@pytest.mark.unit
class TestSentimentAnalystAgent:
def test_structured_path_produces_rendered_markdown(self):
captured = {}
report = SentimentReport(
overall_band=SentimentBand.MILDLY_BEARISH, overall_score=4.0,
confidence="medium", narrative="Mixed signals across sources.",
)
analyst = create_sentiment_analyst(_structured_sentiment_llm(captured, report))
sr = analyst(_make_sentiment_state())["sentiment_report"]
assert "**Overall Sentiment:** **Mildly Bearish**" in sr
assert "(Score: 4.0/10)" in sr
assert "Mixed signals across sources." in sr
def test_sentiment_report_also_in_messages(self):
captured = {}
analyst = create_sentiment_analyst(_structured_sentiment_llm(captured))
result = analyst(_make_sentiment_state())
assert len(result["messages"]) == 1
assert result["sentiment_report"] == result["messages"][0].content
def test_prompt_contains_ticker(self):
captured = {}
create_sentiment_analyst(_structured_sentiment_llm(captured))(_make_sentiment_state())
assert any("NVDA" in str(m) for m in captured["prompt"])
def test_falls_back_to_freetext_when_structured_unavailable(self):
plain = "**Overall Sentiment:** **Bearish** (Score: 3.0/10)\n**Confidence:** Low\n\nLimited data."
llm = MagicMock()
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
llm.invoke.return_value = MagicMock(content=plain)
assert create_sentiment_analyst(llm)(_make_sentiment_state())["sentiment_report"] == plain
def test_falls_back_to_freetext_when_structured_call_fails(self):
plain = "Fallback free-text sentiment."
structured = MagicMock()
structured.invoke.side_effect = ValueError("bad JSON from model")
llm = MagicMock()
llm.with_structured_output.return_value = structured
llm.invoke.return_value = MagicMock(content=plain)
assert create_sentiment_analyst(llm)(_make_sentiment_state())["sentiment_report"] == plain