Merge remote-tracking branch 'upstream/main' into analyst-phase1-observability

# Conflicts:
#	tradingagents/default_config.py
#	tradingagents/graph/setup.py
This commit is contained in:
CadeYu
2026-05-11 16:44:00 +08:00
71 changed files with 5225 additions and 760 deletions

46
tests/conftest.py Normal file
View File

@@ -0,0 +1,46 @@
"""Shared pytest fixtures that prevent CI hangs when API keys are absent."""
import os
from unittest.mock import MagicMock, patch
import pytest
def pytest_configure(config):
for marker in ("unit", "integration", "smoke"):
config.addinivalue_line("markers", f"{marker}: {marker}-level tests")
_API_KEY_ENV_VARS = (
"OPENAI_API_KEY",
"GOOGLE_API_KEY",
"ANTHROPIC_API_KEY",
"XAI_API_KEY",
"DEEPSEEK_API_KEY",
"DASHSCOPE_API_KEY",
"DASHSCOPE_CN_API_KEY",
"ZHIPU_API_KEY",
"ZHIPU_CN_API_KEY",
"MINIMAX_API_KEY",
"MINIMAX_CN_API_KEY",
"OPENROUTER_API_KEY",
"AZURE_OPENAI_API_KEY",
"ALPHA_VANTAGE_API_KEY",
)
@pytest.fixture(autouse=True)
def _dummy_api_keys(monkeypatch):
for env_var in _API_KEY_ENV_VARS:
monkeypatch.setenv(env_var, os.environ.get(env_var, "placeholder"))
@pytest.fixture()
def mock_llm_client():
client = MagicMock()
client.get_llm.return_value = MagicMock()
with patch(
"tradingagents.llm_clients.factory.create_llm_client",
return_value=client,
):
yield client

149
tests/test_api_key_env.py Normal file
View File

@@ -0,0 +1,149 @@
"""Tests for the canonical provider->env-var mapping and the CLI key-prompt helper."""
from __future__ import annotations
import os
from pathlib import Path
from unittest.mock import patch
import pytest
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
# ---- Mapping coverage -----------------------------------------------------
def test_every_select_llm_provider_choice_has_an_entry():
"""select_llm_provider() must not present a provider the mapping doesn't know about."""
# Mirrors the dropdown order in cli/utils.select_llm_provider so the two
# stay in lockstep. Region-specific keys (qwen-cn / minimax-cn / glm-cn)
# are reached via the secondary region prompt, so they must also be present.
expected = {
"openai", "google", "anthropic", "xai", "deepseek",
"qwen", "qwen-cn",
"glm", "glm-cn",
"minimax", "minimax-cn",
"openrouter", "azure", "ollama",
}
assert expected.issubset(PROVIDER_API_KEY_ENV.keys())
@pytest.mark.parametrize(
"provider,env_var",
[
("openai", "OPENAI_API_KEY"),
("anthropic", "ANTHROPIC_API_KEY"),
("google", "GOOGLE_API_KEY"),
("azure", "AZURE_OPENAI_API_KEY"),
("xai", "XAI_API_KEY"),
("deepseek", "DEEPSEEK_API_KEY"),
("qwen", "DASHSCOPE_API_KEY"),
("qwen-cn", "DASHSCOPE_CN_API_KEY"),
("glm", "ZHIPU_API_KEY"),
("glm-cn", "ZHIPU_CN_API_KEY"),
("minimax", "MINIMAX_API_KEY"),
("minimax-cn", "MINIMAX_CN_API_KEY"),
("openrouter", "OPENROUTER_API_KEY"),
],
)
def test_known_providers_resolve(provider, env_var):
assert get_api_key_env(provider) == env_var
def test_ollama_has_no_key():
assert get_api_key_env("ollama") is None
def test_unknown_provider_returns_none():
assert get_api_key_env("not-a-real-provider") is None
def test_case_insensitive_lookup():
assert get_api_key_env("OpenAI") == "OPENAI_API_KEY"
assert get_api_key_env("QWEN-CN") == "DASHSCOPE_CN_API_KEY"
# ---- ensure_api_key behavior ---------------------------------------------
@pytest.fixture
def cli_utils(monkeypatch):
"""Import cli.utils with a fresh environment so module-level state is consistent."""
import importlib
import cli.utils as cli_utils_module
return importlib.reload(cli_utils_module)
def test_ensure_api_key_returns_existing(monkeypatch, cli_utils):
monkeypatch.setenv("OPENAI_API_KEY", "sk-already-set")
result = cli_utils.ensure_api_key("openai")
assert result == "sk-already-set"
def test_ensure_api_key_no_op_for_ollama(monkeypatch, cli_utils):
# Even with no env var set, ollama should not prompt and should return None.
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with patch.object(cli_utils, "questionary") as mock_q:
result = cli_utils.ensure_api_key("ollama")
assert result is None
mock_q.password.assert_not_called()
def test_ensure_api_key_unknown_provider_no_prompt(monkeypatch, cli_utils):
with patch.object(cli_utils, "questionary") as mock_q:
result = cli_utils.ensure_api_key("totally-fake-provider")
assert result is None
mock_q.password.assert_not_called()
def test_ensure_api_key_prompts_and_writes_to_env(monkeypatch, tmp_path, cli_utils):
"""When key is missing, user-pasted value must be written to .env AND os.environ."""
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
monkeypatch.chdir(tmp_path)
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-deepseek-test")})()
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
result = cli_utils.ensure_api_key("deepseek")
assert result == "sk-deepseek-test"
assert os.environ["DEEPSEEK_API_KEY"] == "sk-deepseek-test"
env_file = tmp_path / ".env"
assert env_file.exists()
assert "DEEPSEEK_API_KEY" in env_file.read_text()
assert "sk-deepseek-test" in env_file.read_text()
def test_ensure_api_key_user_cancels_returns_none(monkeypatch, tmp_path, cli_utils):
"""Empty prompt response (user cancelled) must not write to .env."""
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.chdir(tmp_path)
fake_prompt = type("P", (), {"ask": staticmethod(lambda: None)})()
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
result = cli_utils.ensure_api_key("xai")
assert result is None
assert "XAI_API_KEY" not in os.environ
# .env may or may not exist depending on find_dotenv's walk, but if it
# does it must not contain the key.
env_file = tmp_path / ".env"
if env_file.exists():
assert "XAI_API_KEY" not in env_file.read_text()
def test_ensure_api_key_updates_existing_env_file(monkeypatch, tmp_path, cli_utils):
"""An existing .env with other keys must be preserved on writeback."""
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.chdir(tmp_path)
env_file = tmp_path / ".env"
env_file.write_text("OPENAI_API_KEY=sk-existing\nOTHER=value\n")
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-openrouter-new")})()
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
cli_utils.ensure_api_key("openrouter")
content = env_file.read_text()
assert "OPENAI_API_KEY" in content and "sk-existing" in content
assert "OTHER=value" in content
assert "OPENROUTER_API_KEY" in content and "sk-openrouter-new" in content

107
tests/test_capabilities.py Normal file
View File

@@ -0,0 +1,107 @@
"""Unit tests for the LLM capability table."""
import pytest
from tradingagents.llm_clients.capabilities import (
ModelCapabilities,
get_capabilities,
)
@pytest.mark.unit
class TestExactIdMatches:
def test_deepseek_chat_supports_tool_choice(self):
caps = get_capabilities("deepseek-chat")
assert caps.supports_tool_choice is True
def test_deepseek_reasoner_rejects_tool_choice(self):
caps = get_capabilities("deepseek-reasoner")
assert caps.supports_tool_choice is False
assert caps.requires_reasoning_content_roundtrip is True
def test_deepseek_v4_flash_rejects_tool_choice(self):
caps = get_capabilities("deepseek-v4-flash")
assert caps.supports_tool_choice is False
assert caps.requires_reasoning_content_roundtrip is True
def test_deepseek_v4_pro_rejects_tool_choice(self):
caps = get_capabilities("deepseek-v4-pro")
assert caps.supports_tool_choice is False
assert caps.requires_reasoning_content_roundtrip is True
@pytest.mark.unit
class TestPatternMatches:
"""Forward-compat regex patterns catch unknown DeepSeek and MiniMax variants."""
def test_future_deepseek_v5_inherits_thinking_quirks(self):
caps = get_capabilities("deepseek-v5-flash")
assert caps.supports_tool_choice is False
assert caps.requires_reasoning_content_roundtrip is True
def test_future_deepseek_v9_inherits_thinking_quirks(self):
caps = get_capabilities("deepseek-v9-anything")
assert caps.supports_tool_choice is False
def test_reasoner_variant_inherits_thinking_quirks(self):
caps = get_capabilities("deepseek-reasoner-pro")
assert caps.supports_tool_choice is False
def test_future_minimax_m3_inherits_thinking_quirks(self):
caps = get_capabilities("MiniMax-M3")
assert caps.supports_tool_choice is False
def test_future_minimax_m4_highspeed_inherits_thinking_quirks(self):
caps = get_capabilities("MiniMax-M4-highspeed")
assert caps.supports_tool_choice is False
@pytest.mark.unit
class TestMinimaxExactMatches:
"""MiniMax M2.x models reject langchain's function-spec dict tool_choice
(official API enum: none/auto only)."""
def test_m2_7_rejects_tool_choice(self):
caps = get_capabilities("MiniMax-M2.7")
assert caps.supports_tool_choice is False
assert caps.supports_json_mode is False # only MiniMax-Text-01 supports json_object
def test_m2_7_highspeed_rejects_tool_choice(self):
assert get_capabilities("MiniMax-M2.7-highspeed").supports_tool_choice is False
def test_m2_1_rejects_tool_choice(self):
assert get_capabilities("MiniMax-M2.1").supports_tool_choice is False
def test_m2_base_rejects_tool_choice(self):
assert get_capabilities("MiniMax-M2").supports_tool_choice is False
@pytest.mark.unit
class TestDefault:
"""Unknown / non-DeepSeek models get the permissive default."""
def test_gpt_default(self):
caps = get_capabilities("gpt-4.1")
assert caps.supports_tool_choice is True
assert caps.preferred_structured_method == "function_calling"
def test_grok_default(self):
caps = get_capabilities("grok-4-0709")
assert caps.supports_tool_choice is True
def test_unknown_model_default(self):
caps = get_capabilities("totally-made-up-model-id")
assert caps.supports_tool_choice is True
def test_exact_match_precedes_pattern(self):
"""deepseek-chat must NOT match the v\\d regex."""
caps = get_capabilities("deepseek-chat")
assert caps.supports_tool_choice is True
@pytest.mark.unit
def test_capabilities_dataclass_is_frozen():
"""Capability rows are immutable so they can be safely shared."""
caps = get_capabilities("deepseek-chat")
with pytest.raises(Exception):
caps.supports_tool_choice = False # type: ignore[misc]

