mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-02 06:53:16 +03:00
feat: DeepSeek V4 thinking-mode round-trip via DeepSeekChatOpenAI subclass
Resolves #599: thinking-mode models require reasoning_content to be echoed back across turns; multi-turn agent runs failed with HTTP 400. The fix isolates DeepSeek's quirks (reasoning_content round-trip and the deepseek-reasoner no-tool_choice limitation) into a subclass so the general OpenAI-compatible client stays untouched. Adds DeepSeek V4 Pro/Flash to the catalog. 9 new tests; rationale documented in the class docstrings. Design adapted from #600; #611 closed in favour of this approach.
This commit is contained in:
169
tests/test_deepseek_reasoning.py
Normal file
169
tests/test_deepseek_reasoning.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Tests for DeepSeekChatOpenAI thinking-mode behaviour.
|
||||||
|
|
||||||
|
Two pieces verified:
|
||||||
|
|
||||||
|
1. ``reasoning_content`` is captured on receive into the AIMessage's
|
||||||
|
``additional_kwargs`` and re-attached on send so DeepSeek's API
|
||||||
|
sees the same value across turns.
|
||||||
|
2. ``with_structured_output`` raises NotImplementedError for
|
||||||
|
``deepseek-reasoner`` so the agent factories' free-text fallback
|
||||||
|
handles the request instead of failing at runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from langchain_core.prompt_values import ChatPromptValue
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.openai_client import (
|
||||||
|
DeepSeekChatOpenAI,
|
||||||
|
NormalizedChatOpenAI,
|
||||||
|
_input_to_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _input_to_messages — the helper that handles list / ChatPromptValue / other
|
||||||
|
# (Gemini bot review note: non-list inputs must also work)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestInputToMessages:
|
||||||
|
def test_list_input_returned_as_is(self):
|
||||||
|
msgs = [HumanMessage(content="hi")]
|
||||||
|
assert _input_to_messages(msgs) is msgs
|
||||||
|
|
||||||
|
def test_chat_prompt_value_unwrapped(self):
|
||||||
|
msgs = [HumanMessage(content="hi")]
|
||||||
|
prompt_value = ChatPromptValue(messages=msgs)
|
||||||
|
assert _input_to_messages(prompt_value) == msgs
|
||||||
|
|
||||||
|
def test_string_input_yields_empty_list(self):
|
||||||
|
# A bare string isn't a message-bearing input; the caller's normal
|
||||||
|
# langchain conversion happens upstream of _get_request_payload.
|
||||||
|
assert _input_to_messages("hello") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reasoning content propagation across turns
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestDeepSeekReasoningContent:
|
||||||
|
def _client(self):
|
||||||
|
os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder")
|
||||||
|
return DeepSeekChatOpenAI(
|
||||||
|
model="deepseek-v4-flash",
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_capture_on_receive(self):
|
||||||
|
"""When the response carries reasoning_content, it lands on the
|
||||||
|
AIMessage's additional_kwargs so the next turn can echo it back."""
|
||||||
|
client = self._client()
|
||||||
|
result = client._create_chat_result(
|
||||||
|
{
|
||||||
|
"model": "deepseek-v4-flash",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Plan: buy NVDA.",
|
||||||
|
"reasoning_content": "Step 1: trend is up. Step 2: ...",
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ai = result.generations[0].message
|
||||||
|
assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..."
|
||||||
|
|
||||||
|
def test_propagate_on_send(self):
|
||||||
|
"""When an outgoing AIMessage carries reasoning_content, the request
|
||||||
|
payload echoes it on the corresponding message dict."""
|
||||||
|
client = self._client()
|
||||||
|
prior = AIMessage(
|
||||||
|
content="Plan",
|
||||||
|
additional_kwargs={"reasoning_content": "weighed bull case"},
|
||||||
|
)
|
||||||
|
new_user = HumanMessage(content="Refine.")
|
||||||
|
payload = client._get_request_payload([prior, new_user])
|
||||||
|
# Find the assistant message in the payload
|
||||||
|
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
|
||||||
|
assert assistant_dicts, "assistant message missing from outgoing payload"
|
||||||
|
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
|
||||||
|
|
||||||
|
def test_propagate_through_chat_prompt_value(self):
|
||||||
|
"""Gemini bot review note: non-list inputs (ChatPromptValue) must
|
||||||
|
also propagate reasoning_content."""
|
||||||
|
client = self._client()
|
||||||
|
prior = AIMessage(
|
||||||
|
content="Plan",
|
||||||
|
additional_kwargs={"reasoning_content": "weighed bull case"},
|
||||||
|
)
|
||||||
|
prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")])
|
||||||
|
payload = client._get_request_payload(prompt_value)
|
||||||
|
assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"]
|
||||||
|
assert assistant_dicts[0]["reasoning_content"] == "weighed bull case"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# deepseek-reasoner: structured output unavailable, falls through to free-text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestDeepSeekReasonerStructuredOutput:
|
||||||
|
def test_with_structured_output_raises_for_reasoner(self):
|
||||||
|
client = DeepSeekChatOpenAI(
|
||||||
|
model="deepseek-reasoner",
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class _Sample(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
client.with_structured_output(_Sample)
|
||||||
|
|
||||||
|
def test_with_structured_output_works_for_v4(self):
|
||||||
|
"""V4 models (non-reasoner) accept tool_choice; structured output works."""
|
||||||
|
client = DeepSeekChatOpenAI(
|
||||||
|
model="deepseek-v4-flash",
|
||||||
|
api_key="placeholder",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class _Sample(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
# Should return a Runnable, not raise. (The actual API call would
|
||||||
|
# require a real key; we only assert binding succeeds.)
|
||||||
|
wrapped = client.with_structured_output(_Sample)
|
||||||
|
assert wrapped is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestBaseClassIsolation:
|
||||||
|
def test_normalized_does_not_propagate_reasoning_content(self):
|
||||||
|
"""The general-purpose NormalizedChatOpenAI must not carry
|
||||||
|
DeepSeek-specific behaviour. Only the subclass does."""
|
||||||
|
assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or (
|
||||||
|
NormalizedChatOpenAI._get_request_payload
|
||||||
|
is NormalizedChatOpenAI.__bases__[0]._get_request_payload
|
||||||
|
)
|
||||||
@@ -65,10 +65,12 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
|||||||
},
|
},
|
||||||
"deepseek": {
|
"deepseek": {
|
||||||
"quick": [
|
"quick": [
|
||||||
|
("DeepSeek V4 Flash - Latest V4 fast model", "deepseek-v4-flash"),
|
||||||
("DeepSeek V3.2", "deepseek-chat"),
|
("DeepSeek V3.2", "deepseek-chat"),
|
||||||
("Custom model ID", "custom"),
|
("Custom model ID", "custom"),
|
||||||
],
|
],
|
||||||
"deep": [
|
"deep": [
|
||||||
|
("DeepSeek V4 Pro - Latest V4 flagship model", "deepseek-v4-pro"),
|
||||||
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
|
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
|
||||||
("DeepSeek V3.2", "deepseek-chat"),
|
("DeepSeek V3.2", "deepseek-chat"),
|
||||||
("Custom model ID", "custom"),
|
("Custom model ID", "custom"),
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from .base_client import BaseLLMClient, normalize_content
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
@@ -11,29 +12,97 @@ class NormalizedChatOpenAI(ChatOpenAI):
|
|||||||
"""ChatOpenAI with normalized content output.
|
"""ChatOpenAI with normalized content output.
|
||||||
|
|
||||||
The Responses API returns content as a list of typed blocks
|
The Responses API returns content as a list of typed blocks
|
||||||
(reasoning, text, etc.). This normalizes to string for consistent
|
(reasoning, text, etc.). ``invoke`` normalizes to string for
|
||||||
downstream handling.
|
consistent downstream handling. ``with_structured_output`` defaults
|
||||||
|
to function-calling so the Responses-API parse path is avoided
|
||||||
|
(langchain-openai's parse path emits noisy
|
||||||
|
PydanticSerializationUnexpectedValue warnings per call without
|
||||||
|
affecting correctness).
|
||||||
|
|
||||||
|
Provider-specific quirks (e.g. DeepSeek's thinking mode) live in
|
||||||
|
purpose-built subclasses below so this base class stays small.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def invoke(self, input, config=None, **kwargs):
|
def invoke(self, input, config=None, **kwargs):
|
||||||
return normalize_content(super().invoke(input, config, **kwargs))
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
def with_structured_output(self, schema, *, method=None, **kwargs):
|
def with_structured_output(self, schema, *, method=None, **kwargs):
|
||||||
"""Wrap with structured output, defaulting to function_calling for OpenAI.
|
|
||||||
|
|
||||||
langchain-openai's Responses-API-parse path (the default for json_schema
|
|
||||||
when use_responses_api=True) calls response.model_dump(...) on the OpenAI
|
|
||||||
SDK's union-typed parsed response, which makes Pydantic emit ~20
|
|
||||||
PydanticSerializationUnexpectedValue warnings per call. The function-calling
|
|
||||||
path returns a plain tool-call shape that does not trigger that
|
|
||||||
serialization, so it is the cleaner choice for our combination of
|
|
||||||
use_responses_api=True + with_structured_output. Both paths use OpenAI's
|
|
||||||
strict mode and produce the same typed Pydantic instance.
|
|
||||||
"""
|
|
||||||
if method is None:
|
if method is None:
|
||||||
method = "function_calling"
|
method = "function_calling"
|
||||||
return super().with_structured_output(schema, method=method, **kwargs)
|
return super().with_structured_output(schema, method=method, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _input_to_messages(input_: Any) -> list:
|
||||||
|
"""Normalise a langchain LLM input to a list of message objects.
|
||||||
|
|
||||||
|
Accepts a list of messages, a ``ChatPromptValue`` (from a
|
||||||
|
ChatPromptTemplate), or anything else (treated as no messages).
|
||||||
|
Used by providers that need to walk the outgoing message history;
|
||||||
|
in particular DeepSeek thinking-mode propagation must work for
|
||||||
|
both bare-list invocations and ChatPromptTemplate-driven ones, so
|
||||||
|
treating only ``list`` here would silently skip half the call sites.
|
||||||
|
"""
|
||||||
|
if isinstance(input_, list):
|
||||||
|
return input_
|
||||||
|
if hasattr(input_, "to_messages"):
|
||||||
|
return input_.to_messages()
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekChatOpenAI(NormalizedChatOpenAI):
|
||||||
|
"""DeepSeek-specific overrides on top of the OpenAI-compatible client.
|
||||||
|
|
||||||
|
Two quirks that don't apply to other OpenAI-compatible providers:
|
||||||
|
|
||||||
|
1. **Thinking-mode round-trip.** When DeepSeek's thinking models return
|
||||||
|
a response with ``reasoning_content``, that field must be echoed
|
||||||
|
back as part of the assistant message on the next turn or the API
|
||||||
|
fails with HTTP 400. ``_create_chat_result`` captures the field on
|
||||||
|
receive and ``_get_request_payload`` re-attaches it on send.
|
||||||
|
|
||||||
|
2. **deepseek-reasoner has no tool_choice.** Structured output via
|
||||||
|
function-calling is unavailable, so we raise NotImplementedError
|
||||||
|
and let the agent factories fall back to free-text generation
|
||||||
|
(see ``tradingagents/agents/utils/structured.py``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||||
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||||
|
outgoing = payload.get("messages", [])
|
||||||
|
for message_dict, message in zip(outgoing, _input_to_messages(input_)):
|
||||||
|
if not isinstance(message, AIMessage):
|
||||||
|
continue
|
||||||
|
reasoning = message.additional_kwargs.get("reasoning_content")
|
||||||
|
if reasoning is not None:
|
||||||
|
message_dict["reasoning_content"] = reasoning
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _create_chat_result(self, response, generation_info=None):
|
||||||
|
chat_result = super()._create_chat_result(response, generation_info)
|
||||||
|
response_dict = (
|
||||||
|
response
|
||||||
|
if isinstance(response, dict)
|
||||||
|
else response.model_dump(
|
||||||
|
exclude={"choices": {"__all__": {"message": {"parsed"}}}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for generation, choice in zip(
|
||||||
|
chat_result.generations, response_dict.get("choices", [])
|
||||||
|
):
|
||||||
|
reasoning = choice.get("message", {}).get("reasoning_content")
|
||||||
|
if reasoning is not None:
|
||||||
|
generation.message.additional_kwargs["reasoning_content"] = reasoning
|
||||||
|
return chat_result
|
||||||
|
|
||||||
|
def with_structured_output(self, schema, *, method=None, **kwargs):
|
||||||
|
if self.model_name == "deepseek-reasoner":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"deepseek-reasoner does not support tool_choice; structured "
|
||||||
|
"output is unavailable. Agent factories fall back to "
|
||||||
|
"free-text generation automatically."
|
||||||
|
)
|
||||||
|
return super().with_structured_output(schema, method=method, **kwargs)
|
||||||
|
|
||||||
# Kwargs forwarded from user config to ChatOpenAI
|
# Kwargs forwarded from user config to ChatOpenAI
|
||||||
_PASSTHROUGH_KWARGS = (
|
_PASSTHROUGH_KWARGS = (
|
||||||
"timeout", "max_retries", "reasoning_effort",
|
"timeout", "max_retries", "reasoning_effort",
|
||||||
@@ -75,10 +144,12 @@ class OpenAIClient(BaseLLMClient):
|
|||||||
self.warn_if_unknown_model()
|
self.warn_if_unknown_model()
|
||||||
llm_kwargs = {"model": self.model}
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
# Provider-specific base URL and auth
|
# Provider-specific base URL and auth. An explicit base_url on the
|
||||||
|
# client (e.g. a corporate proxy) takes precedence over the
|
||||||
|
# provider default so users can route through their own gateway.
|
||||||
if self.provider in _PROVIDER_CONFIG:
|
if self.provider in _PROVIDER_CONFIG:
|
||||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
default_base, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||||
llm_kwargs["base_url"] = base_url
|
llm_kwargs["base_url"] = self.base_url or default_base
|
||||||
if api_key_env:
|
if api_key_env:
|
||||||
api_key = os.environ.get(api_key_env)
|
api_key = os.environ.get(api_key_env)
|
||||||
if api_key:
|
if api_key:
|
||||||
@@ -98,7 +169,10 @@ class OpenAIClient(BaseLLMClient):
|
|||||||
if self.provider == "openai":
|
if self.provider == "openai":
|
||||||
llm_kwargs["use_responses_api"] = True
|
llm_kwargs["use_responses_api"] = True
|
||||||
|
|
||||||
return NormalizedChatOpenAI(**llm_kwargs)
|
# DeepSeek's thinking-mode quirks live in their own subclass so the
|
||||||
|
# base NormalizedChatOpenAI stays free of provider-specific branches.
|
||||||
|
chat_cls = DeepSeekChatOpenAI if self.provider == "deepseek" else NormalizedChatOpenAI
|
||||||
|
return chat_cls(**llm_kwargs)
|
||||||
|
|
||||||
def validate_model(self) -> bool:
|
def validate_model(self) -> bool:
|
||||||
"""Validate model for the provider."""
|
"""Validate model for the provider."""
|
||||||
|
|||||||
Reference in New Issue
Block a user