diff --git a/tests/test_memory_log.py b/tests/test_memory_log.py index 5d7f7f844..c468f9998 100644 --- a/tests/test_memory_log.py +++ b/tests/test_memory_log.py @@ -535,6 +535,93 @@ class TestDeferredReflection: assert raw is not None and alpha is not None and days is not None assert days == 2 + # TradingAgentsGraph._resolve_benchmark — picks index for alpha calc + + def test_resolve_benchmark_explicit_override(self): + """config['benchmark_ticker'] wins for every ticker.""" + mock_graph = MagicMock(spec=TradingAgentsGraph) + mock_graph.config = { + "benchmark_ticker": "QQQ", + "benchmark_map": {"": "SPY", ".T": "^N225"}, + } + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.T") == "QQQ" + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "NVDA") == "QQQ" + + def test_resolve_benchmark_suffix_map(self): + """Known suffixes route to their regional index.""" + mock_graph = MagicMock(spec=TradingAgentsGraph) + mock_graph.config = { + "benchmark_ticker": None, + "benchmark_map": { + ".T": "^N225", ".HK": "^HSI", ".NS": "^NSEI", + ".L": "^FTSE", ".TO": "^GSPTSE", ".AX": "^AXJO", + ".BO": "^BSESN", "": "SPY", + }, + } + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.T") == "^N225" + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "0700.HK") == "^HSI" + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "RELIANCE.NS") == "^NSEI" + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "AZN.L") == "^FTSE" + + def test_resolve_benchmark_us_ticker_defaults_to_spy(self): + """US tickers (no dotted suffix) take the empty-suffix entry.""" + mock_graph = MagicMock(spec=TradingAgentsGraph) + mock_graph.config = { + "benchmark_ticker": None, + "benchmark_map": {"": "SPY", ".T": "^N225"}, + } + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "NVDA") == "SPY" + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "AAPL") == "SPY" + + def test_resolve_benchmark_unknown_suffix_falls_back(self): + """Unrecognised suffix (BRK.B, FAKE.XX) falls back to SPY.""" + mock_graph = MagicMock(spec=TradingAgentsGraph) + mock_graph.config = { + "benchmark_ticker": None, + "benchmark_map": {"": "SPY", ".T": "^N225"}, + } + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "FAKE.XX") == "SPY" + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "BRK.B") == "SPY" + + def test_resolve_benchmark_case_insensitive(self): + """Suffix matching is case-insensitive so 7203.t resolves like 7203.T.""" + mock_graph = MagicMock(spec=TradingAgentsGraph) + mock_graph.config = { + "benchmark_ticker": None, + "benchmark_map": {".T": "^N225", "": "SPY"}, + } + assert TradingAgentsGraph._resolve_benchmark(mock_graph, "7203.t") == "^N225" + + def test_reflector_includes_benchmark_in_label(self): + """benchmark_name appears in the prompt label, not 'SPY' hardcoded.""" + mock_llm = MagicMock() + mock_llm.invoke.return_value.content = "Directionally correct." + reflector = Reflector(mock_llm) + reflector.reflect_on_final_decision( + final_decision=DECISION_BUY, + raw_return=0.05, + alpha_return=0.02, + benchmark_name="^N225", + ) + messages = mock_llm.invoke.call_args[0][0] + human_content = next(content for role, content in messages if role == "human") + assert "Alpha vs ^N225:" in human_content + assert "Alpha vs SPY:" not in human_content + + def test_reflector_defaults_to_spy_for_unupdated_callers(self): + """Default benchmark_name keeps the SPY label for legacy callers.""" + mock_llm = MagicMock() + mock_llm.invoke.return_value.content = "ok" + reflector = Reflector(mock_llm) + reflector.reflect_on_final_decision( + final_decision=DECISION_BUY, + raw_return=0.05, + alpha_return=0.02, + ) + messages = mock_llm.invoke.call_args[0][0] + human_content = next(content for role, content in messages if role == "human") + assert "Alpha vs SPY:" in human_content + # TradingAgentsGraph._resolve_pending_entries def test_resolve_skips_other_tickers(self, tmp_path): diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index fe5a6f755..91faa1b02 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -16,6 +16,7 @@ _ENV_OVERRIDES = { "TRADINGAGENTS_MAX_DEBATE_ROUNDS": "max_debate_rounds", "TRADINGAGENTS_MAX_RISK_ROUNDS": "max_risk_discuss_rounds", "TRADINGAGENTS_CHECKPOINT_ENABLED": "checkpoint_enabled", + "TRADINGAGENTS_BENCHMARK_TICKER": "benchmark_ticker", } @@ -100,4 +101,21 @@ DEFAULT_CONFIG = _apply_env_overrides({ "tool_vendors": { # Example: "get_stock_data": "alpha_vantage", # Override category default }, + # Benchmark for alpha calculation in the reflection layer. + # ``benchmark_ticker`` (when set) overrides the suffix map for all + # tickers; leave it None to use ``benchmark_map`` for auto-detection + # based on the ticker's exchange suffix. SPY remains the US default + # so the reflection label keeps reading "Alpha vs SPY" for US tickers + # while non-US tickers get their regional index automatically. + "benchmark_ticker": None, + "benchmark_map": { + ".NS": "^NSEI", # NSE India (Nifty 50) + ".BO": "^BSESN", # BSE India (Sensex) + ".T": "^N225", # Tokyo (Nikkei 225) + ".HK": "^HSI", # Hong Kong (Hang Seng) + ".L": "^FTSE", # London (FTSE 100) + ".TO": "^GSPTSE", # Toronto (TSX Composite) + ".AX": "^AXJO", # Australia (ASX 200) + "": "SPY", # default for US-listed tickers (no suffix) + }, }) diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 813114428..0685941fe 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -33,11 +33,15 @@ class Reflector: final_decision: str, raw_return: float, alpha_return: float, + benchmark_name: str = "SPY", ) -> str: """Single reflection call on the final trade decision with outcome context. Used by Phase B deferred reflection. The final_trade_decision already synthesises all analyst insights, so no separate market context is needed. + ``benchmark_name`` is the label used for the alpha line (e.g. ``"SPY"`` + for US tickers, ``"^N225"`` for ``.T`` listings); defaults to SPY for + callers that haven't been updated to thread the benchmark through. """ messages = [ ("system", self.log_reflection_prompt), @@ -45,7 +49,7 @@ class Reflector: "human", ( f"Raw return: {raw_return:+.1%}\n" - f"Alpha vs SPY: {alpha_return:+.1%}\n\n" + f"Alpha vs {benchmark_name}: {alpha_return:+.1%}\n\n" f"Final Decision:\n{final_decision}" ), ), diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 949dbf654..c0d8ecdd9 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -190,14 +190,37 @@ class TradingAgentsGraph: ), } + def _resolve_benchmark(self, ticker: str) -> str: + """Pick the benchmark ticker for alpha calculation against ``ticker``. + + ``config["benchmark_ticker"]`` overrides everything when set; otherwise + the suffix map matches the ticker's exchange suffix (e.g. ``.T`` for + Tokyo). US-listed tickers without a dotted suffix fall through to the + empty-suffix entry (SPY by default). Unrecognised suffixes (including + US tickers with dots like ``BRK.B``) also fall back to the empty-suffix + entry, which is the right default because the alpha calculation works + in USD. + """ + explicit = self.config.get("benchmark_ticker") + if explicit: + return explicit + benchmark_map = self.config.get("benchmark_map", {}) + ticker_upper = ticker.upper() + for suffix, benchmark in benchmark_map.items(): + if suffix and ticker_upper.endswith(suffix.upper()): + return benchmark + return benchmark_map.get("", "SPY") + def _fetch_returns( - self, ticker: str, trade_date: str, holding_days: int = 5 + self, ticker: str, trade_date: str, holding_days: int = 5, + benchmark: str = "SPY", ) -> Tuple[Optional[float], Optional[float], Optional[int]]: """Fetch raw and alpha return for ticker over holding_days from trade_date. - Returns (raw_return, alpha_return, actual_holding_days) or - (None, None, None) if price data is unavailable (too recent, delisted, - or network error). + ``benchmark`` is the index used as the alpha baseline (resolved by the + caller via ``_resolve_benchmark``). Returns ``(raw_return, alpha_return, + actual_holding_days)`` or ``(None, None, None)`` if price data is + unavailable (too recent, delisted, or network error). """ try: start = datetime.strptime(trade_date, "%Y-%m-%d") @@ -205,26 +228,26 @@ class TradingAgentsGraph: end_str = end.strftime("%Y-%m-%d") stock = yf.Ticker(ticker).history(start=trade_date, end=end_str) - spy = yf.Ticker("SPY").history(start=trade_date, end=end_str) + bench = yf.Ticker(benchmark).history(start=trade_date, end=end_str) - if len(stock) < 2 or len(spy) < 2: + if len(stock) < 2 or len(bench) < 2: return None, None, None - actual_days = min(holding_days, len(stock) - 1, len(spy) - 1) + actual_days = min(holding_days, len(stock) - 1, len(bench) - 1) raw = float( (stock["Close"].iloc[actual_days] - stock["Close"].iloc[0]) / stock["Close"].iloc[0] ) - spy_ret = float( - (spy["Close"].iloc[actual_days] - spy["Close"].iloc[0]) - / spy["Close"].iloc[0] + bench_ret = float( + (bench["Close"].iloc[actual_days] - bench["Close"].iloc[0]) + / bench["Close"].iloc[0] ) - alpha = raw - spy_ret + alpha = raw - bench_ret return raw, alpha, actual_days except Exception as e: logger.warning( - "Could not resolve outcome for %s on %s (will retry next run): %s", - ticker, trade_date, e, + "Could not resolve outcome for %s on %s vs %s (will retry next run): %s", + ticker, trade_date, benchmark, e, ) return None, None, None @@ -242,15 +265,19 @@ class TradingAgentsGraph: if not pending: return + benchmark = self._resolve_benchmark(ticker) updates = [] for entry in pending: - raw, alpha, days = self._fetch_returns(ticker, entry["date"]) + raw, alpha, days = self._fetch_returns( + ticker, entry["date"], benchmark=benchmark, + ) if raw is None: continue # price not available yet — try again next run reflection = self.reflector.reflect_on_final_decision( final_decision=entry.get("decision", ""), raw_return=raw, alpha_return=alpha, + benchmark_name=benchmark, ) updates.append({ "ticker": ticker,