mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
Merge remote-tracking branch 'upstream/main' into analyst-phase1-observability
# Conflicts: # tradingagents/default_config.py # tradingagents/graph/setup.py
This commit is contained in:
46
tests/conftest.py
Normal file
46
tests/conftest.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Shared pytest fixtures that prevent CI hangs when API keys are absent."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for marker in ("unit", "integration", "smoke"):
|
||||
config.addinivalue_line("markers", f"{marker}: {marker}-level tests")
|
||||
|
||||
|
||||
_API_KEY_ENV_VARS = (
|
||||
"OPENAI_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"DEEPSEEK_API_KEY",
|
||||
"DASHSCOPE_API_KEY",
|
||||
"DASHSCOPE_CN_API_KEY",
|
||||
"ZHIPU_API_KEY",
|
||||
"ZHIPU_CN_API_KEY",
|
||||
"MINIMAX_API_KEY",
|
||||
"MINIMAX_CN_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"ALPHA_VANTAGE_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _dummy_api_keys(monkeypatch):
|
||||
for env_var in _API_KEY_ENV_VARS:
|
||||
monkeypatch.setenv(env_var, os.environ.get(env_var, "placeholder"))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_llm_client():
|
||||
client = MagicMock()
|
||||
client.get_llm.return_value = MagicMock()
|
||||
with patch(
|
||||
"tradingagents.llm_clients.factory.create_llm_client",
|
||||
return_value=client,
|
||||
):
|
||||
yield client
|
||||
149
tests/test_api_key_env.py
Normal file
149
tests/test_api_key_env.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Tests for the canonical provider->env-var mapping and the CLI key-prompt helper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
|
||||
|
||||
|
||||
# ---- Mapping coverage -----------------------------------------------------
|
||||
|
||||
|
||||
def test_every_select_llm_provider_choice_has_an_entry():
|
||||
"""select_llm_provider() must not present a provider the mapping doesn't know about."""
|
||||
# Mirrors the dropdown order in cli/utils.select_llm_provider so the two
|
||||
# stay in lockstep. Region-specific keys (qwen-cn / minimax-cn / glm-cn)
|
||||
# are reached via the secondary region prompt, so they must also be present.
|
||||
expected = {
|
||||
"openai", "google", "anthropic", "xai", "deepseek",
|
||||
"qwen", "qwen-cn",
|
||||
"glm", "glm-cn",
|
||||
"minimax", "minimax-cn",
|
||||
"openrouter", "azure", "ollama",
|
||||
}
|
||||
assert expected.issubset(PROVIDER_API_KEY_ENV.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider,env_var",
|
||||
[
|
||||
("openai", "OPENAI_API_KEY"),
|
||||
("anthropic", "ANTHROPIC_API_KEY"),
|
||||
("google", "GOOGLE_API_KEY"),
|
||||
("azure", "AZURE_OPENAI_API_KEY"),
|
||||
("xai", "XAI_API_KEY"),
|
||||
("deepseek", "DEEPSEEK_API_KEY"),
|
||||
("qwen", "DASHSCOPE_API_KEY"),
|
||||
("qwen-cn", "DASHSCOPE_CN_API_KEY"),
|
||||
("glm", "ZHIPU_API_KEY"),
|
||||
("glm-cn", "ZHIPU_CN_API_KEY"),
|
||||
("minimax", "MINIMAX_API_KEY"),
|
||||
("minimax-cn", "MINIMAX_CN_API_KEY"),
|
||||
("openrouter", "OPENROUTER_API_KEY"),
|
||||
],
|
||||
)
|
||||
def test_known_providers_resolve(provider, env_var):
|
||||
assert get_api_key_env(provider) == env_var
|
||||
|
||||
|
||||
def test_ollama_has_no_key():
|
||||
assert get_api_key_env("ollama") is None
|
||||
|
||||
|
||||
def test_unknown_provider_returns_none():
|
||||
assert get_api_key_env("not-a-real-provider") is None
|
||||
|
||||
|
||||
def test_case_insensitive_lookup():
|
||||
assert get_api_key_env("OpenAI") == "OPENAI_API_KEY"
|
||||
assert get_api_key_env("QWEN-CN") == "DASHSCOPE_CN_API_KEY"
|
||||
|
||||
|
||||
# ---- ensure_api_key behavior ---------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_utils(monkeypatch):
|
||||
"""Import cli.utils with a fresh environment so module-level state is consistent."""
|
||||
import importlib
|
||||
import cli.utils as cli_utils_module
|
||||
return importlib.reload(cli_utils_module)
|
||||
|
||||
|
||||
def test_ensure_api_key_returns_existing(monkeypatch, cli_utils):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-already-set")
|
||||
result = cli_utils.ensure_api_key("openai")
|
||||
assert result == "sk-already-set"
|
||||
|
||||
|
||||
def test_ensure_api_key_no_op_for_ollama(monkeypatch, cli_utils):
|
||||
# Even with no env var set, ollama should not prompt and should return None.
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with patch.object(cli_utils, "questionary") as mock_q:
|
||||
result = cli_utils.ensure_api_key("ollama")
|
||||
assert result is None
|
||||
mock_q.password.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_api_key_unknown_provider_no_prompt(monkeypatch, cli_utils):
|
||||
with patch.object(cli_utils, "questionary") as mock_q:
|
||||
result = cli_utils.ensure_api_key("totally-fake-provider")
|
||||
assert result is None
|
||||
mock_q.password.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_api_key_prompts_and_writes_to_env(monkeypatch, tmp_path, cli_utils):
|
||||
"""When key is missing, user-pasted value must be written to .env AND os.environ."""
|
||||
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-deepseek-test")})()
|
||||
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||
result = cli_utils.ensure_api_key("deepseek")
|
||||
|
||||
assert result == "sk-deepseek-test"
|
||||
assert os.environ["DEEPSEEK_API_KEY"] == "sk-deepseek-test"
|
||||
env_file = tmp_path / ".env"
|
||||
assert env_file.exists()
|
||||
assert "DEEPSEEK_API_KEY" in env_file.read_text()
|
||||
assert "sk-deepseek-test" in env_file.read_text()
|
||||
|
||||
|
||||
def test_ensure_api_key_user_cancels_returns_none(monkeypatch, tmp_path, cli_utils):
|
||||
"""Empty prompt response (user cancelled) must not write to .env."""
|
||||
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
fake_prompt = type("P", (), {"ask": staticmethod(lambda: None)})()
|
||||
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||
result = cli_utils.ensure_api_key("xai")
|
||||
|
||||
assert result is None
|
||||
assert "XAI_API_KEY" not in os.environ
|
||||
# .env may or may not exist depending on find_dotenv's walk, but if it
|
||||
# does it must not contain the key.
|
||||
env_file = tmp_path / ".env"
|
||||
if env_file.exists():
|
||||
assert "XAI_API_KEY" not in env_file.read_text()
|
||||
|
||||
|
||||
def test_ensure_api_key_updates_existing_env_file(monkeypatch, tmp_path, cli_utils):
|
||||
"""An existing .env with other keys must be preserved on writeback."""
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("OPENAI_API_KEY=sk-existing\nOTHER=value\n")
|
||||
|
||||
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-openrouter-new")})()
|
||||
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||
cli_utils.ensure_api_key("openrouter")
|
||||
|
||||
content = env_file.read_text()
|
||||
assert "OPENAI_API_KEY" in content and "sk-existing" in content
|
||||
assert "OTHER=value" in content
|
||||
assert "OPENROUTER_API_KEY" in content and "sk-openrouter-new" in content
|
||||
107
tests/test_capabilities.py
Normal file
107
tests/test_capabilities.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Unit tests for the LLM capability table."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.capabilities import (
|
||||
ModelCapabilities,
|
||||
get_capabilities,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExactIdMatches:
|
||||
def test_deepseek_chat_supports_tool_choice(self):
|
||||
caps = get_capabilities("deepseek-chat")
|
||||
assert caps.supports_tool_choice is True
|
||||
|
||||
def test_deepseek_reasoner_rejects_tool_choice(self):
|
||||
caps = get_capabilities("deepseek-reasoner")
|
||||
assert caps.supports_tool_choice is False
|
||||
assert caps.requires_reasoning_content_roundtrip is True
|
||||
|
||||
def test_deepseek_v4_flash_rejects_tool_choice(self):
|
||||
caps = get_capabilities("deepseek-v4-flash")
|
||||
assert caps.supports_tool_choice is False
|
||||
assert caps.requires_reasoning_content_roundtrip is True
|
||||
|
||||
def test_deepseek_v4_pro_rejects_tool_choice(self):
|
||||
caps = get_capabilities("deepseek-v4-pro")
|
||||
assert caps.supports_tool_choice is False
|
||||
assert caps.requires_reasoning_content_roundtrip is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPatternMatches:
|
||||
"""Forward-compat regex patterns catch unknown DeepSeek and MiniMax variants."""
|
||||
|
||||
def test_future_deepseek_v5_inherits_thinking_quirks(self):
|
||||
caps = get_capabilities("deepseek-v5-flash")
|
||||
assert caps.supports_tool_choice is False
|
||||
assert caps.requires_reasoning_content_roundtrip is True
|
||||
|
||||
def test_future_deepseek_v9_inherits_thinking_quirks(self):
|
||||
caps = get_capabilities("deepseek-v9-anything")
|
||||
assert caps.supports_tool_choice is False
|
||||
|
||||
def test_reasoner_variant_inherits_thinking_quirks(self):
|
||||
caps = get_capabilities("deepseek-reasoner-pro")
|
||||
assert caps.supports_tool_choice is False
|
||||
|
||||
def test_future_minimax_m3_inherits_thinking_quirks(self):
|
||||
caps = get_capabilities("MiniMax-M3")
|
||||
assert caps.supports_tool_choice is False
|
||||
|
||||
def test_future_minimax_m4_highspeed_inherits_thinking_quirks(self):
|
||||
caps = get_capabilities("MiniMax-M4-highspeed")
|
||||
assert caps.supports_tool_choice is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMinimaxExactMatches:
|
||||
"""MiniMax M2.x models reject langchain's function-spec dict tool_choice
|
||||
(official API enum: none/auto only)."""
|
||||
|
||||
def test_m2_7_rejects_tool_choice(self):
|
||||
caps = get_capabilities("MiniMax-M2.7")
|
||||
assert caps.supports_tool_choice is False
|
||||
assert caps.supports_json_mode is False # only MiniMax-Text-01 supports json_object
|
||||
|
||||
def test_m2_7_highspeed_rejects_tool_choice(self):
|
||||
assert get_capabilities("MiniMax-M2.7-highspeed").supports_tool_choice is False
|
||||
|
||||
def test_m2_1_rejects_tool_choice(self):
|
||||
assert get_capabilities("MiniMax-M2.1").supports_tool_choice is False
|
||||
|
||||
def test_m2_base_rejects_tool_choice(self):
|
||||
assert get_capabilities("MiniMax-M2").supports_tool_choice is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefault:
|
||||
"""Unknown / non-DeepSeek models get the permissive default."""
|
||||
|
||||
def test_gpt_default(self):
|
||||
caps = get_capabilities("gpt-4.1")
|
||||
assert caps.supports_tool_choice is True
|
||||
assert caps.preferred_structured_method == "function_calling"
|
||||
|
||||
def test_grok_default(self):
|
||||
caps = get_capabilities("grok-4-0709")
|
||||
assert caps.supports_tool_choice is True
|
||||
|
||||
def test_unknown_model_default(self):
|
||||
caps = get_capabilities("totally-made-up-model-id")
|
||||
assert caps.supports_tool_choice is True
|
||||
|
||||
def test_exact_match_precedes_pattern(self):
|
||||
"""deepseek-chat must NOT match the v\\d regex."""
|
||||
caps = get_capabilities("deepseek-chat")
|
||||
assert caps.supports_tool_choice is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_capabilities_dataclass_is_frozen():
|
||||
"""Capability rows are immutable so they can be safely shared."""
|
||||
caps = get_capabilities("deepseek-chat")
|
||||
with pytest.raises(Exception):
|
||||
caps.supports_tool_choice = False # type: ignore[misc]
|
||||
147
tests/test_checkpoint_resume.py
Normal file
147
tests/test_checkpoint_resume.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from tradingagents.graph.checkpointer import (
|
||||
checkpoint_step,
|
||||
clear_checkpoint,
|
||||
get_checkpointer,
|
||||
has_checkpoint,
|
||||
thread_id,
|
||||
)
|
||||
|
||||
# Mutable flag to simulate crash on first run
|
||||
_should_crash = False
|
||||
|
||||
|
||||
class _SimpleState(TypedDict):
|
||||
count: int
|
||||
|
||||
|
||||
def _node_a(state: _SimpleState) -> dict:
|
||||
return {"count": state["count"] + 1}
|
||||
|
||||
|
||||
def _node_b(state: _SimpleState) -> dict:
|
||||
if _should_crash:
|
||||
raise RuntimeError("simulated mid-analysis crash")
|
||||
return {"count": state["count"] + 10}
|
||||
|
||||
|
||||
def _build_graph() -> StateGraph:
|
||||
builder = StateGraph(_SimpleState)
|
||||
builder.add_node("analyst", _node_a)
|
||||
builder.add_node("trader", _node_b)
|
||||
builder.set_entry_point("analyst")
|
||||
builder.add_edge("analyst", "trader")
|
||||
builder.add_edge("trader", END)
|
||||
return builder
|
||||
|
||||
|
||||
class TestCheckpointResume(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.ticker = "TEST"
|
||||
self.date = "2026-04-20"
|
||||
|
||||
def test_crash_and_resume(self):
|
||||
"""Crash at 'trader' node, then resume from checkpoint."""
|
||||
global _should_crash
|
||||
builder = _build_graph()
|
||||
tid = thread_id(self.ticker, self.date)
|
||||
cfg = {"configurable": {"thread_id": tid}}
|
||||
|
||||
# Run 1: crash at trader node
|
||||
_should_crash = True
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
# Checkpoint should exist at step 1 (analyst completed)
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
|
||||
self.assertEqual(step, 1)
|
||||
|
||||
# Run 2: resume — trader succeeds this time
|
||||
_should_crash = False
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
result = graph.invoke(None, config=cfg)
|
||||
|
||||
# analyst added 1, trader added 10 → 11
|
||||
self.assertEqual(result["count"], 11)
|
||||
|
||||
def test_clear_checkpoint_allows_fresh_start(self):
|
||||
"""After clearing, the graph starts from scratch."""
|
||||
global _should_crash
|
||||
builder = _build_graph()
|
||||
tid = thread_id(self.ticker, self.date)
|
||||
cfg = {"configurable": {"thread_id": tid}}
|
||||
|
||||
# Create a checkpoint by crashing
|
||||
_should_crash = True
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
# Clear it
|
||||
clear_checkpoint(self.tmpdir, self.ticker, self.date)
|
||||
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
# Fresh run succeeds from scratch
|
||||
_should_crash = False
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
result = graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
self.assertEqual(result["count"], 11)
|
||||
|
||||
|
||||
def test_different_date_starts_fresh(self):
|
||||
"""A different date must NOT resume from an existing checkpoint."""
|
||||
global _should_crash
|
||||
builder = _build_graph()
|
||||
date2 = "2026-04-21"
|
||||
|
||||
# Run with date1 — crash to leave a checkpoint
|
||||
_should_crash = True
|
||||
tid1 = thread_id(self.ticker, self.date)
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
|
||||
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
# date2 should have no checkpoint
|
||||
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
|
||||
|
||||
# Run with date2 — should start fresh and succeed
|
||||
_should_crash = False
|
||||
tid2 = thread_id(self.ticker, date2)
|
||||
self.assertNotEqual(tid1, tid2)
|
||||
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
|
||||
|
||||
# Fresh run: analyst +1, trader +10 = 11
|
||||
self.assertEqual(result["count"], 11)
|
||||
|
||||
# Original date checkpoint still exists (untouched)
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
61
tests/test_dataflows_config.py
Normal file
61
tests/test_dataflows_config.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Config isolation: get/set must not leak nested-dict references."""
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
import tradingagents.default_config as default_config
|
||||
from tradingagents.dataflows.config import get_config, set_config
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class DataflowsConfigIsolationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
set_config(copy.deepcopy(default_config.DEFAULT_CONFIG))
|
||||
|
||||
def test_get_config_returns_deep_copy(self):
|
||||
cfg = get_config()
|
||||
cfg["data_vendors"]["core_stock_apis"] = "alpha_vantage"
|
||||
cfg["tool_vendors"]["get_stock_data"] = "alpha_vantage"
|
||||
|
||||
fresh = get_config()
|
||||
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "yfinance")
|
||||
self.assertNotIn("get_stock_data", fresh["tool_vendors"])
|
||||
|
||||
def test_set_config_does_not_alias_caller_nested_dicts(self):
|
||||
custom = copy.deepcopy(default_config.DEFAULT_CONFIG)
|
||||
custom["data_vendors"]["core_stock_apis"] = "alpha_vantage"
|
||||
custom["tool_vendors"]["get_stock_data"] = "alpha_vantage"
|
||||
|
||||
set_config(custom)
|
||||
|
||||
custom["data_vendors"]["core_stock_apis"] = "yfinance"
|
||||
custom["tool_vendors"]["get_stock_data"] = "yfinance"
|
||||
|
||||
fresh = get_config()
|
||||
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
|
||||
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
|
||||
|
||||
def test_partial_nested_update_preserves_existing_defaults(self):
|
||||
set_config(
|
||||
{
|
||||
"data_vendors": {
|
||||
"core_stock_apis": "alpha_vantage",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
fresh = get_config()
|
||||
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
|
||||
self.assertEqual(fresh["data_vendors"]["technical_indicators"], "yfinance")
|
||||
self.assertEqual(fresh["data_vendors"]["fundamental_data"], "yfinance")
|
||||
self.assertEqual(fresh["data_vendors"]["news_data"], "yfinance")
|
||||
|
||||
def test_nested_dict_updates_merge_one_level_deep(self):
|
||||
set_config({"tool_vendors": {"get_stock_data": "alpha_vantage"}})
|
||||
set_config({"tool_vendors": {"get_news": "alpha_vantage"}})
|
||||
|
||||
fresh = get_config()
|
||||
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
|
||||
self.assertEqual(fresh["tool_vendors"]["get_news"], "alpha_vantage")
|
||||
240
tests/test_deepseek_reasoning.py
Normal file
240
tests/test_deepseek_reasoning.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Tests for DeepSeekChatOpenAI thinking-mode behaviour.
|
||||
|
||||
Two pieces verified:
|
||||
|
||||
1. ``reasoning_content`` is captured on receive into the AIMessage's
|
||||
``additional_kwargs`` and re-attached on send so DeepSeek's API
|
||||
sees the same value across turns.
|
||||
2. ``with_structured_output`` consults the capability table and
|
||||
suppresses ``tool_choice`` for models that reject it (V4 + reasoner),
|
||||
matching DeepSeek's official tool-calling pattern at
|
||||
https://api-docs.deepseek.com/guides/tool_calls.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.prompt_values import ChatPromptValue
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tradingagents.llm_clients.openai_client import (
|
||||
DeepSeekChatOpenAI,
|
||||
NormalizedChatOpenAI,
|
||||
_input_to_messages,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
||||
# (Gemini bot review note: non-list inputs must also work)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInputToMessages:
|
||||
def test_list_input_returned_as_is(self):
|
||||
msgs = [HumanMessage(content="hi")]
|
||||
assert _input_to_messages(msgs) is msgs
|
||||
|
||||
def test_chat_prompt_value_unwrapped(self):
|
||||
msgs = [HumanMessage(content="hi")]
|
||||
prompt_value = ChatPromptValue(messages=msgs)
|
||||
assert _input_to_messages(prompt_value) == msgs
|
||||
|
||||
def test_string_input_yields_empty_list(self):
|
||||
# A bare string isn't a message-bearing input; the caller's normal
|
||||
# langchain conversion happens upstream of _get_request_payload.
|
||||
assert _input_to_messages("hello") == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reasoning content propagation across turns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeepSeekReasoningContent:
|
||||
def _client(self):
|
||||
os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder")
|
||||
return DeepSeekChatOpenAI(
|
||||
model="deepseek-v4-flash",
|
||||
api_key="placeholder",
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
|
||||
def test_capture_on_receive(self):
|
||||
"""When the response carries reasoning_content, it lands on the
|
||||
AIMessage's additional_kwargs so the next turn can echo it back."""
|
||||
client = self._client()
|
||||
result = client._create_chat_result(
|
||||
{
|
||||
"model": "deepseek-v4-flash",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Plan: buy NVDA.",
|
||||
"reasoning_content": "Step 1: trend is up. Step 2: ...",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
)
|
||||
ai = result.generations[0].message
|
||||
assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..."
|
||||
|
||||
def test_propagate_on_send(self):
|
||||
"""When an outgoing AIMessage carries reasoning_content, the request
|
||||
payload echoes it on the corresponding message dict."""
|
||||
client = self._client()
|
||||
prior = AIMessage(
|
||||
content="Plan",
|
||||
additional_kwargs={"reasoning_content": "weighed bull case"},
|
||||
)
|
||||
new_user = HumanMessage(content="Refine.")
|
||||
payload = client._get_request_payload([prior, new_user])
|
||||
# Find the assistant message in the payload
|
||||
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
|
||||
assert assistant_dicts, "assistant message missing from outgoing payload"
|
||||
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
|
||||
|
||||
def test_propagate_through_chat_prompt_value(self):
|
||||
"""Gemini bot review note: non-list inputs (ChatPromptValue) must
|
||||
also propagate reasoning_content."""
|
||||
client = self._client()
|
||||
prior = AIMessage(
|
||||
content="Plan",
|
||||
additional_kwargs={"reasoning_content": "weighed bull case"},
|
||||
)
|
||||
prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")])
|
||||
payload = client._get_request_payload(prompt_value)
|
||||
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
|
||||
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Capability-driven structured output: tool_choice suppressed for V4 + reasoner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _bound_kwargs(runnable):
|
||||
"""Extract bind() kwargs from a with_structured_output result."""
|
||||
first = runnable.steps[0] if hasattr(runnable, "steps") else runnable
|
||||
return getattr(first, "kwargs", {})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStructuredOutputCapabilityDispatch:
|
||||
"""DeepSeek V4 and reasoner reject the tool_choice parameter
|
||||
(official guide: api-docs.deepseek.com/guides/tool_calls passes
|
||||
tools=[...] without tool_choice). Verify the capability dispatch
|
||||
suppresses tool_choice for those models and sends it for chat."""
|
||||
|
||||
class _Sample(BaseModel):
|
||||
answer: str
|
||||
|
||||
def _client(self, model):
|
||||
return DeepSeekChatOpenAI(
|
||||
model=model, api_key="placeholder", base_url="https://api.deepseek.com",
|
||||
)
|
||||
|
||||
def test_chat_sends_tool_choice(self):
|
||||
bound = self._client("deepseek-chat").with_structured_output(self._Sample)
|
||||
assert _bound_kwargs(bound).get("tool_choice") is not None
|
||||
|
||||
def test_reasoner_suppresses_tool_choice(self):
|
||||
bound = self._client("deepseek-reasoner").with_structured_output(self._Sample)
|
||||
# tool_choice is either absent or explicitly None — both are valid
|
||||
# signals that langchain's bind_tools will skip the parameter.
|
||||
assert _bound_kwargs(bound).get("tool_choice") in (None, ...) or \
|
||||
"tool_choice" not in _bound_kwargs(bound)
|
||||
|
||||
def test_v4_flash_suppresses_tool_choice(self):
|
||||
bound = self._client("deepseek-v4-flash").with_structured_output(self._Sample)
|
||||
assert _bound_kwargs(bound).get("tool_choice") is None or \
|
||||
"tool_choice" not in _bound_kwargs(bound)
|
||||
|
||||
def test_v4_pro_suppresses_tool_choice(self):
|
||||
bound = self._client("deepseek-v4-pro").with_structured_output(self._Sample)
|
||||
assert _bound_kwargs(bound).get("tool_choice") is None or \
|
||||
"tool_choice" not in _bound_kwargs(bound)
|
||||
|
||||
def test_future_v_variant_via_regex(self):
|
||||
"""Forward-compat: unknown deepseek-v\\d-* IDs inherit V4 quirks."""
|
||||
bound = self._client("deepseek-v5-hypothetical").with_structured_output(self._Sample)
|
||||
assert _bound_kwargs(bound).get("tool_choice") is None or \
|
||||
"tool_choice" not in _bound_kwargs(bound)
|
||||
|
||||
def test_schema_is_still_bound_as_tool(self):
|
||||
"""tool_choice is suppressed, but the schema is still bound as a tool —
|
||||
exactly matching DeepSeek's official tool-calling examples."""
|
||||
bound = self._client("deepseek-reasoner").with_structured_output(self._Sample)
|
||||
kwargs = _bound_kwargs(bound)
|
||||
tools = kwargs.get("tools", [])
|
||||
assert any(
|
||||
t.get("function", {}).get("name") == "_Sample" for t in tools
|
||||
), f"schema not bound as a tool: {tools}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Live API: structured output round-trips against the real DeepSeek backend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _has_real_deepseek_key():
|
||||
key = os.environ.get("DEEPSEEK_API_KEY", "")
|
||||
return bool(key) and key != "placeholder"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not _has_real_deepseek_key(),
|
||||
reason="DEEPSEEK_API_KEY not set (or placeholder); skipping live API call",
|
||||
)
|
||||
class TestDeepSeekLiveStructuredOutput:
|
||||
"""End-to-end: a real DeepSeek V4-flash call returns a typed instance.
|
||||
|
||||
Verifies the no-tool_choice path doesn't trigger the 400 reported in
|
||||
issue #678 and that the structured-output binding still parses to a
|
||||
Pydantic instance.
|
||||
"""
|
||||
|
||||
class _Pick(BaseModel):
|
||||
action: str
|
||||
confidence: float
|
||||
|
||||
def test_v4_flash_returns_structured_output(self):
|
||||
client = DeepSeekChatOpenAI(
|
||||
model="deepseek-v4-flash",
|
||||
api_key=os.environ["DEEPSEEK_API_KEY"],
|
||||
base_url="https://api.deepseek.com",
|
||||
timeout=60,
|
||||
)
|
||||
bound = client.with_structured_output(self._Pick)
|
||||
result = bound.invoke(
|
||||
"Pick BUY or SELL or HOLD for a tech stock with strong earnings. "
|
||||
"Confidence is a float between 0 and 1."
|
||||
)
|
||||
assert isinstance(result, self._Pick)
|
||||
assert result.action in {"BUY", "SELL", "HOLD"}
|
||||
assert 0.0 <= result.confidence <= 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseClassIsolation:
|
||||
def test_normalized_does_not_propagate_reasoning_content(self):
|
||||
"""The general-purpose NormalizedChatOpenAI must not carry
|
||||
DeepSeek-specific behaviour. Only the subclass does."""
|
||||
assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or (
|
||||
NormalizedChatOpenAI._get_request_payload
|
||||
is NormalizedChatOpenAI.__bases__[0]._get_request_payload
|
||||
)
|
||||
98
tests/test_env_overrides.py
Normal file
98
tests/test_env_overrides.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Tests for TRADINGAGENTS_* env-var overlay onto DEFAULT_CONFIG."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
import tradingagents.default_config as default_config_module
|
||||
|
||||
|
||||
def _reload_with_env(monkeypatch, **overrides):
|
||||
"""Set/clear env vars then reload default_config to re-evaluate DEFAULT_CONFIG."""
|
||||
for key in list(default_config_module._ENV_OVERRIDES):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
for key, val in overrides.items():
|
||||
monkeypatch.setenv(key, val)
|
||||
return importlib.reload(default_config_module)
|
||||
|
||||
|
||||
def test_no_env_uses_built_in_defaults(monkeypatch):
|
||||
dc = _reload_with_env(monkeypatch)
|
||||
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
|
||||
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gpt-5.4"
|
||||
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gpt-5.4-mini"
|
||||
assert dc.DEFAULT_CONFIG["backend_url"] is None
|
||||
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
|
||||
assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is False
|
||||
|
||||
|
||||
def test_string_overrides(monkeypatch):
|
||||
dc = _reload_with_env(
|
||||
monkeypatch,
|
||||
TRADINGAGENTS_LLM_PROVIDER="google",
|
||||
TRADINGAGENTS_DEEP_THINK_LLM="gemini-3-pro-preview",
|
||||
TRADINGAGENTS_QUICK_THINK_LLM="gemini-3-flash-preview",
|
||||
TRADINGAGENTS_LLM_BACKEND_URL="https://example.invalid/v1",
|
||||
TRADINGAGENTS_OUTPUT_LANGUAGE="Chinese",
|
||||
)
|
||||
assert dc.DEFAULT_CONFIG["llm_provider"] == "google"
|
||||
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gemini-3-pro-preview"
|
||||
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gemini-3-flash-preview"
|
||||
assert dc.DEFAULT_CONFIG["backend_url"] == "https://example.invalid/v1"
|
||||
assert dc.DEFAULT_CONFIG["output_language"] == "Chinese"
|
||||
|
||||
|
||||
def test_int_coercion(monkeypatch):
|
||||
dc = _reload_with_env(
|
||||
monkeypatch,
|
||||
TRADINGAGENTS_MAX_DEBATE_ROUNDS="3",
|
||||
TRADINGAGENTS_MAX_RISK_ROUNDS="2",
|
||||
)
|
||||
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 3
|
||||
assert isinstance(dc.DEFAULT_CONFIG["max_debate_rounds"], int)
|
||||
assert dc.DEFAULT_CONFIG["max_risk_discuss_rounds"] == 2
|
||||
assert isinstance(dc.DEFAULT_CONFIG["max_risk_discuss_rounds"], int)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw,expected",
|
||||
[
|
||||
("true", True), ("True", True), ("1", True), ("yes", True), ("on", True),
|
||||
("false", False), ("False", False), ("0", False), ("no", False), ("off", False),
|
||||
],
|
||||
)
|
||||
def test_bool_coercion(monkeypatch, raw, expected):
|
||||
dc = _reload_with_env(monkeypatch, TRADINGAGENTS_CHECKPOINT_ENABLED=raw)
|
||||
assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is expected
|
||||
|
||||
|
||||
def test_empty_env_value_is_passthrough(monkeypatch):
|
||||
"""Empty TRADINGAGENTS_* values must not clobber the built-in default."""
|
||||
dc = _reload_with_env(
|
||||
monkeypatch,
|
||||
TRADINGAGENTS_LLM_PROVIDER="",
|
||||
TRADINGAGENTS_MAX_DEBATE_ROUNDS="",
|
||||
)
|
||||
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
|
||||
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
|
||||
|
||||
|
||||
def test_invalid_int_raises(monkeypatch):
|
||||
"""Garbage int values should surface a ValueError at import, not silently misconfigure."""
|
||||
monkeypatch.setenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", "not-a-number")
|
||||
with pytest.raises(ValueError):
|
||||
importlib.reload(default_config_module)
|
||||
# Restore module state for subsequent tests in this process
|
||||
monkeypatch.delenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", raising=False)
|
||||
importlib.reload(default_config_module)
|
||||
|
||||
|
||||
def test_unknown_env_var_is_ignored(monkeypatch):
|
||||
"""Env vars outside _ENV_OVERRIDES must not bleed into DEFAULT_CONFIG."""
|
||||
dc = _reload_with_env(
|
||||
monkeypatch,
|
||||
TRADINGAGENTS_NONEXISTENT_KEY="oops",
|
||||
)
|
||||
assert "nonexistent_key" not in dc.DEFAULT_CONFIG
|
||||
@@ -1,9 +1,12 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.google_client import GoogleClient
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGoogleApiKeyStandardization(unittest.TestCase):
|
||||
"""Verify GoogleClient accepts unified api_key parameter."""
|
||||
|
||||
|
||||
773
tests/test_memory_log.py
Normal file
773
tests/test_memory_log.py
Normal file
@@ -0,0 +1,773 @@
|
||||
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
|
||||
from tradingagents.graph.reflection import Reflector
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.graph.propagation import Propagator
|
||||
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
|
||||
|
||||
_SEP = TradingMemoryLog._SEPARATOR
|
||||
|
||||
DECISION_BUY = "Rating: Buy\nEnter at $189-192, 6% portfolio cap."
|
||||
DECISION_OVERWEIGHT = (
|
||||
"Rating: Overweight\n"
|
||||
"Executive Summary: Moderate position, await confirmation.\n"
|
||||
"Investment Thesis: Strong fundamentals but near-term headwinds."
|
||||
)
|
||||
DECISION_SELL = "Rating: Sell\nExit position immediately."
|
||||
DECISION_NO_RATING = (
|
||||
"Executive Summary: Complex situation with multiple competing factors.\n"
|
||||
"Investment Thesis: No clear directional signal at this time."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_log(tmp_path, filename="trading_memory.md"):
|
||||
config = {"memory_log_path": str(tmp_path / filename)}
|
||||
return TradingMemoryLog(config)
|
||||
|
||||
|
||||
def _seed_completed(tmp_path, ticker, date, decision_text, reflection_text, filename="trading_memory.md"):
|
||||
"""Write a completed entry directly to file, bypassing the API."""
|
||||
entry = (
|
||||
f"[{date} | {ticker} | Buy | +1.0% | +0.5% | 5d]\n\n"
|
||||
f"DECISION:\n{decision_text}\n\n"
|
||||
f"REFLECTION:\n{reflection_text}"
|
||||
+ _SEP
|
||||
)
|
||||
with open(tmp_path / filename, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
|
||||
def _resolve_entry(log, ticker, date, decision, reflection="Good call."):
|
||||
"""Store a decision then immediately resolve it via the API."""
|
||||
log.store_decision(ticker, date, decision)
|
||||
log.update_with_outcome(ticker, date, 0.05, 0.02, 5, reflection)
|
||||
|
||||
|
||||
def _price_df(prices):
|
||||
"""Minimal DataFrame matching yfinance .history() output shape."""
|
||||
return pd.DataFrame({"Close": prices})
|
||||
|
||||
|
||||
def _make_pm_state(past_context=""):
|
||||
"""Minimal AgentState dict for portfolio_manager_node."""
|
||||
return {
|
||||
"company_of_interest": "NVDA",
|
||||
"past_context": past_context,
|
||||
"risk_debate_state": {
|
||||
"history": "Risk debate history.",
|
||||
"aggressive_history": "",
|
||||
"conservative_history": "",
|
||||
"neutral_history": "",
|
||||
"judge_decision": "",
|
||||
"current_aggressive_response": "",
|
||||
"current_conservative_response": "",
|
||||
"current_neutral_response": "",
|
||||
"count": 1,
|
||||
},
|
||||
"market_report": "Market report.",
|
||||
"sentiment_report": "Sentiment report.",
|
||||
"news_report": "News report.",
|
||||
"fundamentals_report": "Fundamentals report.",
|
||||
"investment_plan": "Research plan.",
|
||||
"trader_investment_plan": "Trader plan.",
|
||||
}
|
||||
|
||||
|
||||
def _structured_pm_llm(captured: dict, decision: PortfolioDecision | None = None):
|
||||
"""Build a MagicMock LLM whose with_structured_output binding captures the
|
||||
prompt and returns a real PortfolioDecision (so render_pm_decision works).
|
||||
"""
|
||||
if decision is None:
|
||||
decision = PortfolioDecision(
|
||||
rating=PortfolioRating.HOLD,
|
||||
executive_summary="Hold the position; await catalyst.",
|
||||
investment_thesis="Balanced view; neither side carried the debate.",
|
||||
)
|
||||
structured = MagicMock()
|
||||
structured.invoke.side_effect = lambda prompt: (
|
||||
captured.__setitem__("prompt", prompt) or decision
|
||||
)
|
||||
llm = MagicMock()
|
||||
llm.with_structured_output.return_value = structured
|
||||
return llm
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core: storage and read path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTradingMemoryLogCore:
|
||||
|
||||
def test_store_creates_file(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
assert not (tmp_path / "trading_memory.md").exists()
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
assert (tmp_path / "trading_memory.md").exists()
|
||||
|
||||
def test_store_appends_not_overwrites(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 2
|
||||
assert entries[0]["ticker"] == "NVDA"
|
||||
assert entries[1]["ticker"] == "AAPL"
|
||||
|
||||
def test_store_decision_idempotent(self, tmp_path):
|
||||
"""Calling store_decision twice with same (ticker, date) stores only one entry."""
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
assert len(log.load_entries()) == 1
|
||||
|
||||
def test_batch_update_resolves_multiple_entries(self, tmp_path):
|
||||
"""batch_update_with_outcomes resolves multiple pending entries in one write."""
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||
log.store_decision("NVDA", "2026-01-12", DECISION_SELL)
|
||||
|
||||
updates = [
|
||||
{"ticker": "NVDA", "trade_date": "2026-01-05",
|
||||
"raw_return": 0.05, "alpha_return": 0.02, "holding_days": 5,
|
||||
"reflection": "First correct."},
|
||||
{"ticker": "NVDA", "trade_date": "2026-01-12",
|
||||
"raw_return": -0.03, "alpha_return": -0.01, "holding_days": 5,
|
||||
"reflection": "Second correct."},
|
||||
]
|
||||
log.batch_update_with_outcomes(updates)
|
||||
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 2
|
||||
assert all(not e["pending"] for e in entries)
|
||||
assert entries[0]["reflection"] == "First correct."
|
||||
assert entries[1]["reflection"] == "Second correct."
|
||||
|
||||
def test_pending_tag_format(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||
assert "[2026-01-10 | NVDA | Buy | pending]" in text
|
||||
|
||||
# Rating parsing
|
||||
|
||||
def test_rating_parsed_buy(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
assert log.load_entries()[0]["rating"] == "Buy"
|
||||
|
||||
def test_rating_parsed_overweight(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||
assert log.load_entries()[0]["rating"] == "Overweight"
|
||||
|
||||
def test_rating_fallback_hold(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING)
|
||||
assert log.load_entries()[0]["rating"] == "Hold"
|
||||
|
||||
def test_rating_priority_over_prose(self, tmp_path):
|
||||
"""'Rating: X' label wins even when an opposing rating word appears earlier in prose."""
|
||||
decision = (
|
||||
"The sell thesis is weak. The hold case is marginal.\n\n"
|
||||
"Rating: Buy\n\n"
|
||||
"Executive Summary: Strong fundamentals support the position."
|
||||
)
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", decision)
|
||||
assert log.load_entries()[0]["rating"] == "Buy"
|
||||
|
||||
# Delimiter robustness
|
||||
|
||||
def test_decision_with_markdown_separator(self, tmp_path):
|
||||
"""LLM decision containing '---' must not corrupt the entry."""
|
||||
decision = "Rating: Buy\n\n---\n\nRisk: elevated volatility."
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", decision)
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 1
|
||||
assert "Risk: elevated volatility" in entries[0]["decision"]
|
||||
|
||||
# load_entries
|
||||
|
||||
def test_load_entries_empty_file(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
assert log.load_entries() == []
|
||||
|
||||
def test_load_entries_single(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 1
|
||||
e = entries[0]
|
||||
assert e["date"] == "2026-01-10"
|
||||
assert e["ticker"] == "NVDA"
|
||||
assert e["rating"] == "Buy"
|
||||
assert e["pending"] is True
|
||||
assert e["raw"] is None
|
||||
|
||||
def test_load_entries_multiple(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
|
||||
log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING)
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 3
|
||||
assert [e["ticker"] for e in entries] == ["NVDA", "AAPL", "MSFT"]
|
||||
|
||||
def test_decision_content_preserved(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
assert log.load_entries()[0]["decision"] == DECISION_BUY.strip()
|
||||
|
||||
# get_pending_entries
|
||||
|
||||
def test_get_pending_returns_pending_only(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
_seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA.", "Correct.")
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
pending = log.get_pending_entries()
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["ticker"] == "NVDA"
|
||||
assert pending[0]["date"] == "2026-01-10"
|
||||
|
||||
# get_past_context
|
||||
|
||||
def test_get_past_context_empty(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
assert log.get_past_context("NVDA") == ""
|
||||
|
||||
def test_get_past_context_pending_excluded(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
assert log.get_past_context("NVDA") == ""
|
||||
|
||||
def test_get_past_context_same_ticker(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
_seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA — AI capex thesis intact.", "Directionally correct.")
|
||||
ctx = log.get_past_context("NVDA")
|
||||
assert "Past analyses of NVDA" in ctx
|
||||
assert "Buy NVDA" in ctx
|
||||
|
||||
def test_get_past_context_cross_ticker(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
_seed_completed(tmp_path, "AAPL", "2026-01-05", "Buy AAPL — Services growth.", "Correct.")
|
||||
ctx = log.get_past_context("NVDA")
|
||||
assert "Recent cross-ticker lessons" in ctx
|
||||
assert "Past analyses of NVDA" not in ctx
|
||||
|
||||
def test_n_same_limit_respected(self, tmp_path):
|
||||
"""Only the n_same most recent same-ticker entries are included."""
|
||||
log = make_log(tmp_path)
|
||||
for i in range(6):
|
||||
_seed_completed(tmp_path, "NVDA", f"2026-01-{i+1:02d}", f"Buy entry {i}.", "Correct.")
|
||||
ctx = log.get_past_context("NVDA", n_same=5)
|
||||
assert "Buy entry 0" not in ctx
|
||||
assert "Buy entry 5" in ctx
|
||||
|
||||
def test_n_cross_limit_respected(self, tmp_path):
|
||||
"""Only the n_cross most recent cross-ticker entries are included."""
|
||||
log = make_log(tmp_path)
|
||||
for i, ticker in enumerate(["AAPL", "MSFT", "GOOG", "META"]):
|
||||
_seed_completed(tmp_path, ticker, f"2026-01-{i+1:02d}", f"Buy {ticker}.", "Correct.")
|
||||
ctx = log.get_past_context("NVDA", n_cross=3)
|
||||
assert "AAPL" not in ctx
|
||||
assert "META" in ctx
|
||||
|
||||
# No-op when config is None
|
||||
|
||||
def test_no_log_path_is_noop(self):
|
||||
log = TradingMemoryLog(config=None)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
assert log.load_entries() == []
|
||||
assert log.get_past_context("NVDA") == ""
|
||||
|
||||
# Rotation: opt-in cap on resolved entries
|
||||
|
||||
def test_rotation_disabled_by_default(self, tmp_path):
|
||||
"""Without max_entries, all resolved entries are kept."""
|
||||
log = make_log(tmp_path)
|
||||
for i in range(7):
|
||||
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||
assert len(log.load_entries()) == 7
|
||||
|
||||
def test_rotation_prunes_oldest_resolved(self, tmp_path):
|
||||
"""When max_entries is set and exceeded, oldest resolved entries are pruned."""
|
||||
log = TradingMemoryLog({
|
||||
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||
"memory_log_max_entries": 3,
|
||||
})
|
||||
# Resolve 5 entries; rotation should keep only the 3 most recent.
|
||||
for i in range(5):
|
||||
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 3
|
||||
# Confirm the OLDEST were dropped, not the newest.
|
||||
dates = [e["date"] for e in entries]
|
||||
assert dates == ["2026-01-03", "2026-01-04", "2026-01-05"]
|
||||
|
||||
def test_rotation_never_prunes_pending(self, tmp_path):
|
||||
"""Pending entries (unresolved) are kept regardless of the cap."""
|
||||
log = TradingMemoryLog({
|
||||
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||
"memory_log_max_entries": 2,
|
||||
})
|
||||
# 3 resolved + 2 pending. With cap=2, only 2 resolved survive; both pending stay.
|
||||
for i in range(3):
|
||||
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Resolved {i}.")
|
||||
log.store_decision("NVDA", "2026-02-01", DECISION_BUY)
|
||||
log.store_decision("NVDA", "2026-02-02", DECISION_OVERWEIGHT)
|
||||
# Trigger rotation by resolving one more entry — pending entries must stay.
|
||||
_resolve_entry(log, "NVDA", "2026-01-04", DECISION_BUY, "Resolved 3.")
|
||||
entries = log.load_entries()
|
||||
pending = [e for e in entries if e["pending"]]
|
||||
resolved = [e for e in entries if not e["pending"]]
|
||||
assert len(pending) == 2, "pending entries must never be pruned"
|
||||
assert len(resolved) == 2, f"expected 2 resolved after rotation, got {len(resolved)}"
|
||||
|
||||
def test_rotation_under_cap_is_noop(self, tmp_path):
|
||||
"""No rotation when resolved count <= max_entries."""
|
||||
log = TradingMemoryLog({
|
||||
"memory_log_path": str(tmp_path / "trading_memory.md"),
|
||||
"memory_log_max_entries": 10,
|
||||
})
|
||||
for i in range(3):
|
||||
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||
assert len(log.load_entries()) == 3
|
||||
|
||||
# Rating parsing: markdown bold and numbered list formats
|
||||
|
||||
def test_rating_parsed_from_bold_markdown(self, tmp_path):
|
||||
"""**Rating**: Buy — markdown bold around the label must not prevent parsing."""
|
||||
decision = "**Rating**: Buy\nEnter at $190."
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", decision)
|
||||
assert log.load_entries()[0]["rating"] == "Buy"
|
||||
|
||||
def test_rating_parsed_from_bold_value(self, tmp_path):
|
||||
"""Rating: **Sell** — markdown bold around the value must not prevent parsing."""
|
||||
decision = "Rating: **Sell**\nExit immediately."
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", decision)
|
||||
assert log.load_entries()[0]["rating"] == "Sell"
|
||||
|
||||
def test_rating_label_wins_over_prose_with_markdown(self, tmp_path):
|
||||
"""Rating: **Sell** must win even when prose contains a conflicting rating word."""
|
||||
decision = (
|
||||
"The buy thesis is weakened by guidance.\n"
|
||||
"Rating: **Sell**\n"
|
||||
"Exit before earnings."
|
||||
)
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", decision)
|
||||
assert log.load_entries()[0]["rating"] == "Sell"
|
||||
|
||||
def test_rating_parsed_from_numbered_list(self, tmp_path):
|
||||
"""1. Rating: Buy — numbered list prefix must not prevent parsing."""
|
||||
decision = "1. Rating: Buy\nEnter at $190."
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", decision)
|
||||
assert log.load_entries()[0]["rating"] == "Buy"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deferred reflection: update_with_outcome, Reflector, _fetch_returns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeferredReflection:
|
||||
|
||||
# update_with_outcome
|
||||
|
||||
def test_update_replaces_pending_tag(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||
text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||
assert "[2026-01-10 | NVDA | Buy | pending]" not in text
|
||||
assert "+4.2%" in text
|
||||
assert "+2.1%" in text
|
||||
assert "5d" in text
|
||||
|
||||
def test_update_appends_reflection(self, tmp_path):
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 1
|
||||
e = entries[0]
|
||||
assert e["pending"] is False
|
||||
assert e["reflection"] == "Momentum confirmed."
|
||||
assert e["decision"] == DECISION_BUY.strip()
|
||||
|
||||
def test_update_preserves_other_entries(self, tmp_path):
|
||||
"""Only the matching entry is modified; all other entries remain unchanged."""
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
log.store_decision("AAPL", "2026-01-11", "Rating: Hold\nHold AAPL.")
|
||||
log.store_decision("MSFT", "2026-01-12", DECISION_SELL)
|
||||
log.update_with_outcome("AAPL", "2026-01-11", 0.01, -0.01, 5, "Neutral result.")
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 3
|
||||
nvda, aapl, msft = entries
|
||||
assert nvda["ticker"] == "NVDA" and nvda["pending"] is True
|
||||
assert aapl["ticker"] == "AAPL" and aapl["pending"] is False
|
||||
assert aapl["reflection"] == "Neutral result."
|
||||
assert msft["ticker"] == "MSFT" and msft["pending"] is True
|
||||
|
||||
def test_update_atomic_write(self, tmp_path):
|
||||
"""A pre-existing .tmp file is overwritten; the log is correctly updated."""
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
stale_tmp = tmp_path / "trading_memory.tmp"
|
||||
stale_tmp.write_text("GARBAGE CONTENT — should be overwritten", encoding="utf-8")
|
||||
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Correct.")
|
||||
assert not stale_tmp.exists()
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["reflection"] == "Correct."
|
||||
assert entries[0]["pending"] is False
|
||||
|
||||
def test_update_noop_when_no_log_path(self):
|
||||
log = TradingMemoryLog(config=None)
|
||||
log.update_with_outcome("NVDA", "2026-01-10", 0.05, 0.02, 5, "Reflection")
|
||||
|
||||
def test_formatting_roundtrip_after_update(self, tmp_path):
|
||||
"""All fields intact and blank line between tag and DECISION preserved after update."""
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
|
||||
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 1
|
||||
e = entries[0]
|
||||
assert e["pending"] is False
|
||||
assert e["decision"] == DECISION_BUY.strip()
|
||||
assert e["reflection"] == "Momentum confirmed."
|
||||
assert e["raw"] == "+4.2%"
|
||||
assert e["alpha"] == "+2.1%"
|
||||
assert e["holding"] == "5d"
|
||||
raw_text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
|
||||
assert "[2026-01-10 | NVDA | Buy | +4.2% | +2.1% | 5d]\n\nDECISION:" in raw_text
|
||||
|
||||
# Reflector.reflect_on_final_decision
|
||||
|
||||
def test_reflect_on_final_decision_returns_llm_output(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.return_value.content = "Directionally correct. Thesis confirmed."
|
||||
reflector = Reflector(mock_llm)
|
||||
result = reflector.reflect_on_final_decision(
|
||||
final_decision=DECISION_BUY, raw_return=0.042, alpha_return=0.021
|
||||
)
|
||||
assert result == "Directionally correct. Thesis confirmed."
|
||||
mock_llm.invoke.assert_called_once()
|
||||
|
||||
def test_reflect_on_final_decision_includes_returns_in_prompt(self):
|
||||
"""Return figures are present in the human message sent to the LLM."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.return_value.content = "Incorrect call."
|
||||
reflector = Reflector(mock_llm)
|
||||
reflector.reflect_on_final_decision(
|
||||
final_decision=DECISION_SELL, raw_return=-0.08, alpha_return=-0.05
|
||||
)
|
||||
messages = mock_llm.invoke.call_args[0][0]
|
||||
human_content = next(content for role, content in messages if role == "human")
|
||||
assert "-8.0%" in human_content
|
||||
assert "-5.0%" in human_content
|
||||
assert "Exit position immediately." in human_content
|
||||
|
||||
# TradingAgentsGraph._fetch_returns
|
||||
|
||||
def test_fetch_returns_valid_ticker(self):
|
||||
stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0]
|
||||
spy_prices = [400.0, 402.0, 404.0, 403.0, 405.0, 406.0]
|
||||
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||
def _make_ticker(sym):
|
||||
m = MagicMock()
|
||||
m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices)
|
||||
return m
|
||||
mock_ticker_cls.side_effect = _make_ticker
|
||||
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05")
|
||||
assert raw is not None and alpha is not None and days is not None
|
||||
assert isinstance(raw, float) and isinstance(alpha, float) and isinstance(days, int)
|
||||
assert days == 5
|
||||
|
||||
def test_fetch_returns_too_recent(self):
|
||||
"""Only 1 data point available → returns (None, None, None), no crash."""
|
||||
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||
m = MagicMock()
|
||||
m.history.return_value = _price_df([100.0])
|
||||
mock_ticker_cls.return_value = m
|
||||
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-04-19")
|
||||
assert raw is None and alpha is None and days is None
|
||||
|
||||
def test_fetch_returns_delisted(self):
|
||||
"""Empty DataFrame → returns (None, None, None), no crash."""
|
||||
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||
m = MagicMock()
|
||||
m.history.return_value = pd.DataFrame({"Close": []})
|
||||
mock_ticker_cls.return_value = m
|
||||
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "XXXXXFAKE", "2026-01-10")
|
||||
assert raw is None and alpha is None and days is None
|
||||
|
||||
def test_fetch_returns_spy_shorter_than_stock(self):
|
||||
"""SPY having fewer rows than the stock must not raise IndexError."""
|
||||
stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0]
|
||||
spy_prices = [400.0, 402.0, 403.0]
|
||||
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||
with patch("yfinance.Ticker") as mock_ticker_cls:
|
||||
def _make_ticker(sym):
|
||||
m = MagicMock()
|
||||
m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices)
|
||||
return m
|
||||
mock_ticker_cls.side_effect = _make_ticker
|
||||
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05")
|
||||
assert raw is not None and alpha is not None and days is not None
|
||||
assert days == 2
|
||||
|
||||
# TradingAgentsGraph._resolve_pending_entries
|
||||
|
||||
def test_resolve_skips_other_tickers(self, tmp_path):
|
||||
"""Pending AAPL entry is not resolved when the run is for NVDA."""
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("AAPL", "2026-01-10", DECISION_BUY)
|
||||
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||
mock_graph.memory_log = log
|
||||
mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5))
|
||||
TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA")
|
||||
mock_graph._fetch_returns.assert_not_called()
|
||||
assert len(log.get_pending_entries()) == 1
|
||||
|
||||
def test_resolve_marks_entry_completed(self, tmp_path):
|
||||
"""After resolve, get_pending_entries() is empty and the entry has a REFLECTION."""
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||
mock_reflector = MagicMock()
|
||||
mock_reflector.reflect_on_final_decision.return_value = "Momentum confirmed."
|
||||
mock_graph = MagicMock(spec=TradingAgentsGraph)
|
||||
mock_graph.memory_log = log
|
||||
mock_graph.reflector = mock_reflector
|
||||
mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5))
|
||||
TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA")
|
||||
assert log.get_pending_entries() == []
|
||||
entries = log.load_entries()
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["pending"] is False
|
||||
assert entries[0]["reflection"] == "Momentum confirmed."
|
||||
assert "+5.0%" in entries[0]["raw"]
|
||||
assert "+2.0%" in entries[0]["alpha"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Portfolio Manager injection: past_context in state and prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPortfolioManagerInjection:
|
||||
|
||||
# past_context in initial state
|
||||
|
||||
def test_past_context_in_initial_state(self):
|
||||
propagator = Propagator()
|
||||
state = propagator.create_initial_state("NVDA", "2026-01-10", past_context="some context")
|
||||
assert "past_context" in state
|
||||
assert state["past_context"] == "some context"
|
||||
|
||||
def test_past_context_defaults_to_empty(self):
|
||||
propagator = Propagator()
|
||||
state = propagator.create_initial_state("NVDA", "2026-01-10")
|
||||
assert state["past_context"] == ""
|
||||
|
||||
# PM prompt
|
||||
|
||||
def test_pm_prompt_includes_past_context(self):
|
||||
captured = {}
|
||||
llm = _structured_pm_llm(captured)
|
||||
pm_node = create_portfolio_manager(llm)
|
||||
state = _make_pm_state(past_context="[2026-01-05 | NVDA | Buy | +5.0% | +2.0% | 5d]\nGreat call.")
|
||||
pm_node(state)
|
||||
assert "Lessons from prior decisions and outcomes" in captured["prompt"]
|
||||
assert "Great call." in captured["prompt"]
|
||||
|
||||
def test_pm_no_past_context_no_section(self):
|
||||
"""PM prompt omits the lessons section entirely when past_context is empty."""
|
||||
captured = {}
|
||||
llm = _structured_pm_llm(captured)
|
||||
pm_node = create_portfolio_manager(llm)
|
||||
state = _make_pm_state(past_context="")
|
||||
pm_node(state)
|
||||
assert "Lessons from prior decisions" not in captured["prompt"]
|
||||
|
||||
def test_pm_returns_rendered_markdown_with_rating(self):
|
||||
"""The structured PortfolioDecision is rendered to markdown that
|
||||
downstream consumers (memory log, signal processor, CLI display)
|
||||
can parse without any extra LLM call."""
|
||||
captured = {}
|
||||
decision = PortfolioDecision(
|
||||
rating=PortfolioRating.OVERWEIGHT,
|
||||
executive_summary="Build position gradually over the next two weeks.",
|
||||
investment_thesis="AI capex cycle remains intact; institutional flows constructive.",
|
||||
price_target=215.0,
|
||||
time_horizon="3-6 months",
|
||||
)
|
||||
llm = _structured_pm_llm(captured, decision)
|
||||
pm_node = create_portfolio_manager(llm)
|
||||
result = pm_node(_make_pm_state())
|
||||
md = result["final_trade_decision"]
|
||||
assert "**Rating**: Overweight" in md
|
||||
assert "**Executive Summary**: Build position gradually" in md
|
||||
assert "**Investment Thesis**: AI capex cycle" in md
|
||||
assert "**Price Target**: 215.0" in md
|
||||
assert "**Time Horizon**: 3-6 months" in md
|
||||
|
||||
def test_pm_falls_back_to_freetext_when_structured_unavailable(self):
|
||||
"""If a provider does not support with_structured_output, the agent
|
||||
falls back to a plain invoke and returns whatever prose the model
|
||||
produced, so the pipeline never blocks."""
|
||||
plain_response = "**Rating**: Sell\n\nExit ahead of guidance."
|
||||
llm = MagicMock()
|
||||
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||
pm_node = create_portfolio_manager(llm)
|
||||
result = pm_node(_make_pm_state())
|
||||
assert result["final_trade_decision"] == plain_response
|
||||
|
||||
# get_past_context ordering and limits
|
||||
|
||||
def test_same_ticker_prioritised(self, tmp_path):
|
||||
"""Same-ticker entries in same-ticker section; cross-ticker entries in cross-ticker section."""
|
||||
log = make_log(tmp_path)
|
||||
_resolve_entry(log, "NVDA", "2026-01-05", DECISION_BUY, "Momentum confirmed.")
|
||||
_resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued.")
|
||||
result = log.get_past_context("NVDA")
|
||||
assert "Past analyses of NVDA" in result
|
||||
assert "Recent cross-ticker lessons" in result
|
||||
same_block, cross_block = result.split("Recent cross-ticker lessons")
|
||||
assert "NVDA" in same_block
|
||||
assert "AAPL" in cross_block
|
||||
|
||||
def test_cross_ticker_reflection_only(self, tmp_path):
|
||||
"""Cross-ticker entries show only the REFLECTION text, not the full DECISION."""
|
||||
log = make_log(tmp_path)
|
||||
_resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued correction.")
|
||||
result = log.get_past_context("NVDA")
|
||||
assert "Overvalued correction." in result
|
||||
assert "Exit position immediately." not in result
|
||||
|
||||
def test_n_same_limit_respected(self, tmp_path):
|
||||
"""More than 5 same-ticker completed entries → only 5 injected."""
|
||||
log = make_log(tmp_path)
|
||||
for i in range(7):
|
||||
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
|
||||
result = log.get_past_context("NVDA", n_same=5)
|
||||
lessons_present = sum(1 for i in range(7) if f"Lesson {i}." in result)
|
||||
assert lessons_present == 5
|
||||
|
||||
def test_n_cross_limit_respected(self, tmp_path):
|
||||
"""More than 3 cross-ticker completed entries → only 3 injected."""
|
||||
log = make_log(tmp_path)
|
||||
tickers = ["AAPL", "MSFT", "TSLA", "AMZN", "GOOG"]
|
||||
for i, ticker in enumerate(tickers):
|
||||
_resolve_entry(log, ticker, f"2026-01-{i+1:02d}", DECISION_BUY, f"{ticker} lesson.")
|
||||
result = log.get_past_context("NVDA", n_cross=3)
|
||||
cross_count = sum(result.count(f"{t} lesson.") for t in tickers)
|
||||
assert cross_count == 3
|
||||
|
||||
# Full A→B→C integration cycle
|
||||
|
||||
def test_full_cycle_store_resolve_inject(self, tmp_path):
|
||||
"""store pending → resolve with outcome → past_context non-empty for PM."""
|
||||
log = make_log(tmp_path)
|
||||
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
|
||||
assert len(log.get_pending_entries()) == 1
|
||||
assert log.get_past_context("NVDA") == ""
|
||||
log.update_with_outcome("NVDA", "2026-01-05", 0.05, 0.02, 5, "Correct call.")
|
||||
assert log.get_pending_entries() == []
|
||||
past_ctx = log.get_past_context("NVDA")
|
||||
assert past_ctx != ""
|
||||
assert "NVDA" in past_ctx
|
||||
assert "Correct call." in past_ctx
|
||||
assert "DECISION:" in past_ctx
|
||||
assert "REFLECTION:" in past_ctx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy removal: BM25 / FinancialSituationMemory fully gone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLegacyRemoval:
|
||||
|
||||
def test_financial_situation_memory_removed(self):
|
||||
"""FinancialSituationMemory must not be importable from the memory module."""
|
||||
import tradingagents.agents.utils.memory as m
|
||||
assert not hasattr(m, "FinancialSituationMemory")
|
||||
|
||||
def test_bm25_not_imported(self):
|
||||
"""rank_bm25 must not be present in the memory module namespace."""
|
||||
import tradingagents.agents.utils.memory as m
|
||||
assert not hasattr(m, "BM25Okapi")
|
||||
|
||||
def test_reflect_and_remember_removed(self):
|
||||
"""TradingAgentsGraph must not expose reflect_and_remember."""
|
||||
assert not hasattr(TradingAgentsGraph, "reflect_and_remember")
|
||||
|
||||
def test_portfolio_manager_no_memory_param(self):
|
||||
"""create_portfolio_manager accepts only llm; passing memory= raises TypeError."""
|
||||
mock_llm = MagicMock()
|
||||
create_portfolio_manager(mock_llm)
|
||||
with pytest.raises(TypeError):
|
||||
create_portfolio_manager(mock_llm, memory=MagicMock())
|
||||
|
||||
def test_full_pipeline_no_regression(self, tmp_path):
|
||||
"""propagate() completes and stores the decision after the redesign."""
|
||||
import functools
|
||||
|
||||
fake_state = {
|
||||
"final_trade_decision": "Rating: Buy\nBuy NVDA.",
|
||||
"company_of_interest": "NVDA",
|
||||
"trade_date": "2026-01-10",
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
"fundamentals_report": "",
|
||||
"investment_debate_state": {
|
||||
"bull_history": "", "bear_history": "", "history": "",
|
||||
"current_response": "", "judge_decision": "",
|
||||
},
|
||||
"investment_plan": "",
|
||||
"trader_investment_plan": "",
|
||||
"risk_debate_state": {
|
||||
"aggressive_history": "", "conservative_history": "",
|
||||
"neutral_history": "", "history": "", "judge_decision": "",
|
||||
"current_aggressive_response": "", "current_conservative_response": "",
|
||||
"current_neutral_response": "", "count": 1, "latest_speaker": "",
|
||||
},
|
||||
}
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.memory_log = TradingMemoryLog({"memory_log_path": str(tmp_path / "mem.md")})
|
||||
mock_graph.log_states_dict = {}
|
||||
mock_graph.debug = False
|
||||
mock_graph.config = {"results_dir": str(tmp_path)}
|
||||
mock_graph.graph.invoke.return_value = fake_state
|
||||
mock_graph.propagator.create_initial_state.return_value = fake_state
|
||||
mock_graph.propagator.get_graph_args.return_value = {}
|
||||
mock_graph.signal_processor.process_signal.return_value = "Buy"
|
||||
# Bind the real _run_graph so propagate's call to self._run_graph executes
|
||||
# the actual write path instead of the auto-MagicMock.
|
||||
mock_graph._run_graph = functools.partial(
|
||||
TradingAgentsGraph._run_graph, mock_graph
|
||||
)
|
||||
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
|
||||
entries = mock_graph.memory_log.load_entries()
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["ticker"] == "NVDA"
|
||||
assert entries[0]["pending"] is True
|
||||
73
tests/test_minimax.py
Normal file
73
tests/test_minimax.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Tests for MinimaxChatOpenAI quirks.
|
||||
|
||||
Verifies the subclass injects ``reasoning_split=True`` into outgoing
|
||||
requests so M2.x reasoning models put their <think> block into
|
||||
``reasoning_details`` instead of polluting ``message.content``.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tradingagents.llm_clients.openai_client import MinimaxChatOpenAI
|
||||
|
||||
|
||||
def _client(model: str = "MiniMax-M2.7"):
|
||||
os.environ.setdefault("MINIMAX_API_KEY", "placeholder")
|
||||
return MinimaxChatOpenAI(
|
||||
model=model,
|
||||
api_key="placeholder",
|
||||
base_url="https://api.minimax.io/v1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMinimaxReasoningSplit:
|
||||
def test_request_payload_sets_reasoning_split(self):
|
||||
payload = _client()._get_request_payload([HumanMessage(content="hi")])
|
||||
assert payload.get("reasoning_split") is True
|
||||
|
||||
def test_caller_supplied_reasoning_split_is_preserved(self):
|
||||
"""If the user explicitly sets reasoning_split, don't override it
|
||||
(setdefault semantics — caller wins)."""
|
||||
client = _client()
|
||||
payload = client._get_request_payload(
|
||||
[HumanMessage(content="hi")],
|
||||
reasoning_split=False,
|
||||
)
|
||||
# langchain may or may not surface that kwarg into the payload;
|
||||
# what matters is we don't blindly overwrite a non-default value
|
||||
# the caller passed. setdefault leaves an existing value alone.
|
||||
assert payload.get("reasoning_split") in (False, True)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMinimaxStructuredOutputDispatch:
|
||||
"""M2.x models route through the capability table — tool_choice is
|
||||
suppressed but the schema is still bound as a tool."""
|
||||
|
||||
class _Pick(BaseModel):
|
||||
action: str
|
||||
|
||||
def _bound_kwargs(self, runnable):
|
||||
first = runnable.steps[0] if hasattr(runnable, "steps") else runnable
|
||||
return getattr(first, "kwargs", {})
|
||||
|
||||
def test_m2_7_suppresses_tool_choice(self):
|
||||
bound = _client("MiniMax-M2.7").with_structured_output(self._Pick)
|
||||
kwargs = self._bound_kwargs(bound)
|
||||
assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs
|
||||
|
||||
def test_m2_7_highspeed_suppresses_tool_choice(self):
|
||||
bound = _client("MiniMax-M2.7-highspeed").with_structured_output(self._Pick)
|
||||
kwargs = self._bound_kwargs(bound)
|
||||
assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs
|
||||
|
||||
def test_schema_still_bound_as_tool(self):
|
||||
bound = _client("MiniMax-M2.7").with_structured_output(self._Pick)
|
||||
tools = self._bound_kwargs(bound).get("tools", [])
|
||||
assert any(
|
||||
t.get("function", {}).get("name") == "_Pick" for t in tools
|
||||
), f"schema not bound: {tools}"
|
||||
@@ -1,6 +1,8 @@
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.base_client import BaseLLMClient
|
||||
from tradingagents.llm_clients.model_catalog import get_known_models
|
||||
from tradingagents.llm_clients.validators import validate_model
|
||||
@@ -19,6 +21,7 @@ class DummyLLMClient(BaseLLMClient):
|
||||
return validate_model(self.provider, self.model)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class ModelValidationTests(unittest.TestCase):
|
||||
def test_cli_catalog_models_are_all_validator_approved(self):
|
||||
for provider, models in get_known_models().items():
|
||||
|
||||
52
tests/test_safe_ticker_component.py
Normal file
52
tests/test_safe_ticker_component.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Tests for the ticker path-component validator that blocks directory traversal."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.dataflows.utils import safe_ticker_component
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSafeTickerComponent(unittest.TestCase):
|
||||
def test_accepts_common_ticker_formats(self):
|
||||
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
|
||||
self.assertEqual(safe_ticker_component(ticker), ticker)
|
||||
|
||||
def test_rejects_path_separators(self):
|
||||
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component(bad)
|
||||
|
||||
def test_rejects_null_byte_and_whitespace(self):
|
||||
for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component(bad)
|
||||
|
||||
def test_rejects_empty_or_non_string(self):
|
||||
for bad in ("", None, 123, b"AAPL"):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component(bad)
|
||||
|
||||
def test_rejects_overlong_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component("A" * 33)
|
||||
|
||||
def test_rejects_dot_only_values(self):
|
||||
# '.' and '..' pass the regex but traverse when used as a path
|
||||
# component (e.g. ``Path(results_dir) / ticker / "logs"``).
|
||||
for bad in (".", "..", "...", "...."):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component(bad)
|
||||
|
||||
def test_traversal_string_does_not_escape_join(self):
|
||||
"""Sanity: sanitized values stay within base when joined."""
|
||||
base = os.path.realpath("/tmp/cache")
|
||||
ticker = safe_ticker_component("AAPL")
|
||||
joined = os.path.realpath(os.path.join(base, f"{ticker}.csv"))
|
||||
self.assertTrue(joined.startswith(base + os.sep))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
90
tests/test_signal_processing.py
Normal file
90
tests/test_signal_processing.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Tests for the shared rating heuristic and the SignalProcessor adapter.
|
||||
|
||||
The Portfolio Manager produces a typed PortfolioDecision via structured
|
||||
output and renders it to markdown that always contains a ``**Rating**: X``
|
||||
header. The deterministic heuristic in ``tradingagents.agents.utils.rating``
|
||||
is therefore sufficient to extract the rating downstream — no second LLM
|
||||
call is needed — and SignalProcessor is now a thin adapter that delegates
|
||||
to it.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
|
||||
from tradingagents.graph.signal_processing import SignalProcessor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Heuristic parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestParseRating:
|
||||
def test_explicit_label_buy(self):
|
||||
assert parse_rating("Rating: Buy\nReasoning here.") == "Buy"
|
||||
|
||||
def test_explicit_label_overweight(self):
|
||||
assert parse_rating("Rating: Overweight\nDetails.") == "Overweight"
|
||||
|
||||
def test_explicit_label_with_markdown_bold_value(self):
|
||||
# Regression: Rating: **Sell** — markdown around the value.
|
||||
assert parse_rating("Rating: **Sell**\nExit immediately.") == "Sell"
|
||||
|
||||
def test_explicit_label_with_markdown_bold_label(self):
|
||||
assert parse_rating("**Rating**: Underweight\nTrim exposure.") == "Underweight"
|
||||
|
||||
def test_rendered_pm_markdown_shape(self):
|
||||
# The exact shape produced by render_pm_decision must always parse.
|
||||
text = (
|
||||
"**Rating**: Buy\n\n"
|
||||
"**Executive Summary**: Enter at $189-192, 6% portfolio cap.\n\n"
|
||||
"**Investment Thesis**: AI capex cycle intact; institutional flows constructive."
|
||||
)
|
||||
assert parse_rating(text) == "Buy"
|
||||
|
||||
def test_explicit_label_wins_over_prose_with_markdown(self):
|
||||
text = (
|
||||
"The buy thesis is weakened by guidance.\n"
|
||||
"Rating: **Sell**\n"
|
||||
"Exit before earnings."
|
||||
)
|
||||
assert parse_rating(text) == "Sell"
|
||||
|
||||
def test_no_rating_returns_default(self):
|
||||
assert parse_rating("No clear directional signal at this time.") == "Hold"
|
||||
|
||||
def test_no_rating_custom_default(self):
|
||||
assert parse_rating("Plain prose.", default="Underweight") == "Underweight"
|
||||
|
||||
def test_all_five_tiers_recognised(self):
|
||||
for r in RATINGS_5_TIER:
|
||||
assert parse_rating(f"Rating: {r}") == r
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SignalProcessor: thin adapter over the heuristic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSignalProcessor:
|
||||
def test_returns_rating_from_pm_markdown(self):
|
||||
sp = SignalProcessor()
|
||||
md = "**Rating**: Overweight\n\n**Executive Summary**: Build gradually."
|
||||
assert sp.process_signal(md) == "Overweight"
|
||||
|
||||
def test_makes_no_llm_calls(self):
|
||||
"""SignalProcessor must not invoke the LLM it was constructed with —
|
||||
the rating is parseable from the rendered PM markdown directly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
llm = MagicMock()
|
||||
sp = SignalProcessor(llm)
|
||||
sp.process_signal("Rating: Buy\nDetails.")
|
||||
llm.invoke.assert_not_called()
|
||||
llm.with_structured_output.assert_not_called()
|
||||
|
||||
def test_default_when_no_rating_present(self):
|
||||
sp = SignalProcessor()
|
||||
assert sp.process_signal("Plain prose without a recommendation.") == "Hold"
|
||||
232
tests/test_structured_agents.py
Normal file
232
tests/test_structured_agents.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Tests for structured-output agents (Trader and Research Manager).
|
||||
|
||||
The Portfolio Manager has its own coverage in tests/test_memory_log.py
|
||||
(which exercises the full memory-log → PM injection cycle). This file
|
||||
covers the parallel schemas, render functions, and graceful-fallback
|
||||
behavior we added for the Trader and Research Manager so all three
|
||||
decision-making agents share the same shape.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.agents.managers.research_manager import create_research_manager
|
||||
from tradingagents.agents.schemas import (
|
||||
PortfolioRating,
|
||||
ResearchPlan,
|
||||
TraderAction,
|
||||
TraderProposal,
|
||||
render_research_plan,
|
||||
render_trader_proposal,
|
||||
)
|
||||
from tradingagents.agents.trader.trader import create_trader
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Render functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRenderTraderProposal:
|
||||
def test_minimal_required_fields(self):
|
||||
p = TraderProposal(action=TraderAction.HOLD, reasoning="Balanced setup; no edge.")
|
||||
md = render_trader_proposal(p)
|
||||
assert "**Action**: Hold" in md
|
||||
assert "**Reasoning**: Balanced setup; no edge." in md
|
||||
# The trailing FINAL TRANSACTION PROPOSAL line is preserved for the
|
||||
# analyst stop-signal text and any external code that greps for it.
|
||||
assert "FINAL TRANSACTION PROPOSAL: **HOLD**" in md
|
||||
|
||||
def test_optional_fields_included_when_present(self):
|
||||
p = TraderProposal(
|
||||
action=TraderAction.BUY,
|
||||
reasoning="Strong technicals + fundamentals.",
|
||||
entry_price=189.5,
|
||||
stop_loss=178.0,
|
||||
position_sizing="6% of portfolio",
|
||||
)
|
||||
md = render_trader_proposal(p)
|
||||
assert "**Action**: Buy" in md
|
||||
assert "**Entry Price**: 189.5" in md
|
||||
assert "**Stop Loss**: 178.0" in md
|
||||
assert "**Position Sizing**: 6% of portfolio" in md
|
||||
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in md
|
||||
|
||||
def test_optional_fields_omitted_when_absent(self):
|
||||
p = TraderProposal(action=TraderAction.SELL, reasoning="Guidance cut.")
|
||||
md = render_trader_proposal(p)
|
||||
assert "Entry Price" not in md
|
||||
assert "Stop Loss" not in md
|
||||
assert "Position Sizing" not in md
|
||||
assert "FINAL TRANSACTION PROPOSAL: **SELL**" in md
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRenderResearchPlan:
|
||||
def test_required_fields(self):
|
||||
p = ResearchPlan(
|
||||
recommendation=PortfolioRating.OVERWEIGHT,
|
||||
rationale="Bull case carried; tailwinds intact.",
|
||||
strategic_actions="Build position over two weeks; cap at 5%.",
|
||||
)
|
||||
md = render_research_plan(p)
|
||||
assert "**Recommendation**: Overweight" in md
|
||||
assert "**Rationale**: Bull case carried" in md
|
||||
assert "**Strategic Actions**: Build position" in md
|
||||
|
||||
def test_all_5_tier_ratings_render(self):
|
||||
for rating in PortfolioRating:
|
||||
p = ResearchPlan(
|
||||
recommendation=rating,
|
||||
rationale="r",
|
||||
strategic_actions="s",
|
||||
)
|
||||
md = render_research_plan(p)
|
||||
assert f"**Recommendation**: {rating.value}" in md
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trader agent: structured happy path + fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_trader_state():
|
||||
return {
|
||||
"company_of_interest": "NVDA",
|
||||
"investment_plan": "**Recommendation**: Buy\n**Rationale**: ...\n**Strategic Actions**: ...",
|
||||
}
|
||||
|
||||
|
||||
def _structured_trader_llm(captured: dict, proposal: TraderProposal | None = None):
|
||||
"""Build a MagicMock LLM whose with_structured_output binding captures the
|
||||
prompt and returns a real TraderProposal so render_trader_proposal works.
|
||||
"""
|
||||
if proposal is None:
|
||||
proposal = TraderProposal(
|
||||
action=TraderAction.BUY,
|
||||
reasoning="Strong setup.",
|
||||
)
|
||||
structured = MagicMock()
|
||||
structured.invoke.side_effect = lambda prompt: (
|
||||
captured.__setitem__("prompt", prompt) or proposal
|
||||
)
|
||||
llm = MagicMock()
|
||||
llm.with_structured_output.return_value = structured
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTraderAgent:
|
||||
def test_structured_path_produces_rendered_markdown(self):
|
||||
captured = {}
|
||||
proposal = TraderProposal(
|
||||
action=TraderAction.BUY,
|
||||
reasoning="AI capex cycle intact; institutional flows constructive.",
|
||||
entry_price=189.5,
|
||||
stop_loss=178.0,
|
||||
position_sizing="6% of portfolio",
|
||||
)
|
||||
llm = _structured_trader_llm(captured, proposal)
|
||||
trader = create_trader(llm)
|
||||
result = trader(_make_trader_state())
|
||||
plan = result["trader_investment_plan"]
|
||||
assert "**Action**: Buy" in plan
|
||||
assert "**Entry Price**: 189.5" in plan
|
||||
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in plan
|
||||
# The same rendered markdown is also added to messages for downstream agents.
|
||||
assert plan in result["messages"][0].content
|
||||
|
||||
def test_prompt_includes_investment_plan(self):
|
||||
captured = {}
|
||||
llm = _structured_trader_llm(captured)
|
||||
trader = create_trader(llm)
|
||||
trader(_make_trader_state())
|
||||
# The investment plan is in the user message of the captured prompt.
|
||||
prompt = captured["prompt"]
|
||||
assert any("Proposed Investment Plan" in m["content"] for m in prompt)
|
||||
|
||||
def test_falls_back_to_freetext_when_structured_unavailable(self):
|
||||
plain_response = (
|
||||
"**Action**: Sell\n\nGuidance cut hits margins.\n\n"
|
||||
"FINAL TRANSACTION PROPOSAL: **SELL**"
|
||||
)
|
||||
llm = MagicMock()
|
||||
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||
trader = create_trader(llm)
|
||||
result = trader(_make_trader_state())
|
||||
assert result["trader_investment_plan"] == plain_response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Research Manager agent: structured happy path + fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_rm_state():
|
||||
return {
|
||||
"company_of_interest": "NVDA",
|
||||
"investment_debate_state": {
|
||||
"history": "Bull and bear arguments here.",
|
||||
"bull_history": "Bull says...",
|
||||
"bear_history": "Bear says...",
|
||||
"current_response": "",
|
||||
"judge_decision": "",
|
||||
"count": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _structured_rm_llm(captured: dict, plan: ResearchPlan | None = None):
|
||||
if plan is None:
|
||||
plan = ResearchPlan(
|
||||
recommendation=PortfolioRating.HOLD,
|
||||
rationale="Balanced view across both sides.",
|
||||
strategic_actions="Hold current position; reassess after earnings.",
|
||||
)
|
||||
structured = MagicMock()
|
||||
structured.invoke.side_effect = lambda prompt: (
|
||||
captured.__setitem__("prompt", prompt) or plan
|
||||
)
|
||||
llm = MagicMock()
|
||||
llm.with_structured_output.return_value = structured
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResearchManagerAgent:
|
||||
def test_structured_path_produces_rendered_markdown(self):
|
||||
captured = {}
|
||||
plan = ResearchPlan(
|
||||
recommendation=PortfolioRating.OVERWEIGHT,
|
||||
rationale="Bull case is stronger; AI tailwind intact.",
|
||||
strategic_actions="Build position gradually over two weeks.",
|
||||
)
|
||||
llm = _structured_rm_llm(captured, plan)
|
||||
rm = create_research_manager(llm)
|
||||
result = rm(_make_rm_state())
|
||||
ip = result["investment_plan"]
|
||||
assert "**Recommendation**: Overweight" in ip
|
||||
assert "**Rationale**: Bull case" in ip
|
||||
assert "**Strategic Actions**: Build position" in ip
|
||||
|
||||
def test_prompt_uses_5_tier_rating_scale(self):
|
||||
"""The RM prompt must list all five tiers so the schema enum matches user expectations."""
|
||||
captured = {}
|
||||
llm = _structured_rm_llm(captured)
|
||||
rm = create_research_manager(llm)
|
||||
rm(_make_rm_state())
|
||||
prompt = captured["prompt"]
|
||||
for tier in ("Buy", "Overweight", "Hold", "Underweight", "Sell"):
|
||||
assert f"**{tier}**" in prompt, f"missing {tier} in prompt"
|
||||
|
||||
def test_falls_back_to_freetext_when_structured_unavailable(self):
|
||||
plain_response = "**Recommendation**: Sell\n\n**Rationale**: ...\n\n**Strategic Actions**: ..."
|
||||
llm = MagicMock()
|
||||
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
|
||||
llm.invoke.return_value = MagicMock(content=plain_response)
|
||||
rm = create_research_manager(llm)
|
||||
result = rm(_make_rm_state())
|
||||
assert result["investment_plan"] == plain_response
|
||||
@@ -1,9 +1,12 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from cli.utils import normalize_ticker_symbol
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TickerSymbolHandlingTests(unittest.TestCase):
|
||||
def test_normalize_ticker_symbol_preserves_exchange_suffix(self):
|
||||
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")
|
||||
|
||||
Reference in New Issue
Block a user