mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
Agents had no ground-truth ticker→company mapping, so the market analyst could pattern-match a price chart to the wrong company (e.g. TOTDY read as "TotalEnergies"), and every downstream agent inherited the bad framing. Resolve identity once at run start via a cached, fail-open yfinance lookup and inject company/sector/exchange into the shared instrument context that all twelve agents consume, with an explicit do-not-substitute instruction. Resolution runs on both the propagate() and CLI entry points. Also replaces the bare "Continue" message-clear placeholder, which some OpenAI-compatible providers interpreted as the user task, with a context-anchored placeholder carrying the resolved identity and date. #814 #888
171 lines
6.7 KiB
Python
171 lines
6.7 KiB
Python
"""Tests for deterministic instrument-identity resolution (#814) and the
|
|
context-anchored message placeholder (#888)."""
|
|
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
|
|
|
|
from tradingagents.agents.utils.agent_utils import (
|
|
build_instrument_context,
|
|
create_msg_delete,
|
|
get_instrument_context_from_state,
|
|
resolve_instrument_identity,
|
|
)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class ResolveInstrumentIdentityTests(unittest.TestCase):
|
|
def setUp(self):
|
|
resolve_instrument_identity.cache_clear()
|
|
|
|
def test_resolves_company_metadata_from_yfinance(self):
|
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
|
mock.return_value.info = {
|
|
"longName": "TOTO LTD.",
|
|
"shortName": "TOTO",
|
|
"sector": "Industrials",
|
|
"industry": "Building Products & Equipment",
|
|
"exchange": "PNK",
|
|
"quoteType": "EQUITY",
|
|
}
|
|
identity = resolve_instrument_identity("totdy")
|
|
mock.assert_called_once_with("TOTDY")
|
|
self.assertEqual(identity["company_name"], "TOTO LTD.")
|
|
self.assertEqual(identity["sector"], "Industrials")
|
|
self.assertEqual(identity["industry"], "Building Products & Equipment")
|
|
self.assertEqual(identity["exchange"], "PNK")
|
|
|
|
def test_falls_back_to_short_name(self):
|
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
|
mock.return_value.info = {"shortName": "TOTO", "sector": "Industrials"}
|
|
identity = resolve_instrument_identity("TOTDY")
|
|
self.assertEqual(identity["company_name"], "TOTO")
|
|
|
|
def test_skips_placeholder_values(self):
|
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
|
mock.return_value.info = {"longName": " ", "sector": "None", "industry": "n/a"}
|
|
identity = resolve_instrument_identity("TOTDY")
|
|
self.assertEqual(identity, {})
|
|
|
|
def test_fails_open_on_exception(self):
|
|
with patch(
|
|
"tradingagents.agents.utils.agent_utils.yf.Ticker",
|
|
side_effect=RuntimeError("rate limited"),
|
|
):
|
|
self.assertEqual(resolve_instrument_identity("TOTDY"), {})
|
|
|
|
def test_result_is_cached(self):
|
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
|
mock.return_value.info = {"longName": "TOTO LTD."}
|
|
first = resolve_instrument_identity("TOTDY")
|
|
second = resolve_instrument_identity("TOTDY")
|
|
mock.assert_called_once() # second call served from cache
|
|
self.assertEqual(first, second)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class BuildInstrumentContextTests(unittest.TestCase):
|
|
def test_mentions_exact_symbol_without_identity(self):
|
|
context = build_instrument_context("7203.T")
|
|
self.assertIn("7203.T", context)
|
|
self.assertIn("exchange suffix", context)
|
|
self.assertNotIn("Resolved identity", context)
|
|
|
|
def test_injects_resolved_identity(self):
|
|
context = build_instrument_context(
|
|
"TOTDY", "stock",
|
|
{
|
|
"company_name": "TOTO LTD.",
|
|
"sector": "Industrials",
|
|
"industry": "Building Products & Equipment",
|
|
"exchange": "PNK",
|
|
},
|
|
)
|
|
self.assertIn("Company: TOTO LTD.", context)
|
|
self.assertIn("Industrials / Building Products & Equipment", context)
|
|
self.assertIn("Exchange: PNK", context)
|
|
self.assertIn("Do not substitute a different company", context)
|
|
|
|
def test_crypto_uses_name_label_and_keeps_hint(self):
|
|
context = build_instrument_context(
|
|
"BTC-USD", "crypto", {"company_name": "Bitcoin USD"}
|
|
)
|
|
self.assertIn("Name: Bitcoin USD", context)
|
|
self.assertIn("crypto asset rather than a company", context)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class GetInstrumentContextFromStateTests(unittest.TestCase):
|
|
def test_prefers_precomputed_context(self):
|
|
state = {"company_of_interest": "TOTDY", "instrument_context": "PRECOMPUTED"}
|
|
self.assertEqual(get_instrument_context_from_state(state), "PRECOMPUTED")
|
|
|
|
def test_fallback_is_network_free_ticker_only(self):
|
|
# No instrument_context and no yfinance call — must not hit the network.
|
|
with patch("tradingagents.agents.utils.agent_utils.yf.Ticker") as mock:
|
|
context = get_instrument_context_from_state(
|
|
{"company_of_interest": "NVDA", "asset_type": "stock"}
|
|
)
|
|
mock.assert_not_called()
|
|
self.assertIn("NVDA", context)
|
|
|
|
def test_fallback_respects_asset_type(self):
|
|
context = get_instrument_context_from_state(
|
|
{"company_of_interest": "BTC-USD", "asset_type": "crypto"}
|
|
)
|
|
self.assertIn("crypto asset", context)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class ContextAnchoredPlaceholderTests(unittest.TestCase):
|
|
"""#888 — the message-clear placeholder must not be a bare 'Continue'."""
|
|
|
|
def _run(self, state_extra):
|
|
state = {
|
|
"messages": [
|
|
HumanMessage(content="old", id="h1"),
|
|
AIMessage(content="reply", id="a1"),
|
|
],
|
|
**state_extra,
|
|
}
|
|
return create_msg_delete()(state)
|
|
|
|
def test_placeholder_is_not_bare_continue(self):
|
|
result = self._run(
|
|
{"company_of_interest": "EC", "asset_type": "stock", "trade_date": "2026-05-28"}
|
|
)
|
|
placeholder = result["messages"][-1]
|
|
self.assertIsInstance(placeholder, HumanMessage)
|
|
self.assertNotEqual(placeholder.content.strip(), "Continue")
|
|
|
|
def test_placeholder_carries_resolved_identity(self):
|
|
result = self._run(
|
|
{
|
|
"company_of_interest": "EC",
|
|
"instrument_context": "The instrument to analyze is `EC`. Resolved identity: Company: Ecopetrol.",
|
|
"trade_date": "2026-05-28",
|
|
}
|
|
)
|
|
content = result["messages"][-1].content
|
|
self.assertIn("Ecopetrol", content)
|
|
self.assertIn("2026-05-28", content)
|
|
|
|
def test_old_messages_are_removed(self):
|
|
result = self._run({"company_of_interest": "EC", "trade_date": "2026-05-28"})
|
|
removals = [m for m in result["messages"] if isinstance(m, RemoveMessage)]
|
|
humans = [m for m in result["messages"] if isinstance(m, HumanMessage)]
|
|
self.assertEqual(len(removals), 2)
|
|
self.assertEqual(len(humans), 1)
|
|
|
|
def test_safe_defaults_when_state_minimal(self):
|
|
result = create_msg_delete()({"messages": [], "company_of_interest": "EC"})
|
|
placeholder = result["messages"][-1]
|
|
self.assertNotEqual(placeholder.content.strip(), "Continue")
|
|
self.assertIn("EC", placeholder.content)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|