View File

@@ -0,0 +1,147 @@
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
import sqlite3
import tempfile
import unittest
from pathlib import Path
from typing import TypedDict
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph
from tradingagents.graph.checkpointer import (
checkpoint_step,
clear_checkpoint,
get_checkpointer,
has_checkpoint,
thread_id,
)
# Mutable flag to simulate crash on first run
_should_crash = False
class _SimpleState(TypedDict):
count: int
def _node_a(state: _SimpleState) -> dict:
return {"count": state["count"] + 1}
def _node_b(state: _SimpleState) -> dict:
if _should_crash:
raise RuntimeError("simulated mid-analysis crash")
return {"count": state["count"] + 10}
def _build_graph() -> StateGraph:
builder = StateGraph(_SimpleState)
builder.add_node("analyst", _node_a)
builder.add_node("trader", _node_b)
builder.set_entry_point("analyst")
builder.add_edge("analyst", "trader")
builder.add_edge("trader", END)
return builder
class TestCheckpointResume(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.ticker = "TEST"
self.date = "2026-04-20"
def test_crash_and_resume(self):
"""Crash at 'trader' node, then resume from checkpoint."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Run 1: crash at trader node
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
# Checkpoint should exist at step 1 (analyst completed)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
self.assertEqual(step, 1)
# Run 2: resume — trader succeeds this time
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke(None, config=cfg)
# analyst added 1, trader added 10 → 11
self.assertEqual(result["count"], 11)
def test_clear_checkpoint_allows_fresh_start(self):
"""After clearing, the graph starts from scratch."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Create a checkpoint by crashing
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Clear it
clear_checkpoint(self.tmpdir, self.ticker, self.date)
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Fresh run succeeds from scratch
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config=cfg)
self.assertEqual(result["count"], 11)
def test_different_date_starts_fresh(self):
"""A different date must NOT resume from an existing checkpoint."""
global _should_crash
builder = _build_graph()
date2 = "2026-04-21"
# Run with date1 — crash to leave a checkpoint
_should_crash = True
tid1 = thread_id(self.ticker, self.date)
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# date2 should have no checkpoint
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
# Run with date2 — should start fresh and succeed
_should_crash = False
tid2 = thread_id(self.ticker, date2)
self.assertNotEqual(tid1, tid2)
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
# Fresh run: analyst +1, trader +10 = 11
self.assertEqual(result["count"], 11)
# Original date checkpoint still exists (untouched)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,61 @@
"""Config isolation: get/set must not leak nested-dict references."""
import copy
import unittest
import pytest
import tradingagents.default_config as default_config
from tradingagents.dataflows.config import get_config, set_config
@pytest.mark.unit
class DataflowsConfigIsolationTests(unittest.TestCase):
def setUp(self):
set_config(copy.deepcopy(default_config.DEFAULT_CONFIG))
def test_get_config_returns_deep_copy(self):
cfg = get_config()
cfg["data_vendors"]["core_stock_apis"] = "alpha_vantage"
cfg["tool_vendors"]["get_stock_data"] = "alpha_vantage"
fresh = get_config()
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "yfinance")
self.assertNotIn("get_stock_data", fresh["tool_vendors"])
def test_set_config_does_not_alias_caller_nested_dicts(self):
custom = copy.deepcopy(default_config.DEFAULT_CONFIG)
custom["data_vendors"]["core_stock_apis"] = "alpha_vantage"
custom["tool_vendors"]["get_stock_data"] = "alpha_vantage"
set_config(custom)
custom["data_vendors"]["core_stock_apis"] = "yfinance"
custom["tool_vendors"]["get_stock_data"] = "yfinance"
fresh = get_config()
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
def test_partial_nested_update_preserves_existing_defaults(self):
set_config(
{
"data_vendors": {
"core_stock_apis": "alpha_vantage",
}
}
)
fresh = get_config()
self.assertEqual(fresh["data_vendors"]["core_stock_apis"], "alpha_vantage")
self.assertEqual(fresh["data_vendors"]["technical_indicators"], "yfinance")
self.assertEqual(fresh["data_vendors"]["fundamental_data"], "yfinance")
self.assertEqual(fresh["data_vendors"]["news_data"], "yfinance")
def test_nested_dict_updates_merge_one_level_deep(self):
set_config({"tool_vendors": {"get_stock_data": "alpha_vantage"}})
set_config({"tool_vendors": {"get_news": "alpha_vantage"}})
fresh = get_config()
self.assertEqual(fresh["tool_vendors"]["get_stock_data"], "alpha_vantage")
self.assertEqual(fresh["tool_vendors"]["get_news"], "alpha_vantage")

View File

