2 Commits
v0.2.4 ... main

Author SHA1 Message Date
Yijia-Xiao
7e9e7b83c7 feat: DeepSeek V4 thinking-mode round-trip via DeepSeekChatOpenAI subclass
Resolves #599: thinking-mode models require reasoning_content to be
echoed back across turns; multi-turn agent runs failed with HTTP 400.

The fix isolates DeepSeek's quirks (reasoning_content round-trip and
the deepseek-reasoner no-tool_choice limitation) into a subclass so
the general OpenAI-compatible client stays untouched. Adds DeepSeek
V4 Pro/Flash to the catalog. 9 new tests; rationale documented in
the class docstrings.

Design adapted from #600; #611 closed in favour of this approach.
2026-05-01 19:23:23 +00:00
Yijia-Xiao
2c97bad45c fix(security): validate ticker before using as path component (#618)
The ticker symbol reaches three filesystem-path construction sites
(load_ohlcv cache filename, checkpointer DB path, _log_state results
directory) without validation. A value containing path separators or
"../" escapes the configured cache / checkpoints / results directory.

Two attack vectors:
- Programmatic callers passing arbitrary ticker to propagate()
- Prompt injection via fetched news content steering the LLM into
  tool calls with attacker-chosen ticker

Fix: new safe_ticker_component() validator in tradingagents/dataflows/
utils.py applied at all three sites. Allows the standard ticker
character set ([A-Za-z0-9._\-\^], up to 32 chars) and explicitly
rejects dot-only values like "." and ".." which would otherwise pass
the regex but traverse parent directories. Seven test cases cover
the accepted formats (BRK-B, 7203.T, ^GSPC, etc.) and the rejected
inputs (path separators, null bytes, whitespace, empty values,
overlong strings, dot-only values).

Closes #618.
2026-05-01 18:56:36 +00:00
8 changed files with 365 additions and 21 deletions

View File

@@ -0,0 +1,169 @@
"""Tests for DeepSeekChatOpenAI thinking-mode behaviour.
Two pieces verified:
1. ``reasoning_content`` is captured on receive into the AIMessage's
``additional_kwargs`` and re-attached on send so DeepSeek's API
sees the same value across turns.
2. ``with_structured_output`` raises NotImplementedError for
``deepseek-reasoner`` so the agent factories' free-text fallback
handles the request instead of failing at runtime.
"""
import os
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompt_values import ChatPromptValue
from tradingagents.llm_clients.openai_client import (
DeepSeekChatOpenAI,
NormalizedChatOpenAI,
_input_to_messages,
)
# ---------------------------------------------------------------------------
# _input_to_messages — the helper that handles list / ChatPromptValue / other
# (Gemini bot review note: non-list inputs must also work)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestInputToMessages:
def test_list_input_returned_as_is(self):
msgs = [HumanMessage(content="hi")]
assert _input_to_messages(msgs) is msgs
def test_chat_prompt_value_unwrapped(self):
msgs = [HumanMessage(content="hi")]
prompt_value = ChatPromptValue(messages=msgs)
assert _input_to_messages(prompt_value) == msgs
def test_string_input_yields_empty_list(self):
# A bare string isn't a message-bearing input; the caller's normal
# langchain conversion happens upstream of _get_request_payload.
assert _input_to_messages("hello") == []
# ---------------------------------------------------------------------------
# Reasoning content propagation across turns
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDeepSeekReasoningContent:
def _client(self):
os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder")
return DeepSeekChatOpenAI(
model="deepseek-v4-flash",
api_key="placeholder",
base_url="https://api.deepseek.com",
)
def test_capture_on_receive(self):
"""When the response carries reasoning_content, it lands on the
AIMessage's additional_kwargs so the next turn can echo it back."""
client = self._client()
result = client._create_chat_result(
{
"model": "deepseek-v4-flash",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Plan: buy NVDA.",
"reasoning_content": "Step 1: trend is up. Step 2: ...",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
)
ai = result.generations[0].message
assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..."
def test_propagate_on_send(self):
"""When an outgoing AIMessage carries reasoning_content, the request
payload echoes it on the corresponding message dict."""
client = self._client()
prior = AIMessage(
content="Plan",
additional_kwargs={"reasoning_content": "weighed bull case"},
)
new_user = HumanMessage(content="Refine.")
payload = client._get_request_payload([prior, new_user])
# Find the assistant message in the payload
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
assert assistant_dicts, "assistant message missing from outgoing payload"
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
def test_propagate_through_chat_prompt_value(self):
"""Gemini bot review note: non-list inputs (ChatPromptValue) must
also propagate reasoning_content."""
client = self._client()
prior = AIMessage(
content="Plan",
additional_kwargs={"reasoning_content": "weighed bull case"},
)
prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")])
payload = client._get_request_payload(prompt_value)
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
# ---------------------------------------------------------------------------
# deepseek-reasoner: structured output unavailable, falls through to free-text
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDeepSeekReasonerStructuredOutput:
def test_with_structured_output_raises_for_reasoner(self):
client = DeepSeekChatOpenAI(
model="deepseek-reasoner",
api_key="placeholder",
base_url="https://api.deepseek.com",
)
from pydantic import BaseModel
class _Sample(BaseModel):
answer: str
with pytest.raises(NotImplementedError):
client.with_structured_output(_Sample)
def test_with_structured_output_works_for_v4(self):
"""V4 models (non-reasoner) accept tool_choice; structured output works."""
client = DeepSeekChatOpenAI(
model="deepseek-v4-flash",
api_key="placeholder",
base_url="https://api.deepseek.com",
)
from pydantic import BaseModel
class _Sample(BaseModel):
answer: str
# Should return a Runnable, not raise. (The actual API call would
# require a real key; we only assert binding succeeds.)
wrapped = client.with_structured_output(_Sample)
assert wrapped is not None
# ---------------------------------------------------------------------------
# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBaseClassIsolation:
def test_normalized_does_not_propagate_reasoning_content(self):
"""The general-purpose NormalizedChatOpenAI must not carry
DeepSeek-specific behaviour. Only the subclass does."""
assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or (
NormalizedChatOpenAI._get_request_payload
is NormalizedChatOpenAI.__bases__[0]._get_request_payload
)

View File

@@ -0,0 +1,52 @@
"""Tests for the ticker path-component validator that blocks directory traversal."""
import os
import unittest
import pytest
from tradingagents.dataflows.utils import safe_ticker_component
@pytest.mark.unit
class TestSafeTickerComponent(unittest.TestCase):
def test_accepts_common_ticker_formats(self):
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
self.assertEqual(safe_ticker_component(ticker), ticker)
def test_rejects_path_separators(self):
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_null_byte_and_whitespace(self):
for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_empty_or_non_string(self):
for bad in ("", None, 123, b"AAPL"):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_rejects_overlong_input(self):
with self.assertRaises(ValueError):
safe_ticker_component("A" * 33)
def test_rejects_dot_only_values(self):
# '.' and '..' pass the regex but traverse when used as a path
# component (e.g. ``Path(results_dir) / ticker / "logs"``).
for bad in (".", "..", "...", "...."):
with self.assertRaises(ValueError):
safe_ticker_component(bad)
def test_traversal_string_does_not_escape_join(self):
"""Sanity: sanitized values stay within base when joined."""
base = os.path.realpath("/tmp/cache")
ticker = safe_ticker_component("AAPL")
joined = os.path.realpath(os.path.join(base, f"{ticker}.csv"))
self.assertTrue(joined.startswith(base + os.sep))
if __name__ == "__main__":
unittest.main()

View File

@@ -8,6 +8,7 @@ from stockstats import wrap
from typing import Annotated
import os
from .config import get_config
from .utils import safe_ticker_component
logger = logging.getLogger(__name__)
@@ -51,6 +52,10 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
subsequent calls the cache is reused. Rows after curr_date are
filtered out so backtests never see future prices.
"""
# Reject ticker values that would escape the cache directory when
# interpolated into the cache filename (e.g. ``../../tmp/x``).
safe_symbol = safe_ticker_component(symbol)
config = get_config()
curr_date_dt = pd.to_datetime(curr_date)
@@ -63,7 +68,7 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_str}-{end_str}.csv",
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
)
if os.path.exists(data_file):

View File

@@ -1,4 +1,5 @@
import os
import re
import json
import pandas as pd
from datetime import date, timedelta, datetime
@@ -6,6 +7,40 @@ from typing import Annotated
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
# Tickers can contain letters, digits, dot, dash, underscore, and caret
# (for index symbols like ^GSPC). Anything else is rejected so the value
# never escapes a containing directory when interpolated into a path.
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$")
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
"""Validate ``value`` is safe to interpolate into a filesystem path.
Tickers come from user CLI input or from LLM tool calls, both of which
can be influenced by attacker-controlled content (e.g. prompt injection
embedded in fetched news). Without validation, a value like
``"../../../etc/foo"`` flows into ``os.path.join`` / ``Path /`` and
escapes the configured cache, checkpoint, or results directory.
Returns ``value`` unchanged when it matches the allowed pattern; raises
``ValueError`` otherwise.
"""
if not isinstance(value, str) or not value:
raise ValueError(f"ticker must be a non-empty string, got {value!r}")
if len(value) > max_len:
raise ValueError(f"ticker exceeds {max_len} chars: {value!r}")
if not _TICKER_PATH_RE.fullmatch(value):
raise ValueError(
f"ticker contains characters not allowed in a filesystem path: {value!r}"
)
# The regex above allows '.', so values like '.', '..', '...' would pass,
# and as a path component they traverse the parent directory. Reject any
# value that's only dots.
if set(value) == {"."}:
raise ValueError(f"ticker cannot consist solely of dots: {value!r}")
return value
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
if save_path:
data.to_csv(save_path, encoding="utf-8")

View File

@@ -13,12 +13,16 @@ from typing import Generator
from langgraph.checkpoint.sqlite import SqliteSaver
from tradingagents.dataflows.utils import safe_ticker_component
def _db_path(data_dir: str | Path, ticker: str) -> Path:
"""Return the SQLite checkpoint DB path for a ticker."""
# Reject ticker values that would escape the checkpoints directory.
safe = safe_ticker_component(ticker).upper()
p = Path(data_dir) / "checkpoints"
p.mkdir(parents=True, exist_ok=True)
return p / f"{ticker.upper()}.db"
return p / f"{safe}.db"
def thread_id(ticker: str, date: str) -> str:

View File

@@ -18,6 +18,7 @@ from tradingagents.llm_clients import create_llm_client
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import TradingMemoryLog
from tradingagents.dataflows.utils import safe_ticker_component
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
@@ -378,8 +379,10 @@ class TradingAgentsGraph:
"final_trade_decision": final_state["final_trade_decision"],
}
# Save to file
directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs"
# Save to file. Reject ticker values that would escape the
# results directory when joined as a path component.
safe_ticker = safe_ticker_component(self.ticker)
directory = Path(self.config["results_dir"]) / safe_ticker / "TradingAgentsStrategy_logs"
directory.mkdir(parents=True, exist_ok=True)
log_path = directory / f"full_states_log_{trade_date}.json"

View File

@@ -65,10 +65,12 @@ MODEL_OPTIONS: ProviderModeOptions = {
},
"deepseek": {
"quick": [
("DeepSeek V4 Flash - Latest V4 fast model", "deepseek-v4-flash"),
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
"deep": [
("DeepSeek V4 Pro - Latest V4 flagship model", "deepseek-v4-pro"),
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),

View File

@@ -1,6 +1,7 @@
import os
from typing import Any, Optional
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from .base_client import BaseLLMClient, normalize_content
@@ -11,29 +12,97 @@ class NormalizedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with normalized content output.
The Responses API returns content as a list of typed blocks
(reasoning, text, etc.). This normalizes to string for consistent
downstream handling.
(reasoning, text, etc.). ``invoke`` normalizes to string for
consistent downstream handling. ``with_structured_output`` defaults
to function-calling so the Responses-API parse path is avoided
(langchain-openai's parse path emits noisy
PydanticSerializationUnexpectedValue warnings per call without
affecting correctness).
Provider-specific quirks (e.g. DeepSeek's thinking mode) live in
purpose-built subclasses below so this base class stays small.
"""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
def with_structured_output(self, schema, *, method=None, **kwargs):
"""Wrap with structured output, defaulting to function_calling for OpenAI.
langchain-openai's Responses-API-parse path (the default for json_schema
when use_responses_api=True) calls response.model_dump(...) on the OpenAI
SDK's union-typed parsed response, which makes Pydantic emit ~20
PydanticSerializationUnexpectedValue warnings per call. The function-calling
path returns a plain tool-call shape that does not trigger that
serialization, so it is the cleaner choice for our combination of
use_responses_api=True + with_structured_output. Both paths use OpenAI's
strict mode and produce the same typed Pydantic instance.
"""
if method is None:
method = "function_calling"
return super().with_structured_output(schema, method=method, **kwargs)
def _input_to_messages(input_: Any) -> list:
"""Normalise a langchain LLM input to a list of message objects.
Accepts a list of messages, a ``ChatPromptValue`` (from a
ChatPromptTemplate), or anything else (treated as no messages).
Used by providers that need to walk the outgoing message history;
in particular DeepSeek thinking-mode propagation must work for
both bare-list invocations and ChatPromptTemplate-driven ones, so
treating only ``list`` here would silently skip half the call sites.
"""
if isinstance(input_, list):
return input_
if hasattr(input_, "to_messages"):
return input_.to_messages()
return []
class DeepSeekChatOpenAI(NormalizedChatOpenAI):
"""DeepSeek-specific overrides on top of the OpenAI-compatible client.
Two quirks that don't apply to other OpenAI-compatible providers:
1. **Thinking-mode round-trip.** When DeepSeek's thinking models return
a response with ``reasoning_content``, that field must be echoed
back as part of the assistant message on the next turn or the API
fails with HTTP 400. ``_create_chat_result`` captures the field on
receive and ``_get_request_payload`` re-attaches it on send.
2. **deepseek-reasoner has no tool_choice.** Structured output via
function-calling is unavailable, so we raise NotImplementedError
and let the agent factories fall back to free-text generation
(see ``tradingagents/agents/utils/structured.py``).
"""
def _get_request_payload(self, input_, *, stop=None, **kwargs):
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
outgoing = payload.get("messages", [])
for message_dict, message in zip(outgoing, _input_to_messages(input_)):
if not isinstance(message, AIMessage):
continue
reasoning = message.additional_kwargs.get("reasoning_content")
if reasoning is not None:
message_dict["reasoning_content"] = reasoning
return payload
def _create_chat_result(self, response, generation_info=None):
chat_result = super()._create_chat_result(response, generation_info)
response_dict = (
response
if isinstance(response, dict)
else response.model_dump(
exclude={"choices": {"__all__": {"message": {"parsed"}}}}
)
)
for generation, choice in zip(
chat_result.generations, response_dict.get("choices", [])
):
reasoning = choice.get("message", {}).get("reasoning_content")
if reasoning is not None:
generation.message.additional_kwargs["reasoning_content"] = reasoning
return chat_result
def with_structured_output(self, schema, *, method=None, **kwargs):
if self.model_name == "deepseek-reasoner":
raise NotImplementedError(
"deepseek-reasoner does not support tool_choice; structured "
"output is unavailable. Agent factories fall back to "
"free-text generation automatically."
)
return super().with_structured_output(schema, method=method, **kwargs)
# Kwargs forwarded from user config to ChatOpenAI
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "reasoning_effort",
@@ -75,10 +144,12 @@ class OpenAIClient(BaseLLMClient):
self.warn_if_unknown_model()
llm_kwargs = {"model": self.model}
# Provider-specific base URL and auth
# Provider-specific base URL and auth. An explicit base_url on the
# client (e.g. a corporate proxy) takes precedence over the
# provider default so users can route through their own gateway.
if self.provider in _PROVIDER_CONFIG:
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = base_url
default_base, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = self.base_url or default_base
if api_key_env:
api_key = os.environ.get(api_key_env)
if api_key:
@@ -98,7 +169,10 @@ class OpenAIClient(BaseLLMClient):
if self.provider == "openai":
llm_kwargs["use_responses_api"] = True
return NormalizedChatOpenAI(**llm_kwargs)
# DeepSeek's thinking-mode quirks live in their own subclass so the
# base NormalizedChatOpenAI stays free of provider-specific branches.
chat_cls = DeepSeekChatOpenAI if self.provider == "deepseek" else NormalizedChatOpenAI
return chat_cls(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for the provider."""