"""Tests for deterministic instrument-identity resolution (#814) and the context-anchored message placeholder (#888).""" import unittest from unittest.mock import patch import pytest from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage from tradingagents.agents.utils.agent_utils import ( build_instrument_context, create_msg_delete, get_instrument_context_from_state, resolve_instrument_identity, ) @pytest.mark.unit class ResolveInstrumentIdentityTests(unittest.TestCase): def setUp(self): resolve_instrument_identity.cache_clear() def test_resolves_company_metadata_from_yfinance(self): with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock: mock.return_value.info = { "longName": "TOTO LTD.", "shortName": "TOTO", "sector": "Industrials", "industry": "Building Products & Equipment", "exchange": "PNK", "quoteType": "EQUITY", } identity = resolve_instrument_identity("totdy") mock.assert_called_once_with("TOTDY") self.assertEqual(identity["company_name"], "TOTO LTD.") self.assertEqual(identity["sector"], "Industrials") self.assertEqual(identity["industry"], "Building Products & Equipment") self.assertEqual(identity["exchange"], "PNK") def test_falls_back_to_short_name(self): with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock: mock.return_value.info = {"shortName": "TOTO", "sector": "Industrials"} identity = resolve_instrument_identity("TOTDY") self.assertEqual(identity["company_name"], "TOTO") def test_skips_placeholder_values(self): with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock: mock.return_value.info = {"longName": " ", "sector": "None", "industry": "n/a"} identity = resolve_instrument_identity("TOTDY") self.assertEqual(identity, {}) def test_fails_open_on_exception(self): with patch( "tradingagents.agents.utils.agent_utils.yf.Ticker", side_effect=RuntimeError("rate limited"), ): self.assertEqual(resolve_instrument_identity("TOTDY"), {}) def test_result_is_cached(self): with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock: mock.return_value.info = {"longName": "TOTO LTD."} first = resolve_instrument_identity("TOTDY") second = resolve_instrument_identity("TOTDY") mock.assert_called_once() # second call served from cache self.assertEqual(first, second) @pytest.mark.unit class BuildInstrumentContextTests(unittest.TestCase): def test_mentions_exact_symbol_without_identity(self): context = build_instrument_context("7203.T") self.assertIn("7203.T", context) self.assertIn("exchange suffix", context) self.assertNotIn("Resolved identity", context) def test_injects_resolved_identity(self): context = build_instrument_context( "TOTDY", "stock", { "company_name": "TOTO LTD.", "sector": "Industrials", "industry": "Building Products & Equipment", "exchange": "PNK", }, ) self.assertIn("Company: TOTO LTD.", context) self.assertIn("Industrials / Building Products & Equipment", context) self.assertIn("Exchange: PNK", context) self.assertIn("Do not substitute a different company", context) def test_crypto_uses_name_label_and_keeps_hint(self): context = build_instrument_context( "BTC-USD", "crypto", {"company_name": "Bitcoin USD"} ) self.assertIn("Name: Bitcoin USD", context) self.assertIn("crypto asset rather than a company", context) @pytest.mark.unit class GetInstrumentContextFromStateTests(unittest.TestCase): def test_prefers_precomputed_context(self): state = {"company_of_interest": "TOTDY", "instrument_context": "PRECOMPUTED"} self.assertEqual(get_instrument_context_from_state(state), "PRECOMPUTED") def test_fallback_is_network_free_ticker_only(self): # No instrument_context and no yfinance call — must not hit the network. with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock: context = get_instrument_context_from_state( {"company_of_interest": "NVDA", "asset_type": "stock"} ) mock.assert_not_called() self.assertIn("NVDA", context) def test_fallback_respects_asset_type(self): context = get_instrument_context_from_state( {"company_of_interest": "BTC-USD", "asset_type": "crypto"} ) self.assertIn("crypto asset", context) @pytest.mark.unit class ContextAnchoredPlaceholderTests(unittest.TestCase): """#888 — the message-clear placeholder must not be a bare 'Continue'.""" def _run(self, state_extra): state = { "messages": [ HumanMessage(content="old", id="h1"), AIMessage(content="reply", id="a1"), ], **state_extra, } return create_msg_delete()(state) def test_placeholder_is_not_bare_continue(self): result = self._run( {"company_of_interest": "EC", "asset_type": "stock", "trade_date": "2026-05-28"} ) placeholder = result["messages"][-1] self.assertIsInstance(placeholder, HumanMessage) self.assertNotEqual(placeholder.content.strip(), "Continue") def test_placeholder_carries_resolved_identity(self): result = self._run( { "company_of_interest": "EC", "instrument_context": "The instrument to analyze is `EC`. Resolved identity: Company: Ecopetrol.", "trade_date": "2026-05-28", } ) content = result["messages"][-1].content self.assertIn("Ecopetrol", content) self.assertIn("2026-05-28", content) def test_old_messages_are_removed(self): result = self._run({"company_of_interest": "EC", "trade_date": "2026-05-28"}) removals = [m for m in result["messages"] if isinstance(m, RemoveMessage)] humans = [m for m in result["messages"] if isinstance(m, HumanMessage)] self.assertEqual(len(removals), 2) self.assertEqual(len(humans), 1) def test_safe_defaults_when_state_minimal(self): result = create_msg_delete()({"messages": [], "company_of_interest": "EC"}) placeholder = result["messages"][-1] self.assertNotEqual(placeholder.content.strip(), "Continue") self.assertIn("EC", placeholder.content) if __name__ == "__main__": unittest.main()