mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
fix(graph): resolve instrument identity to stop wrong-company hallucination
Agents had no ground-truth ticker→company mapping, so the market analyst could pattern-match a price chart to the wrong company (e.g. TOTDY read as "TotalEnergies"), and every downstream agent inherited the bad framing. Resolve identity once at run start via a cached, fail-open yfinance lookup and inject company/sector/exchange into the shared instrument context that all twelve agents consume, with an explicit do-not-substitute instruction. Resolution runs on both the propagate() and CLI entry points. Also replaces the bare "Continue" message-clear placeholder, which some OpenAI-compatible providers interpreted as the user task, with a context-anchored placeholder carrying the resolved identity and date. #814 #888
This commit is contained in:
170
tests/test_instrument_identity.py
Normal file
170
tests/test_instrument_identity.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Tests for deterministic instrument-identity resolution (#814) and the
|
||||
context-anchored message placeholder (#888)."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
create_msg_delete,
|
||||
get_instrument_context_from_state,
|
||||
resolve_instrument_identity,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class ResolveInstrumentIdentityTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
resolve_instrument_identity.cache_clear()
|
||||
|
||||
def test_resolves_company_metadata_from_yfinance(self):
|
||||
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||
mock.return_value.info = {
|
||||
"longName": "TOTO LTD.",
|
||||
"shortName": "TOTO",
|
||||
"sector": "Industrials",
|
||||
"industry": "Building Products & Equipment",
|
||||
"exchange": "PNK",
|
||||
"quoteType": "EQUITY",
|
||||
}
|
||||
identity = resolve_instrument_identity("totdy")
|
||||
mock.assert_called_once_with("TOTDY")
|
||||
self.assertEqual(identity["company_name"], "TOTO LTD.")
|
||||
self.assertEqual(identity["sector"], "Industrials")
|
||||
self.assertEqual(identity["industry"], "Building Products & Equipment")
|
||||
self.assertEqual(identity["exchange"], "PNK")
|
||||
|
||||
def test_falls_back_to_short_name(self):
|
||||
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||
mock.return_value.info = {"shortName": "TOTO", "sector": "Industrials"}
|
||||
identity = resolve_instrument_identity("TOTDY")
|
||||
self.assertEqual(identity["company_name"], "TOTO")
|
||||
|
||||
def test_skips_placeholder_values(self):
|
||||
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||
mock.return_value.info = {"longName": " ", "sector": "None", "industry": "n/a"}
|
||||
identity = resolve_instrument_identity("TOTDY")
|
||||
self.assertEqual(identity, {})
|
||||
|
||||
def test_fails_open_on_exception(self):
|
||||
with patch(
|
||||
"tradingagents.agents.utils.agent_utils.yf.Ticker",
|
||||
side_effect=RuntimeError("rate limited"),
|
||||
):
|
||||
self.assertEqual(resolve_instrument_identity("TOTDY"), {})
|
||||
|
||||
def test_result_is_cached(self):
|
||||
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||
mock.return_value.info = {"longName": "TOTO LTD."}
|
||||
first = resolve_instrument_identity("TOTDY")
|
||||
second = resolve_instrument_identity("TOTDY")
|
||||
mock.assert_called_once() # second call served from cache
|
||||
self.assertEqual(first, second)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class BuildInstrumentContextTests(unittest.TestCase):
|
||||
def test_mentions_exact_symbol_without_identity(self):
|
||||
context = build_instrument_context("7203.T")
|
||||
self.assertIn("7203.T", context)
|
||||
self.assertIn("exchange suffix", context)
|
||||
self.assertNotIn("Resolved identity", context)
|
||||
|
||||
def test_injects_resolved_identity(self):
|
||||
context = build_instrument_context(
|
||||
"TOTDY", "stock",
|
||||
{
|
||||
"company_name": "TOTO LTD.",
|
||||
"sector": "Industrials",
|
||||
"industry": "Building Products & Equipment",
|
||||
"exchange": "PNK",
|
||||
},
|
||||
)
|
||||
self.assertIn("Company: TOTO LTD.", context)
|
||||
self.assertIn("Industrials / Building Products & Equipment", context)
|
||||
self.assertIn("Exchange: PNK", context)
|
||||
self.assertIn("Do not substitute a different company", context)
|
||||
|
||||
def test_crypto_uses_name_label_and_keeps_hint(self):
|
||||
context = build_instrument_context(
|
||||
"BTC-USD", "crypto", {"company_name": "Bitcoin USD"}
|
||||
)
|
||||
self.assertIn("Name: Bitcoin USD", context)
|
||||
self.assertIn("crypto asset rather than a company", context)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class GetInstrumentContextFromStateTests(unittest.TestCase):
|
||||
def test_prefers_precomputed_context(self):
|
||||
state = {"company_of_interest": "TOTDY", "instrument_context": "PRECOMPUTED"}
|
||||
self.assertEqual(get_instrument_context_from_state(state), "PRECOMPUTED")
|
||||
|
||||
def test_fallback_is_network_free_ticker_only(self):
|
||||
# No instrument_context and no yfinance call — must not hit the network.
|
||||
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
||||
context = get_instrument_context_from_state(
|
||||
{"company_of_interest": "NVDA", "asset_type": "stock"}
|
||||
)
|
||||
mock.assert_not_called()
|
||||
self.assertIn("NVDA", context)
|
||||
|
||||
def test_fallback_respects_asset_type(self):
|
||||
context = get_instrument_context_from_state(
|
||||
{"company_of_interest": "BTC-USD", "asset_type": "crypto"}
|
||||
)
|
||||
self.assertIn("crypto asset", context)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class ContextAnchoredPlaceholderTests(unittest.TestCase):
|
||||
"""#888 — the message-clear placeholder must not be a bare 'Continue'."""
|
||||
|
||||
def _run(self, state_extra):
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="old", id="h1"),
|
||||
AIMessage(content="reply", id="a1"),
|
||||
],
|
||||
**state_extra,
|
||||
}
|
||||
return create_msg_delete()(state)
|
||||
|
||||
def test_placeholder_is_not_bare_continue(self):
|
||||
result = self._run(
|
||||
{"company_of_interest": "EC", "asset_type": "stock", "trade_date": "2026-05-28"}
|
||||
)
|
||||
placeholder = result["messages"][-1]
|
||||
self.assertIsInstance(placeholder, HumanMessage)
|
||||
self.assertNotEqual(placeholder.content.strip(), "Continue")
|
||||
|
||||
def test_placeholder_carries_resolved_identity(self):
|
||||
result = self._run(
|
||||
{
|
||||
"company_of_interest": "EC",
|
||||
"instrument_context": "The instrument to analyze is `EC`. Resolved identity: Company: Ecopetrol.",
|
||||
"trade_date": "2026-05-28",
|
||||
}
|
||||
)
|
||||
content = result["messages"][-1].content
|
||||
self.assertIn("Ecopetrol", content)
|
||||
self.assertIn("2026-05-28", content)
|
||||
|
||||
def test_old_messages_are_removed(self):
|
||||
result = self._run({"company_of_interest": "EC", "trade_date": "2026-05-28"})
|
||||
removals = [m for m in result["messages"] if isinstance(m, RemoveMessage)]
|
||||
humans = [m for m in result["messages"] if isinstance(m, HumanMessage)]
|
||||
self.assertEqual(len(removals), 2)
|
||||
self.assertEqual(len(humans), 1)
|
||||
|
||||
def test_safe_defaults_when_state_minimal(self):
|
||||
result = create_msg_delete()({"messages": [], "company_of_interest": "EC"})
|
||||
placeholder = result["messages"][-1]
|
||||
self.assertNotEqual(placeholder.content.strip(), "Continue")
|
||||
self.assertIn("EC", placeholder.content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user