fix(graph): resolve instrument identity to stop wrong-company hallucination

Agents had no ground-truth ticker→company mapping, so the market analyst
could pattern-match a price chart to the wrong company (e.g. TOTDY read as
"TotalEnergies"), and every downstream agent inherited the bad framing.

Resolve identity once at run start via a cached, fail-open yfinance lookup
and inject company/sector/exchange into the shared instrument context that
all twelve agents consume, with an explicit do-not-substitute instruction.
Resolution runs on both the propagate() and CLI entry points.

Also replaces the bare "Continue" message-clear placeholder, which some
OpenAI-compatible providers interpreted as the user task, with a
context-anchored placeholder carrying the resolved identity and date.

#814 #888
This commit is contained in:
Yijia-Xiao
2026-05-30 23:56:32 +00:00
parent 61522e103e
commit d7b40a2a5c
18 changed files with 390 additions and 44 deletions

View File

@@ -0,0 +1,170 @@
"""Tests for deterministic instrument-identity resolution (#814) and the
context-anchored message placeholder (#888)."""
import unittest
from unittest.mock import patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
create_msg_delete,
get_instrument_context_from_state,
resolve_instrument_identity,
)
@pytest.mark.unit
class ResolveInstrumentIdentityTests(unittest.TestCase):
def setUp(self):
resolve_instrument_identity.cache_clear()
def test_resolves_company_metadata_from_yfinance(self):
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
mock.return_value.info = {
"longName": "TOTO LTD.",
"shortName": "TOTO",
"sector": "Industrials",
"industry": "Building Products & Equipment",
"exchange": "PNK",
"quoteType": "EQUITY",
}
identity = resolve_instrument_identity("totdy")
mock.assert_called_once_with("TOTDY")
self.assertEqual(identity["company_name"], "TOTO LTD.")
self.assertEqual(identity["sector"], "Industrials")
self.assertEqual(identity["industry"], "Building Products & Equipment")
self.assertEqual(identity["exchange"], "PNK")
def test_falls_back_to_short_name(self):
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
mock.return_value.info = {"shortName": "TOTO", "sector": "Industrials"}
identity = resolve_instrument_identity("TOTDY")
self.assertEqual(identity["company_name"], "TOTO")
def test_skips_placeholder_values(self):
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
mock.return_value.info = {"longName": " ", "sector": "None", "industry": "n/a"}
identity = resolve_instrument_identity("TOTDY")
self.assertEqual(identity, {})
def test_fails_open_on_exception(self):
with patch(
"tradingagents.agents.utils.agent_utils.yf.Ticker",
side_effect=RuntimeError("rate limited"),
):
self.assertEqual(resolve_instrument_identity("TOTDY"), {})
def test_result_is_cached(self):
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
mock.return_value.info = {"longName": "TOTO LTD."}
first = resolve_instrument_identity("TOTDY")
second = resolve_instrument_identity("TOTDY")
mock.assert_called_once() # second call served from cache
self.assertEqual(first, second)
@pytest.mark.unit
class BuildInstrumentContextTests(unittest.TestCase):
def test_mentions_exact_symbol_without_identity(self):
context = build_instrument_context("7203.T")
self.assertIn("7203.T", context)
self.assertIn("exchange suffix", context)
self.assertNotIn("Resolved identity", context)
def test_injects_resolved_identity(self):
context = build_instrument_context(
"TOTDY", "stock",
{
"company_name": "TOTO LTD.",
"sector": "Industrials",
"industry": "Building Products & Equipment",
"exchange": "PNK",
},
)
self.assertIn("Company: TOTO LTD.", context)
self.assertIn("Industrials / Building Products & Equipment", context)
self.assertIn("Exchange: PNK", context)
self.assertIn("Do not substitute a different company", context)
def test_crypto_uses_name_label_and_keeps_hint(self):
context = build_instrument_context(
"BTC-USD", "crypto", {"company_name": "Bitcoin USD"}
)
self.assertIn("Name: Bitcoin USD", context)
self.assertIn("crypto asset rather than a company", context)
@pytest.mark.unit
class GetInstrumentContextFromStateTests(unittest.TestCase):
def test_prefers_precomputed_context(self):
state = {"company_of_interest": "TOTDY", "instrument_context": "PRECOMPUTED"}
self.assertEqual(get_instrument_context_from_state(state), "PRECOMPUTED")
def test_fallback_is_network_free_ticker_only(self):
# No instrument_context and no yfinance call — must not hit the network.
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
context = get_instrument_context_from_state(
{"company_of_interest": "NVDA", "asset_type": "stock"}
)
mock.assert_not_called()
self.assertIn("NVDA", context)
def test_fallback_respects_asset_type(self):
context = get_instrument_context_from_state(
{"company_of_interest": "BTC-USD", "asset_type": "crypto"}
)
self.assertIn("crypto asset", context)
@pytest.mark.unit
class ContextAnchoredPlaceholderTests(unittest.TestCase):
"""#888 — the message-clear placeholder must not be a bare 'Continue'."""
def _run(self, state_extra):
state = {
"messages": [
HumanMessage(content="old", id="h1"),
AIMessage(content="reply", id="a1"),
],
**state_extra,
}
return create_msg_delete()(state)
def test_placeholder_is_not_bare_continue(self):
result = self._run(
{"company_of_interest": "EC", "asset_type": "stock", "trade_date": "2026-05-28"}
)
placeholder = result["messages"][-1]
self.assertIsInstance(placeholder, HumanMessage)
self.assertNotEqual(placeholder.content.strip(), "Continue")
def test_placeholder_carries_resolved_identity(self):
result = self._run(
{
"company_of_interest": "EC",
"instrument_context": "The instrument to analyze is `EC`. Resolved identity: Company: Ecopetrol.",
"trade_date": "2026-05-28",
}
)
content = result["messages"][-1].content
self.assertIn("Ecopetrol", content)
self.assertIn("2026-05-28", content)
def test_old_messages_are_removed(self):
result = self._run({"company_of_interest": "EC", "trade_date": "2026-05-28"})
removals = [m for m in result["messages"] if isinstance(m, RemoveMessage)]
humans = [m for m in result["messages"] if isinstance(m, HumanMessage)]
self.assertEqual(len(removals), 2)
self.assertEqual(len(humans), 1)
def test_safe_defaults_when_state_minimal(self):
result = create_msg_delete()({"messages": [], "company_of_interest": "EC"})
placeholder = result["messages"][-1]
self.assertNotEqual(placeholder.content.strip(), "Continue")
self.assertIn("EC", placeholder.content)
if __name__ == "__main__":
unittest.main()