@@ -0,0 +1,240 @@
"""Tests for DeepSeekChatOpenAI thinking-mode behaviour.
Two pieces verified:
1. ``reasoning_content`` is captured on receive into the AIMessage's
``additional_kwargs`` and re-attached on send so DeepSeek's API
sees the same value across turns.
2. ``with_structured_output`` consults the capability table and
suppresses ``tool_choice`` for models that reject it (V4 + reasoner),
matching DeepSeek's official tool-calling pattern at
https://api-docs.deepseek.com/guides/tool_calls.
"""
import os
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompt_values import ChatPromptValue
from pydantic import BaseModel
from tradingagents.llm_clients.openai_client import (
DeepSeekChatOpenAI,
NormalizedChatOpenAI,
_input_to_messages,
)
# ---------------------------------------------------------------------------
# _input_to_messages — the helper that handles list / ChatPromptValue / other
# (Gemini bot review note: non-list inputs must also work)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestInputToMessages:
def test_list_input_returned_as_is(self):
msgs = [HumanMessage(content="hi")]
assert _input_to_messages(msgs) is msgs
def test_chat_prompt_value_unwrapped(self):
msgs = [HumanMessage(content="hi")]
prompt_value = ChatPromptValue(messages=msgs)
assert _input_to_messages(prompt_value) == msgs
def test_string_input_yields_empty_list(self):
# A bare string isn't a message-bearing input; the caller's normal
# langchain conversion happens upstream of _get_request_payload.
assert _input_to_messages("hello") == []
# ---------------------------------------------------------------------------
# Reasoning content propagation across turns
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDeepSeekReasoningContent:
def _client(self):
os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder")
return DeepSeekChatOpenAI(
model="deepseek-v4-flash",
api_key="placeholder",
base_url="https://api.deepseek.com",
)
def test_capture_on_receive(self):
"""When the response carries reasoning_content, it lands on the
AIMessage's additional_kwargs so the next turn can echo it back."""
client = self._client()
result = client._create_chat_result(
{
"model": "deepseek-v4-flash",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Plan: buy NVDA.",
"reasoning_content": "Step 1: trend is up. Step 2: ...",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
)
ai = result.generations[0].message
assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..."
def test_propagate_on_send(self):
"""When an outgoing AIMessage carries reasoning_content, the request
payload echoes it on the corresponding message dict."""
client = self._client()
prior = AIMessage(
content="Plan",
additional_kwargs={"reasoning_content": "weighed bull case"},
)
new_user = HumanMessage(content="Refine.")
payload = client._get_request_payload([prior, new_user])
# Find the assistant message in the payload
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
assert assistant_dicts, "assistant message missing from outgoing payload"
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
def test_propagate_through_chat_prompt_value(self):
"""Gemini bot review note: non-list inputs (ChatPromptValue) must
also propagate reasoning_content."""
client = self._client()
prior = AIMessage(
content="Plan",
additional_kwargs={"reasoning_content": "weighed bull case"},
)
prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")])
payload = client._get_request_payload(prompt_value)
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
# ---------------------------------------------------------------------------
# Capability-driven structured output: tool_choice suppressed for V4 + reasoner
# ---------------------------------------------------------------------------
def _bound_kwargs(runnable):
"""Extract bind() kwargs from a with_structured_output result."""
first = runnable.steps[0] if hasattr(runnable, "steps") else runnable
return getattr(first, "kwargs", {})
@pytest.mark.unit
class TestStructuredOutputCapabilityDispatch:
"""DeepSeek V4 and reasoner reject the tool_choice parameter
(official guide: api-docs.deepseek.com/guides/tool_calls passes
tools=[...] without tool_choice). Verify the capability dispatch
suppresses tool_choice for those models and sends it for chat."""
class _Sample(BaseModel):
answer: str
def _client(self, model):
return DeepSeekChatOpenAI(
model=model, api_key="placeholder", base_url="https://api.deepseek.com",
)
def test_chat_sends_tool_choice(self):
bound = self._client("deepseek-chat").with_structured_output(self._Sample)
assert _bound_kwargs(bound).get("tool_choice") is not None
def test_reasoner_suppresses_tool_choice(self):
bound = self._client("deepseek-reasoner").with_structured_output(self._Sample)
# tool_choice is either absent or explicitly None — both are valid
# signals that langchain's bind_tools will skip the parameter.
assert _bound_kwargs(bound).get("tool_choice") in (None, ...) or \
"tool_choice" not in _bound_kwargs(bound)
def test_v4_flash_suppresses_tool_choice(self):
bound = self._client("deepseek-v4-flash").with_structured_output(self._Sample)
assert _bound_kwargs(bound).get("tool_choice") is None or \
"tool_choice" not in _bound_kwargs(bound)
def test_v4_pro_suppresses_tool_choice(self):
bound = self._client("deepseek-v4-pro").with_structured_output(self._Sample)
assert _bound_kwargs(bound).get("tool_choice") is None or \
"tool_choice" not in _bound_kwargs(bound)
def test_future_v_variant_via_regex(self):
"""Forward-compat: unknown deepseek-v\\d-* IDs inherit V4 quirks."""
bound = self._client("deepseek-v5-hypothetical").with_structured_output(self._Sample)
assert _bound_kwargs(bound).get("tool_choice") is None or \
"tool_choice" not in _bound_kwargs(bound)
def test_schema_is_still_bound_as_tool(self):
"""tool_choice is suppressed, but the schema is still bound as a tool —
exactly matching DeepSeek's official tool-calling examples."""
bound = self._client("deepseek-reasoner").with_structured_output(self._Sample)
kwargs = _bound_kwargs(bound)
tools = kwargs.get("tools", [])
assert any(
t.get("function", {}).get("name") == "_Sample" for t in tools
), f"schema not bound as a tool: {tools}"
# ---------------------------------------------------------------------------
# Live API: structured output round-trips against the real DeepSeek backend
# ---------------------------------------------------------------------------
def _has_real_deepseek_key():
key = os.environ.get("DEEPSEEK_API_KEY", "")
return bool(key) and key != "placeholder"
@pytest.mark.integration
@pytest.mark.skipif(
not _has_real_deepseek_key(),
reason="DEEPSEEK_API_KEY not set (or placeholder); skipping live API call",
)
class TestDeepSeekLiveStructuredOutput:
"""End-to-end: a real DeepSeek V4-flash call returns a typed instance.
Verifies the no-tool_choice path doesn't trigger the 400 reported in
issue #678 and that the structured-output binding still parses to a
Pydantic instance.
"""
class _Pick(BaseModel):
action: str
confidence: float
def test_v4_flash_returns_structured_output(self):
client = DeepSeekChatOpenAI(
model="deepseek-v4-flash",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url="https://api.deepseek.com",
timeout=60,
)
bound = client.with_structured_output(self._Pick)
result = bound.invoke(
"Pick BUY or SELL or HOLD for a tech stock with strong earnings. "
"Confidence is a float between 0 and 1."
)
assert isinstance(result, self._Pick)
assert result.action in {"BUY", "SELL", "HOLD"}
assert 0.0 <= result.confidence <= 1.0
# ---------------------------------------------------------------------------
# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBaseClassIsolation:
def test_normalized_does_not_propagate_reasoning_content(self):
"""The general-purpose NormalizedChatOpenAI must not carry
DeepSeek-specific behaviour. Only the subclass does."""
assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or (
NormalizedChatOpenAI._get_request_payload
is NormalizedChatOpenAI.__bases__[0]._get_request_payload
)

View File

@@ -0,0 +1,98 @@
"""Tests for TRADINGAGENTS_* env-var overlay onto DEFAULT_CONFIG."""
from __future__ import annotations
import importlib
import pytest
import tradingagents.default_config as default_config_module
def _reload_with_env(monkeypatch, **overrides):
"""Set/clear env vars then reload default_config to re-evaluate DEFAULT_CONFIG."""
for key in list(default_config_module._ENV_OVERRIDES):
monkeypatch.delenv(key, raising=False)
for key, val in overrides.items():
monkeypatch.setenv(key, val)
return importlib.reload(default_config_module)
def test_no_env_uses_built_in_defaults(monkeypatch):
dc = _reload_with_env(monkeypatch)
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gpt-5.4"
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gpt-5.4-mini"
assert dc.DEFAULT_CONFIG["backend_url"] is None
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is False
def test_string_overrides(monkeypatch):
dc = _reload_with_env(
monkeypatch,
TRADINGAGENTS_LLM_PROVIDER="google",
TRADINGAGENTS_DEEP_THINK_LLM="gemini-3-pro-preview",
TRADINGAGENTS_QUICK_THINK_LLM="gemini-3-flash-preview",
TRADINGAGENTS_LLM_BACKEND_URL="https://example.invalid/v1",
TRADINGAGENTS_OUTPUT_LANGUAGE="Chinese",
)
assert dc.DEFAULT_CONFIG["llm_provider"] == "google"
assert dc.DEFAULT_CONFIG["deep_think_llm"] == "gemini-3-pro-preview"
assert dc.DEFAULT_CONFIG["quick_think_llm"] == "gemini-3-flash-preview"
assert dc.DEFAULT_CONFIG["backend_url"] == "https://example.invalid/v1"
assert dc.DEFAULT_CONFIG["output_language"] == "Chinese"
def test_int_coercion(monkeypatch):
dc = _reload_with_env(
monkeypatch,
TRADINGAGENTS_MAX_DEBATE_ROUNDS="3",
TRADINGAGENTS_MAX_RISK_ROUNDS="2",
)
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 3
assert isinstance(dc.DEFAULT_CONFIG["max_debate_rounds"], int)
assert dc.DEFAULT_CONFIG["max_risk_discuss_rounds"] == 2
assert isinstance(dc.DEFAULT_CONFIG["max_risk_discuss_rounds"], int)
@pytest.mark.parametrize(
"raw,expected",
[
("true", True), ("True", True), ("1", True), ("yes", True), ("on", True),
("false", False), ("False", False), ("0", False), ("no", False), ("off", False),
],
)
def test_bool_coercion(monkeypatch, raw, expected):
dc = _reload_with_env(monkeypatch, TRADINGAGENTS_CHECKPOINT_ENABLED=raw)
assert dc.DEFAULT_CONFIG["checkpoint_enabled"] is expected
def test_empty_env_value_is_passthrough(monkeypatch):
"""Empty TRADINGAGENTS_* values must not clobber the built-in default."""
dc = _reload_with_env(
monkeypatch,
TRADINGAGENTS_LLM_PROVIDER="",
TRADINGAGENTS_MAX_DEBATE_ROUNDS="",
)
assert dc.DEFAULT_CONFIG["llm_provider"] == "openai"
assert dc.DEFAULT_CONFIG["max_debate_rounds"] == 1
def test_invalid_int_raises(monkeypatch):
"""Garbage int values should surface a ValueError at import, not silently misconfigure."""
monkeypatch.setenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", "not-a-number")
with pytest.raises(ValueError):
importlib.reload(default_config_module)
# Restore module state for subsequent tests in this process
monkeypatch.delenv("TRADINGAGENTS_MAX_DEBATE_ROUNDS", raising=False)
importlib.reload(default_config_module)
def test_unknown_env_var_is_ignored(monkeypatch):
"""Env vars outside _ENV_OVERRIDES must not bleed into DEFAULT_CONFIG."""
dc = _reload_with_env(
monkeypatch,
TRADINGAGENTS_NONEXISTENT_KEY="oops",
)
assert "nonexistent_key" not in dc.DEFAULT_CONFIG

View File

@@ -1,9 +1,12 @@
import unittest
from unittest.mock import patch
import pytest
from tradingagents.llm_clients.google_client import GoogleClient
@pytest.mark.unit
class TestGoogleApiKeyStandardization(unittest.TestCase):
"""Verify GoogleClient accepts unified api_key parameter."""

773
tests/test_memory_log.py Normal file
View File

@@ -0,0 +1,773 @@
"""Tests for TradingMemoryLog — storage, deferred reflection, PM injection, legacy removal."""
import pytest
import pandas as pd
from unittest.mock import MagicMock, patch
from tradingagents.agents.utils.memory import TradingMemoryLog
from tradingagents.agents.schemas import PortfolioDecision, PortfolioRating
from tradingagents.graph.reflection import Reflector
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.graph.propagation import Propagator
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
_SEP = TradingMemoryLog._SEPARATOR
DECISION_BUY = "Rating: Buy\nEnter at $189-192, 6% portfolio cap."
DECISION_OVERWEIGHT = (
"Rating: Overweight\n"
"Executive Summary: Moderate position, await confirmation.\n"
"Investment Thesis: Strong fundamentals but near-term headwinds."
)
DECISION_SELL = "Rating: Sell\nExit position immediately."
DECISION_NO_RATING = (
"Executive Summary: Complex situation with multiple competing factors.\n"
"Investment Thesis: No clear directional signal at this time."
)
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def make_log(tmp_path, filename="trading_memory.md"):
config = {"memory_log_path": str(tmp_path / filename)}
return TradingMemoryLog(config)
def _seed_completed(tmp_path, ticker, date, decision_text, reflection_text, filename="trading_memory.md"):
"""Write a completed entry directly to file, bypassing the API."""
entry = (
f"[{date} | {ticker} | Buy | +1.0% | +0.5% | 5d]\n\n"
f"DECISION:\n{decision_text}\n\n"
f"REFLECTION:\n{reflection_text}"
+ _SEP
)
with open(tmp_path / filename, "a", encoding="utf-8") as f:
f.write(entry)
def _resolve_entry(log, ticker, date, decision, reflection="Good call."):
"""Store a decision then immediately resolve it via the API."""
log.store_decision(ticker, date, decision)
log.update_with_outcome(ticker, date, 0.05, 0.02, 5, reflection)
def _price_df(prices):
"""Minimal DataFrame matching yfinance .history() output shape."""
return pd.DataFrame({"Close": prices})
def _make_pm_state(past_context=""):
"""Minimal AgentState dict for portfolio_manager_node."""
return {
"company_of_interest": "NVDA",
"past_context": past_context,
"risk_debate_state": {
"history": "Risk debate history.",
"aggressive_history": "",
"conservative_history": "",
"neutral_history": "",
"judge_decision": "",
"current_aggressive_response": "",
"current_conservative_response": "",
"current_neutral_response": "",
"count": 1,
},
"market_report": "Market report.",
"sentiment_report": "Sentiment report.",
"news_report": "News report.",
"fundamentals_report": "Fundamentals report.",
"investment_plan": "Research plan.",
"trader_investment_plan": "Trader plan.",
}
def _structured_pm_llm(captured: dict, decision: PortfolioDecision | None = None):
"""Build a MagicMock LLM whose with_structured_output binding captures the
prompt and returns a real PortfolioDecision (so render_pm_decision works).
"""
if decision is None:
decision = PortfolioDecision(
rating=PortfolioRating.HOLD,
executive_summary="Hold the position; await catalyst.",
investment_thesis="Balanced view; neither side carried the debate.",
)
structured = MagicMock()
structured.invoke.side_effect = lambda prompt: (
captured.__setitem__("prompt", prompt) or decision
)
llm = MagicMock()
llm.with_structured_output.return_value = structured
return llm
# ---------------------------------------------------------------------------
# Core: storage and read path
# ---------------------------------------------------------------------------
class TestTradingMemoryLogCore:
def test_store_creates_file(self, tmp_path):
log = make_log(tmp_path)
assert not (tmp_path / "trading_memory.md").exists()
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
assert (tmp_path / "trading_memory.md").exists()
def test_store_appends_not_overwrites(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
entries = log.load_entries()
assert len(entries) == 2
assert entries[0]["ticker"] == "NVDA"
assert entries[1]["ticker"] == "AAPL"
def test_store_decision_idempotent(self, tmp_path):
"""Calling store_decision twice with same (ticker, date) stores only one entry."""
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
assert len(log.load_entries()) == 1
def test_batch_update_resolves_multiple_entries(self, tmp_path):
"""batch_update_with_outcomes resolves multiple pending entries in one write."""
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
log.store_decision("NVDA", "2026-01-12", DECISION_SELL)
updates = [
{"ticker": "NVDA", "trade_date": "2026-01-05",
"raw_return": 0.05, "alpha_return": 0.02, "holding_days": 5,
"reflection": "First correct."},
{"ticker": "NVDA", "trade_date": "2026-01-12",
"raw_return": -0.03, "alpha_return": -0.01, "holding_days": 5,
"reflection": "Second correct."},
]
log.batch_update_with_outcomes(updates)
entries = log.load_entries()
assert len(entries) == 2
assert all(not e["pending"] for e in entries)
assert entries[0]["reflection"] == "First correct."
assert entries[1]["reflection"] == "Second correct."
def test_pending_tag_format(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
assert "[2026-01-10 | NVDA | Buy | pending]" in text
# Rating parsing
def test_rating_parsed_buy(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
assert log.load_entries()[0]["rating"] == "Buy"
def test_rating_parsed_overweight(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
assert log.load_entries()[0]["rating"] == "Overweight"
def test_rating_fallback_hold(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING)
assert log.load_entries()[0]["rating"] == "Hold"
def test_rating_priority_over_prose(self, tmp_path):
"""'Rating: X' label wins even when an opposing rating word appears earlier in prose."""
decision = (
"The sell thesis is weak. The hold case is marginal.\n\n"
"Rating: Buy\n\n"
"Executive Summary: Strong fundamentals support the position."
)
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", decision)
assert log.load_entries()[0]["rating"] == "Buy"
# Delimiter robustness
def test_decision_with_markdown_separator(self, tmp_path):
"""LLM decision containing '---' must not corrupt the entry."""
decision = "Rating: Buy\n\n---\n\nRisk: elevated volatility."
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", decision)
entries = log.load_entries()
assert len(entries) == 1
assert "Risk: elevated volatility" in entries[0]["decision"]
# load_entries
def test_load_entries_empty_file(self, tmp_path):
log = make_log(tmp_path)
assert log.load_entries() == []
def test_load_entries_single(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
entries = log.load_entries()
assert len(entries) == 1
e = entries[0]
assert e["date"] == "2026-01-10"
assert e["ticker"] == "NVDA"
assert e["rating"] == "Buy"
assert e["pending"] is True
assert e["raw"] is None
def test_load_entries_multiple(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
log.store_decision("AAPL", "2026-01-11", DECISION_OVERWEIGHT)
log.store_decision("MSFT", "2026-01-12", DECISION_NO_RATING)
entries = log.load_entries()
assert len(entries) == 3
assert [e["ticker"] for e in entries] == ["NVDA", "AAPL", "MSFT"]
def test_decision_content_preserved(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
assert log.load_entries()[0]["decision"] == DECISION_BUY.strip()
# get_pending_entries
def test_get_pending_returns_pending_only(self, tmp_path):
log = make_log(tmp_path)
_seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA.", "Correct.")
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
pending = log.get_pending_entries()
assert len(pending) == 1
assert pending[0]["ticker"] == "NVDA"
assert pending[0]["date"] == "2026-01-10"
# get_past_context
def test_get_past_context_empty(self, tmp_path):
log = make_log(tmp_path)
assert log.get_past_context("NVDA") == ""
def test_get_past_context_pending_excluded(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
assert log.get_past_context("NVDA") == ""
def test_get_past_context_same_ticker(self, tmp_path):
log = make_log(tmp_path)
_seed_completed(tmp_path, "NVDA", "2026-01-05", "Buy NVDA — AI capex thesis intact.", "Directionally correct.")
ctx = log.get_past_context("NVDA")
assert "Past analyses of NVDA" in ctx
assert "Buy NVDA" in ctx
def test_get_past_context_cross_ticker(self, tmp_path):
log = make_log(tmp_path)
_seed_completed(tmp_path, "AAPL", "2026-01-05", "Buy AAPL — Services growth.", "Correct.")
ctx = log.get_past_context("NVDA")
assert "Recent cross-ticker lessons" in ctx
assert "Past analyses of NVDA" not in ctx
def test_n_same_limit_respected(self, tmp_path):
"""Only the n_same most recent same-ticker entries are included."""
log = make_log(tmp_path)
for i in range(6):
_seed_completed(tmp_path, "NVDA", f"2026-01-{i+1:02d}", f"Buy entry {i}.", "Correct.")
ctx = log.get_past_context("NVDA", n_same=5)
assert "Buy entry 0" not in ctx
assert "Buy entry 5" in ctx
def test_n_cross_limit_respected(self, tmp_path):
"""Only the n_cross most recent cross-ticker entries are included."""
log = make_log(tmp_path)
for i, ticker in enumerate(["AAPL", "MSFT", "GOOG", "META"]):
_seed_completed(tmp_path, ticker, f"2026-01-{i+1:02d}", f"Buy {ticker}.", "Correct.")
ctx = log.get_past_context("NVDA", n_cross=3)
assert "AAPL" not in ctx
assert "META" in ctx
# No-op when config is None
def test_no_log_path_is_noop(self):
log = TradingMemoryLog(config=None)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
assert log.load_entries() == []
assert log.get_past_context("NVDA") == ""
# Rotation: opt-in cap on resolved entries
def test_rotation_disabled_by_default(self, tmp_path):
"""Without max_entries, all resolved entries are kept."""
log = make_log(tmp_path)
for i in range(7):
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
assert len(log.load_entries()) == 7
def test_rotation_prunes_oldest_resolved(self, tmp_path):
"""When max_entries is set and exceeded, oldest resolved entries are pruned."""
log = TradingMemoryLog({
"memory_log_path": str(tmp_path / "trading_memory.md"),
"memory_log_max_entries": 3,
})
# Resolve 5 entries; rotation should keep only the 3 most recent.
for i in range(5):
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
entries = log.load_entries()
assert len(entries) == 3
# Confirm the OLDEST were dropped, not the newest.
dates = [e["date"] for e in entries]
assert dates == ["2026-01-03", "2026-01-04", "2026-01-05"]
def test_rotation_never_prunes_pending(self, tmp_path):
"""Pending entries (unresolved) are kept regardless of the cap."""
log = TradingMemoryLog({
"memory_log_path": str(tmp_path / "trading_memory.md"),
"memory_log_max_entries": 2,
})
# 3 resolved + 2 pending. With cap=2, only 2 resolved survive; both pending stay.
for i in range(3):
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Resolved {i}.")
log.store_decision("NVDA", "2026-02-01", DECISION_BUY)
log.store_decision("NVDA", "2026-02-02", DECISION_OVERWEIGHT)
# Trigger rotation by resolving one more entry — pending entries must stay.
_resolve_entry(log, "NVDA", "2026-01-04", DECISION_BUY, "Resolved 3.")
entries = log.load_entries()
pending = [e for e in entries if e["pending"]]
resolved = [e for e in entries if not e["pending"]]
assert len(pending) == 2, "pending entries must never be pruned"
assert len(resolved) == 2, f"expected 2 resolved after rotation, got {len(resolved)}"
def test_rotation_under_cap_is_noop(self, tmp_path):
"""No rotation when resolved count <= max_entries."""
log = TradingMemoryLog({
"memory_log_path": str(tmp_path / "trading_memory.md"),
"memory_log_max_entries": 10,
})
for i in range(3):
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
assert len(log.load_entries()) == 3
# Rating parsing: markdown bold and numbered list formats
def test_rating_parsed_from_bold_markdown(self, tmp_path):
"""**Rating**: Buy — markdown bold around the label must not prevent parsing."""
decision = "**Rating**: Buy\nEnter at $190."
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", decision)
assert log.load_entries()[0]["rating"] == "Buy"
def test_rating_parsed_from_bold_value(self, tmp_path):
"""Rating: **Sell** — markdown bold around the value must not prevent parsing."""
decision = "Rating: **Sell**\nExit immediately."
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", decision)
assert log.load_entries()[0]["rating"] == "Sell"
def test_rating_label_wins_over_prose_with_markdown(self, tmp_path):
"""Rating: **Sell** must win even when prose contains a conflicting rating word."""
decision = (
"The buy thesis is weakened by guidance.\n"
"Rating: **Sell**\n"
"Exit before earnings."
)
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", decision)
assert log.load_entries()[0]["rating"] == "Sell"
def test_rating_parsed_from_numbered_list(self, tmp_path):
"""1. Rating: Buy — numbered list prefix must not prevent parsing."""
decision = "1. Rating: Buy\nEnter at $190."
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", decision)
assert log.load_entries()[0]["rating"] == "Buy"
# ---------------------------------------------------------------------------
# Deferred reflection: update_with_outcome, Reflector, _fetch_returns
# ---------------------------------------------------------------------------
class TestDeferredReflection:
# update_with_outcome
def test_update_replaces_pending_tag(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
assert "[2026-01-10 | NVDA | Buy | pending]" not in text
assert "+4.2%" in text
assert "+2.1%" in text
assert "5d" in text
def test_update_appends_reflection(self, tmp_path):
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
entries = log.load_entries()
assert len(entries) == 1
e = entries[0]
assert e["pending"] is False
assert e["reflection"] == "Momentum confirmed."
assert e["decision"] == DECISION_BUY.strip()
def test_update_preserves_other_entries(self, tmp_path):
"""Only the matching entry is modified; all other entries remain unchanged."""
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
log.store_decision("AAPL", "2026-01-11", "Rating: Hold\nHold AAPL.")
log.store_decision("MSFT", "2026-01-12", DECISION_SELL)
log.update_with_outcome("AAPL", "2026-01-11", 0.01, -0.01, 5, "Neutral result.")
entries = log.load_entries()
assert len(entries) == 3
nvda, aapl, msft = entries
assert nvda["ticker"] == "NVDA" and nvda["pending"] is True
assert aapl["ticker"] == "AAPL" and aapl["pending"] is False
assert aapl["reflection"] == "Neutral result."
assert msft["ticker"] == "MSFT" and msft["pending"] is True
def test_update_atomic_write(self, tmp_path):
"""A pre-existing .tmp file is overwritten; the log is correctly updated."""
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
stale_tmp = tmp_path / "trading_memory.tmp"
stale_tmp.write_text("GARBAGE CONTENT — should be overwritten", encoding="utf-8")
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Correct.")
assert not stale_tmp.exists()
entries = log.load_entries()
assert len(entries) == 1
assert entries[0]["reflection"] == "Correct."
assert entries[0]["pending"] is False
def test_update_noop_when_no_log_path(self):
log = TradingMemoryLog(config=None)
log.update_with_outcome("NVDA", "2026-01-10", 0.05, 0.02, 5, "Reflection")
def test_formatting_roundtrip_after_update(self, tmp_path):
"""All fields intact and blank line between tag and DECISION preserved after update."""
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-10", DECISION_BUY)
log.update_with_outcome("NVDA", "2026-01-10", 0.042, 0.021, 5, "Momentum confirmed.")
entries = log.load_entries()
assert len(entries) == 1
e = entries[0]
assert e["pending"] is False
assert e["decision"] == DECISION_BUY.strip()
assert e["reflection"] == "Momentum confirmed."
assert e["raw"] == "+4.2%"
assert e["alpha"] == "+2.1%"
assert e["holding"] == "5d"
raw_text = (tmp_path / "trading_memory.md").read_text(encoding="utf-8")
assert "[2026-01-10 | NVDA | Buy | +4.2% | +2.1% | 5d]\n\nDECISION:" in raw_text
# Reflector.reflect_on_final_decision
def test_reflect_on_final_decision_returns_llm_output(self):
mock_llm = MagicMock()
mock_llm.invoke.return_value.content = "Directionally correct. Thesis confirmed."
reflector = Reflector(mock_llm)
result = reflector.reflect_on_final_decision(
final_decision=DECISION_BUY, raw_return=0.042, alpha_return=0.021
)
assert result == "Directionally correct. Thesis confirmed."
mock_llm.invoke.assert_called_once()
def test_reflect_on_final_decision_includes_returns_in_prompt(self):
"""Return figures are present in the human message sent to the LLM."""
mock_llm = MagicMock()
mock_llm.invoke.return_value.content = "Incorrect call."
reflector = Reflector(mock_llm)
reflector.reflect_on_final_decision(
final_decision=DECISION_SELL, raw_return=-0.08, alpha_return=-0.05
)
messages = mock_llm.invoke.call_args[0][0]
human_content = next(content for role, content in messages if role == "human")
assert "-8.0%" in human_content
assert "-5.0%" in human_content
assert "Exit position immediately." in human_content
# TradingAgentsGraph._fetch_returns
def test_fetch_returns_valid_ticker(self):
stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0]
spy_prices = [400.0, 402.0, 404.0, 403.0, 405.0, 406.0]
mock_graph = MagicMock(spec=TradingAgentsGraph)
with patch("yfinance.Ticker") as mock_ticker_cls:
def _make_ticker(sym):
m = MagicMock()
m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices)
return m
mock_ticker_cls.side_effect = _make_ticker
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05")
assert raw is not None and alpha is not None and days is not None
assert isinstance(raw, float) and isinstance(alpha, float) and isinstance(days, int)
assert days == 5
def test_fetch_returns_too_recent(self):
"""Only 1 data point available → returns (None, None, None), no crash."""
mock_graph = MagicMock(spec=TradingAgentsGraph)
with patch("yfinance.Ticker") as mock_ticker_cls:
m = MagicMock()
m.history.return_value = _price_df([100.0])
mock_ticker_cls.return_value = m
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-04-19")
assert raw is None and alpha is None and days is None
def test_fetch_returns_delisted(self):
"""Empty DataFrame → returns (None, None, None), no crash."""
mock_graph = MagicMock(spec=TradingAgentsGraph)
with patch("yfinance.Ticker") as mock_ticker_cls:
m = MagicMock()
m.history.return_value = pd.DataFrame({"Close": []})
mock_ticker_cls.return_value = m
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "XXXXXFAKE", "2026-01-10")
assert raw is None and alpha is None and days is None
def test_fetch_returns_spy_shorter_than_stock(self):
"""SPY having fewer rows than the stock must not raise IndexError."""
stock_prices = [100.0, 102.0, 104.0, 103.0, 105.0, 106.0]
spy_prices = [400.0, 402.0, 403.0]
mock_graph = MagicMock(spec=TradingAgentsGraph)
with patch("yfinance.Ticker") as mock_ticker_cls:
def _make_ticker(sym):
m = MagicMock()
m.history.return_value = _price_df(spy_prices if sym == "SPY" else stock_prices)
return m
mock_ticker_cls.side_effect = _make_ticker
raw, alpha, days = TradingAgentsGraph._fetch_returns(mock_graph, "NVDA", "2026-01-05")
assert raw is not None and alpha is not None and days is not None
assert days == 2
# TradingAgentsGraph._resolve_pending_entries
def test_resolve_skips_other_tickers(self, tmp_path):
"""Pending AAPL entry is not resolved when the run is for NVDA."""
log = make_log(tmp_path)
log.store_decision("AAPL", "2026-01-10", DECISION_BUY)
mock_graph = MagicMock(spec=TradingAgentsGraph)
mock_graph.memory_log = log
mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5))
TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA")
mock_graph._fetch_returns.assert_not_called()
assert len(log.get_pending_entries()) == 1
def test_resolve_marks_entry_completed(self, tmp_path):
"""After resolve, get_pending_entries() is empty and the entry has a REFLECTION."""
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
mock_reflector = MagicMock()
mock_reflector.reflect_on_final_decision.return_value = "Momentum confirmed."
mock_graph = MagicMock(spec=TradingAgentsGraph)
mock_graph.memory_log = log
mock_graph.reflector = mock_reflector
mock_graph._fetch_returns = MagicMock(return_value=(0.05, 0.02, 5))
TradingAgentsGraph._resolve_pending_entries(mock_graph, "NVDA")
assert log.get_pending_entries() == []
entries = log.load_entries()
assert len(entries) == 1
assert entries[0]["pending"] is False
assert entries[0]["reflection"] == "Momentum confirmed."
assert "+5.0%" in entries[0]["raw"]
assert "+2.0%" in entries[0]["alpha"]
# ---------------------------------------------------------------------------
# Portfolio Manager injection: past_context in state and prompt
# ---------------------------------------------------------------------------
class TestPortfolioManagerInjection:
# past_context in initial state
def test_past_context_in_initial_state(self):
propagator = Propagator()
state = propagator.create_initial_state("NVDA", "2026-01-10", past_context="some context")
assert "past_context" in state
assert state["past_context"] == "some context"
def test_past_context_defaults_to_empty(self):
propagator = Propagator()
state = propagator.create_initial_state("NVDA", "2026-01-10")
assert state["past_context"] == ""
# PM prompt
def test_pm_prompt_includes_past_context(self):
captured = {}
llm = _structured_pm_llm(captured)
pm_node = create_portfolio_manager(llm)
state = _make_pm_state(past_context="[2026-01-05 | NVDA | Buy | +5.0% | +2.0% | 5d]\nGreat call.")
pm_node(state)
assert "Lessons from prior decisions and outcomes" in captured["prompt"]
assert "Great call." in captured["prompt"]
def test_pm_no_past_context_no_section(self):
"""PM prompt omits the lessons section entirely when past_context is empty."""
captured = {}
llm = _structured_pm_llm(captured)
pm_node = create_portfolio_manager(llm)
state = _make_pm_state(past_context="")
pm_node(state)
assert "Lessons from prior decisions" not in captured["prompt"]
def test_pm_returns_rendered_markdown_with_rating(self):
"""The structured PortfolioDecision is rendered to markdown that
downstream consumers (memory log, signal processor, CLI display)
can parse without any extra LLM call."""
captured = {}
decision = PortfolioDecision(
rating=PortfolioRating.OVERWEIGHT,
executive_summary="Build position gradually over the next two weeks.",
investment_thesis="AI capex cycle remains intact; institutional flows constructive.",
price_target=215.0,
time_horizon="3-6 months",
)
llm = _structured_pm_llm(captured, decision)
pm_node = create_portfolio_manager(llm)
result = pm_node(_make_pm_state())
md = result["final_trade_decision"]
assert "**Rating**: Overweight" in md
assert "**Executive Summary**: Build position gradually" in md
assert "**Investment Thesis**: AI capex cycle" in md
assert "**Price Target**: 215.0" in md
assert "**Time Horizon**: 3-6 months" in md
def test_pm_falls_back_to_freetext_when_structured_unavailable(self):
"""If a provider does not support with_structured_output, the agent
falls back to a plain invoke and returns whatever prose the model
produced, so the pipeline never blocks."""
plain_response = "**Rating**: Sell\n\nExit ahead of guidance."
llm = MagicMock()
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
llm.invoke.return_value = MagicMock(content=plain_response)
pm_node = create_portfolio_manager(llm)
result = pm_node(_make_pm_state())
assert result["final_trade_decision"] == plain_response
# get_past_context ordering and limits
def test_same_ticker_prioritised(self, tmp_path):
"""Same-ticker entries in same-ticker section; cross-ticker entries in cross-ticker section."""
log = make_log(tmp_path)
_resolve_entry(log, "NVDA", "2026-01-05", DECISION_BUY, "Momentum confirmed.")
_resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued.")
result = log.get_past_context("NVDA")
assert "Past analyses of NVDA" in result
assert "Recent cross-ticker lessons" in result
same_block, cross_block = result.split("Recent cross-ticker lessons")
assert "NVDA" in same_block
assert "AAPL" in cross_block
def test_cross_ticker_reflection_only(self, tmp_path):
"""Cross-ticker entries show only the REFLECTION text, not the full DECISION."""
log = make_log(tmp_path)
_resolve_entry(log, "AAPL", "2026-01-06", DECISION_SELL, "Overvalued correction.")
result = log.get_past_context("NVDA")
assert "Overvalued correction." in result
assert "Exit position immediately." not in result
def test_n_same_limit_respected(self, tmp_path):
"""More than 5 same-ticker completed entries → only 5 injected."""
log = make_log(tmp_path)
for i in range(7):
_resolve_entry(log, "NVDA", f"2026-01-{i+1:02d}", DECISION_BUY, f"Lesson {i}.")
result = log.get_past_context("NVDA", n_same=5)
lessons_present = sum(1 for i in range(7) if f"Lesson {i}." in result)
assert lessons_present == 5
def test_n_cross_limit_respected(self, tmp_path):
"""More than 3 cross-ticker completed entries → only 3 injected."""
log = make_log(tmp_path)
tickers = ["AAPL", "MSFT", "TSLA", "AMZN", "GOOG"]
for i, ticker in enumerate(tickers):
_resolve_entry(log, ticker, f"2026-01-{i+1:02d}", DECISION_BUY, f"{ticker} lesson.")
result = log.get_past_context("NVDA", n_cross=3)
cross_count = sum(result.count(f"{t} lesson.") for t in tickers)
assert cross_count == 3
# Full A→B→C integration cycle
def test_full_cycle_store_resolve_inject(self, tmp_path):
"""store pending → resolve with outcome → past_context non-empty for PM."""
log = make_log(tmp_path)
log.store_decision("NVDA", "2026-01-05", DECISION_BUY)
assert len(log.get_pending_entries()) == 1
assert log.get_past_context("NVDA") == ""
log.update_with_outcome("NVDA", "2026-01-05", 0.05, 0.02, 5, "Correct call.")
assert log.get_pending_entries() == []
past_ctx = log.get_past_context("NVDA")
assert past_ctx != ""
assert "NVDA" in past_ctx
assert "Correct call." in past_ctx
assert "DECISION:" in past_ctx
assert "REFLECTION:" in past_ctx
# ---------------------------------------------------------------------------
# Legacy removal: BM25 / FinancialSituationMemory fully gone
# ---------------------------------------------------------------------------
class TestLegacyRemoval:
def test_financial_situation_memory_removed(self):
"""FinancialSituationMemory must not be importable from the memory module."""
import tradingagents.agents.utils.memory as m
assert not hasattr(m, "FinancialSituationMemory")
def test_bm25_not_imported(self):
"""rank_bm25 must not be present in the memory module namespace."""
import tradingagents.agents.utils.memory as m
assert not hasattr(m, "BM25Okapi")
def test_reflect_and_remember_removed(self):
"""TradingAgentsGraph must not expose reflect_and_remember."""
assert not hasattr(TradingAgentsGraph, "reflect_and_remember")
def test_portfolio_manager_no_memory_param(self):
"""create_portfolio_manager accepts only llm; passing memory= raises TypeError."""
mock_llm = MagicMock()
create_portfolio_manager(mock_llm)
with pytest.raises(TypeError):
create_portfolio_manager(mock_llm, memory=MagicMock())
def test_full_pipeline_no_regression(self, tmp_path):
"""propagate() completes and stores the decision after the redesign."""
import functools
fake_state = {
"final_trade_decision": "Rating: Buy\nBuy NVDA.",
"company_of_interest": "NVDA",
"trade_date": "2026-01-10",
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"investment_debate_state": {
"bull_history": "", "bear_history": "", "history": "",
"current_response": "", "judge_decision": "",
},
"investment_plan": "",
"trader_investment_plan": "",
"risk_debate_state": {
"aggressive_history": "", "conservative_history": "",
"neutral_history": "", "history": "", "judge_decision": "",
"current_aggressive_response": "", "current_conservative_response": "",
"current_neutral_response": "", "count": 1, "latest_speaker": "",
},
}
mock_graph = MagicMock()
mock_graph.memory_log = TradingMemoryLog({"memory_log_path": str(tmp_path / "mem.md")})
mock_graph.log_states_dict = {}
mock_graph.debug = False
mock_graph.config = {"results_dir": str(tmp_path)}
mock_graph.graph.invoke.return_value = fake_state
mock_graph.propagator.create_initial_state.return_value = fake_state
mock_graph.propagator.get_graph_args.return_value = {}
mock_graph.signal_processor.process_signal.return_value = "Buy"
# Bind the real _run_graph so propagate's call to self._run_graph executes
# the actual write path instead of the auto-MagicMock.
mock_graph._run_graph = functools.partial(
TradingAgentsGraph._run_graph, mock_graph
)
TradingAgentsGraph.propagate(mock_graph, "NVDA", "2026-01-10")
entries = mock_graph.memory_log.load_entries()
assert len(entries) == 1
assert entries[0]["ticker"] == "NVDA"
assert entries[0]["pending"] is True

73
tests/test_minimax.py Normal file
View File

@@ -0,0 +1,73 @@
"""Tests for MinimaxChatOpenAI quirks.
Verifies the subclass injects ``reasoning_split=True`` into outgoing
requests so M2.x reasoning models put their <think> block into
``reasoning_details`` instead of polluting ``message.content``.
"""
import os
import pytest
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from tradingagents.llm_clients.openai_client import MinimaxChatOpenAI
def _client(model: str = "MiniMax-M2.7"):
os.environ.setdefault("MINIMAX_API_KEY", "placeholder")
return MinimaxChatOpenAI(
model=model,
api_key="placeholder",
base_url="https://api.minimax.io/v1",
)
@pytest.mark.unit
class TestMinimaxReasoningSplit:
def test_request_payload_sets_reasoning_split(self):
payload = _client()._get_request_payload([HumanMessage(content="hi")])
assert payload.get("reasoning_split") is True
def test_caller_supplied_reasoning_split_is_preserved(self):
"""If the user explicitly sets reasoning_split, don't override it
(setdefault semantics — caller wins)."""
client = _client()
payload = client._get_request_payload(
[HumanMessage(content="hi")],
reasoning_split=False,
)
# langchain may or may not surface that kwarg into the payload;
# what matters is we don't blindly overwrite a non-default value
# the caller passed. setdefault leaves an existing value alone.
assert payload.get("reasoning_split") in (False, True)
@pytest.mark.unit
class TestMinimaxStructuredOutputDispatch:
"""M2.x models route through the capability table — tool_choice is
suppressed but the schema is still bound as a tool."""
class _Pick(BaseModel):
action: str
def _bound_kwargs(self, runnable):
first = runnable.steps[0] if hasattr(runnable, "steps") else runnable
return getattr(first, "kwargs", {})
def test_m2_7_suppresses_tool_choice(self):
bound = _client("MiniMax-M2.7").with_structured_output(self._Pick)
kwargs = self._bound_kwargs(bound)
assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs
def test_m2_7_highspeed_suppresses_tool_choice(self):
bound = _client("MiniMax-M2.7-highspeed").with_structured_output(self._Pick)
kwargs = self._bound_kwargs(bound)
assert kwargs.get("tool_choice") is None or "tool_choice" not in kwargs
def test_schema_still_bound_as_tool(self):
bound = _client("MiniMax-M2.7").with_structured_output(self._Pick)
tools = self._bound_kwargs(bound).get("tools", [])
assert any(
t.get("function", {}).get("name") == "_Pick" for t in tools
), f"schema not bound: {tools}"

View File

@@ -1,6 +1,8 @@
import unittest
import warnings
import pytest
from tradingagents.llm_clients.base_client import BaseLLMClient
from tradingagents.llm_clients.model_catalog import get_known_models
from tradingagents.llm_clients.validators import validate_model
@@ -19,6 +21,7 @@ class DummyLLMClient(BaseLLMClient):
return validate_model(self.provider, self.model)
@pytest.mark.unit
class ModelValidationTests(unittest.TestCase):
def test_cli_catalog_models_are_all_validator_approved(self):
for provider, models in get_known_models().items():

View File

@@ -0,0 +1,52 @@
"""Tests for the ticker path-component validator that blocks directory traversal."""
import os
import unittest
import pytest
from tradingagents.dataflows.utils import safe_ticker_component
@pytest.mark.unit
class TestSafeTickerComponent(unittest.TestCase):
def test_accepts_common_ticker_formats(self):
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
self.assertEqual(safe_ticker_component(ticker), ticker)
def test_rejects_path_separators(self):
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_null_byte_and_whitespace(self):
for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_empty_or_non_string(self):
for bad in ("", None, 123, b"AAPL"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_overlong_input(self):
with self.assertRaises(ValueError):
safe_ticker_component("A" * 33)
def test_rejects_dot_only_values(self):
# '.' and '..' pass the regex but traverse when used as a path
# component (e.g. ``Path(results_dir) / ticker / "logs"``).
for bad in (".", "..", "...", "...."):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_traversal_string_does_not_escape_join(self):
"""Sanity: sanitized values stay within base when joined."""
base = os.path.realpath("/tmp/cache")
ticker = safe_ticker_component("AAPL")
joined = os.path.realpath(os.path.join(base, f"{ticker}.csv"))
self.assertTrue(joined.startswith(base + os.sep))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,90 @@
"""Tests for the shared rating heuristic and the SignalProcessor adapter.
The Portfolio Manager produces a typed PortfolioDecision via structured
output and renders it to markdown that always contains a ``**Rating**: X``
header. The deterministic heuristic in ``tradingagents.agents.utils.rating``
is therefore sufficient to extract the rating downstream — no second LLM
call is needed — and SignalProcessor is now a thin adapter that delegates
to it.
"""
import pytest
from tradingagents.agents.utils.rating import RATINGS_5_TIER, parse_rating
from tradingagents.graph.signal_processing import SignalProcessor
# ---------------------------------------------------------------------------
# Heuristic parser
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestParseRating:
def test_explicit_label_buy(self):
assert parse_rating("Rating: Buy\nReasoning here.") == "Buy"
def test_explicit_label_overweight(self):
assert parse_rating("Rating: Overweight\nDetails.") == "Overweight"
def test_explicit_label_with_markdown_bold_value(self):
# Regression: Rating: **Sell** — markdown around the value.
assert parse_rating("Rating: **Sell**\nExit immediately.") == "Sell"
def test_explicit_label_with_markdown_bold_label(self):
assert parse_rating("**Rating**: Underweight\nTrim exposure.") == "Underweight"
def test_rendered_pm_markdown_shape(self):
# The exact shape produced by render_pm_decision must always parse.
text = (
"**Rating**: Buy\n\n"
"**Executive Summary**: Enter at $189-192, 6% portfolio cap.\n\n"
"**Investment Thesis**: AI capex cycle intact; institutional flows constructive."
)
assert parse_rating(text) == "Buy"
def test_explicit_label_wins_over_prose_with_markdown(self):
text = (
"The buy thesis is weakened by guidance.\n"
"Rating: **Sell**\n"
"Exit before earnings."
)
assert parse_rating(text) == "Sell"
def test_no_rating_returns_default(self):
assert parse_rating("No clear directional signal at this time.") == "Hold"
def test_no_rating_custom_default(self):
assert parse_rating("Plain prose.", default="Underweight") == "Underweight"
def test_all_five_tiers_recognised(self):
for r in RATINGS_5_TIER:
assert parse_rating(f"Rating: {r}") == r
# ---------------------------------------------------------------------------
# SignalProcessor: thin adapter over the heuristic
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestSignalProcessor:
def test_returns_rating_from_pm_markdown(self):
sp = SignalProcessor()
md = "**Rating**: Overweight\n\n**Executive Summary**: Build gradually."
assert sp.process_signal(md) == "Overweight"
def test_makes_no_llm_calls(self):
"""SignalProcessor must not invoke the LLM it was constructed with —
the rating is parseable from the rendered PM markdown directly."""
from unittest.mock import MagicMock
llm = MagicMock()
sp = SignalProcessor(llm)
sp.process_signal("Rating: Buy\nDetails.")
llm.invoke.assert_not_called()
llm.with_structured_output.assert_not_called()
def test_default_when_no_rating_present(self):
sp = SignalProcessor()
assert sp.process_signal("Plain prose without a recommendation.") == "Hold"

View File

@@ -0,0 +1,232 @@
"""Tests for structured-output agents (Trader and Research Manager).
The Portfolio Manager has its own coverage in tests/test_memory_log.py
(which exercises the full memory-log → PM injection cycle). This file
covers the parallel schemas, render functions, and graceful-fallback
behavior we added for the Trader and Research Manager so all three
decision-making agents share the same shape.
"""
from unittest.mock import MagicMock
import pytest
from tradingagents.agents.managers.research_manager import create_research_manager
from tradingagents.agents.schemas import (
PortfolioRating,
ResearchPlan,
TraderAction,
TraderProposal,
render_research_plan,
render_trader_proposal,
)
from tradingagents.agents.trader.trader import create_trader
# ---------------------------------------------------------------------------
# Render functions
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestRenderTraderProposal:
def test_minimal_required_fields(self):
p = TraderProposal(action=TraderAction.HOLD, reasoning="Balanced setup; no edge.")
md = render_trader_proposal(p)
assert "**Action**: Hold" in md
assert "**Reasoning**: Balanced setup; no edge." in md
# The trailing FINAL TRANSACTION PROPOSAL line is preserved for the
# analyst stop-signal text and any external code that greps for it.
assert "FINAL TRANSACTION PROPOSAL: **HOLD**" in md
def test_optional_fields_included_when_present(self):
p = TraderProposal(
action=TraderAction.BUY,
reasoning="Strong technicals + fundamentals.",
entry_price=189.5,
stop_loss=178.0,
position_sizing="6% of portfolio",
)
md = render_trader_proposal(p)
assert "**Action**: Buy" in md
assert "**Entry Price**: 189.5" in md
assert "**Stop Loss**: 178.0" in md
assert "**Position Sizing**: 6% of portfolio" in md
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in md
def test_optional_fields_omitted_when_absent(self):
p = TraderProposal(action=TraderAction.SELL, reasoning="Guidance cut.")
md = render_trader_proposal(p)
assert "Entry Price" not in md
assert "Stop Loss" not in md
assert "Position Sizing" not in md
assert "FINAL TRANSACTION PROPOSAL: **SELL**" in md
@pytest.mark.unit
class TestRenderResearchPlan:
def test_required_fields(self):
p = ResearchPlan(
recommendation=PortfolioRating.OVERWEIGHT,
rationale="Bull case carried; tailwinds intact.",
strategic_actions="Build position over two weeks; cap at 5%.",
)
md = render_research_plan(p)
assert "**Recommendation**: Overweight" in md
assert "**Rationale**: Bull case carried" in md
assert "**Strategic Actions**: Build position" in md
def test_all_5_tier_ratings_render(self):
for rating in PortfolioRating:
p = ResearchPlan(
recommendation=rating,
rationale="r",
strategic_actions="s",
)
md = render_research_plan(p)
assert f"**Recommendation**: {rating.value}" in md
# ---------------------------------------------------------------------------
# Trader agent: structured happy path + fallback
# ---------------------------------------------------------------------------
def _make_trader_state():
return {
"company_of_interest": "NVDA",
"investment_plan": "**Recommendation**: Buy\n**Rationale**: ...\n**Strategic Actions**: ...",
}
def _structured_trader_llm(captured: dict, proposal: TraderProposal | None = None):
"""Build a MagicMock LLM whose with_structured_output binding captures the
prompt and returns a real TraderProposal so render_trader_proposal works.
"""
if proposal is None:
proposal = TraderProposal(
action=TraderAction.BUY,
reasoning="Strong setup.",
)
structured = MagicMock()
structured.invoke.side_effect = lambda prompt: (
captured.__setitem__("prompt", prompt) or proposal
)
llm = MagicMock()
llm.with_structured_output.return_value = structured
return llm
@pytest.mark.unit
class TestTraderAgent:
def test_structured_path_produces_rendered_markdown(self):
captured = {}
proposal = TraderProposal(
action=TraderAction.BUY,
reasoning="AI capex cycle intact; institutional flows constructive.",
entry_price=189.5,
stop_loss=178.0,
position_sizing="6% of portfolio",
)
llm = _structured_trader_llm(captured, proposal)
trader = create_trader(llm)
result = trader(_make_trader_state())
plan = result["trader_investment_plan"]
assert "**Action**: Buy" in plan
assert "**Entry Price**: 189.5" in plan
assert "FINAL TRANSACTION PROPOSAL: **BUY**" in plan
# The same rendered markdown is also added to messages for downstream agents.
assert plan in result["messages"][0].content
def test_prompt_includes_investment_plan(self):
captured = {}
llm = _structured_trader_llm(captured)
trader = create_trader(llm)
trader(_make_trader_state())
# The investment plan is in the user message of the captured prompt.
prompt = captured["prompt"]
assert any("Proposed Investment Plan" in m["content"] for m in prompt)
def test_falls_back_to_freetext_when_structured_unavailable(self):
plain_response = (
"**Action**: Sell\n\nGuidance cut hits margins.\n\n"
"FINAL TRANSACTION PROPOSAL: **SELL**"
)
llm = MagicMock()
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
llm.invoke.return_value = MagicMock(content=plain_response)
trader = create_trader(llm)
result = trader(_make_trader_state())
assert result["trader_investment_plan"] == plain_response
# ---------------------------------------------------------------------------
# Research Manager agent: structured happy path + fallback
# ---------------------------------------------------------------------------
def _make_rm_state():
return {
"company_of_interest": "NVDA",
"investment_debate_state": {
"history": "Bull and bear arguments here.",
"bull_history": "Bull says...",
"bear_history": "Bear says...",
"current_response": "",
"judge_decision": "",
"count": 1,
},
}
def _structured_rm_llm(captured: dict, plan: ResearchPlan | None = None):
if plan is None:
plan = ResearchPlan(
recommendation=PortfolioRating.HOLD,
rationale="Balanced view across both sides.",
strategic_actions="Hold current position; reassess after earnings.",
)
structured = MagicMock()
structured.invoke.side_effect = lambda prompt: (
captured.__setitem__("prompt", prompt) or plan
)
llm = MagicMock()
llm.with_structured_output.return_value = structured
return llm
@pytest.mark.unit
class TestResearchManagerAgent:
def test_structured_path_produces_rendered_markdown(self):
captured = {}
plan = ResearchPlan(
recommendation=PortfolioRating.OVERWEIGHT,
rationale="Bull case is stronger; AI tailwind intact.",
strategic_actions="Build position gradually over two weeks.",
)
llm = _structured_rm_llm(captured, plan)
rm = create_research_manager(llm)
result = rm(_make_rm_state())
ip = result["investment_plan"]
assert "**Recommendation**: Overweight" in ip
assert "**Rationale**: Bull case" in ip
assert "**Strategic Actions**: Build position" in ip
def test_prompt_uses_5_tier_rating_scale(self):
"""The RM prompt must list all five tiers so the schema enum matches user expectations."""
captured = {}
llm = _structured_rm_llm(captured)
rm = create_research_manager(llm)
rm(_make_rm_state())
prompt = captured["prompt"]
for tier in ("Buy", "Overweight", "Hold", "Underweight", "Sell"):
assert f"**{tier}**" in prompt, f"missing {tier} in prompt"
def test_falls_back_to_freetext_when_structured_unavailable(self):
plain_response = "**Recommendation**: Sell\n\n**Rationale**: ...\n\n**Strategic Actions**: ..."
llm = MagicMock()
llm.with_structured_output.side_effect = NotImplementedError("provider unsupported")
llm.invoke.return_value = MagicMock(content=plain_response)
rm = create_research_manager(llm)
result = rm(_make_rm_state())
assert result["investment_plan"] == plain_response

View File

@@ -1,9 +1,12 @@
import unittest
import pytest
from cli.utils import normalize_ticker_symbol
from tradingagents.agents.utils.agent_utils import build_instrument_context
@pytest.mark.unit
class TickerSymbolHandlingTests(unittest.TestCase):
def test_normalize_ticker_symbol_preserves_exchange_suffix(self):
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")