From 7e9e7b83c7fcc18d941300b253c6ed24d985788d Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Fri, 1 May 2026 19:23:23 +0000 Subject: [PATCH] feat: DeepSeek V4 thinking-mode round-trip via DeepSeekChatOpenAI subclass Resolves #599: thinking-mode models require reasoning_content to be echoed back across turns; multi-turn agent runs failed with HTTP 400. The fix isolates DeepSeek's quirks (reasoning_content round-trip and the deepseek-reasoner no-tool_choice limitation) into a subclass so the general OpenAI-compatible client stays untouched. Adds DeepSeek V4 Pro/Flash to the catalog. 9 new tests; rationale documented in the class docstrings. Design adapted from #600; #611 closed in favour of this approach. --- tests/test_deepseek_reasoning.py | 169 +++++++++++++++++++++ tradingagents/llm_clients/model_catalog.py | 2 + tradingagents/llm_clients/openai_client.py | 108 ++++++++++--- 3 files changed, 262 insertions(+), 17 deletions(-) create mode 100644 tests/test_deepseek_reasoning.py diff --git a/tests/test_deepseek_reasoning.py b/tests/test_deepseek_reasoning.py new file mode 100644 index 00000000..fb300336 --- /dev/null +++ b/tests/test_deepseek_reasoning.py @@ -0,0 +1,169 @@ +"""Tests for DeepSeekChatOpenAI thinking-mode behaviour. + +Two pieces verified: + +1. ``reasoning_content`` is captured on receive into the AIMessage's + ``additional_kwargs`` and re-attached on send so DeepSeek's API + sees the same value across turns. +2. ``with_structured_output`` raises NotImplementedError for + ``deepseek-reasoner`` so the agent factories' free-text fallback + handles the request instead of failing at runtime. +""" + +import os + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.prompt_values import ChatPromptValue + +from tradingagents.llm_clients.openai_client import ( + DeepSeekChatOpenAI, + NormalizedChatOpenAI, + _input_to_messages, +) + + +# --------------------------------------------------------------------------- +# _input_to_messages — the helper that handles list / ChatPromptValue / other +# (Gemini bot review note: non-list inputs must also work) +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestInputToMessages: + def test_list_input_returned_as_is(self): + msgs = [HumanMessage(content="hi")] + assert _input_to_messages(msgs) is msgs + + def test_chat_prompt_value_unwrapped(self): + msgs = [HumanMessage(content="hi")] + prompt_value = ChatPromptValue(messages=msgs) + assert _input_to_messages(prompt_value) == msgs + + def test_string_input_yields_empty_list(self): + # A bare string isn't a message-bearing input; the caller's normal + # langchain conversion happens upstream of _get_request_payload. + assert _input_to_messages("hello") == [] + + +# --------------------------------------------------------------------------- +# Reasoning content propagation across turns +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestDeepSeekReasoningContent: + def _client(self): + os.environ.setdefault("DEEPSEEK_API_KEY", "placeholder") + return DeepSeekChatOpenAI( + model="deepseek-v4-flash", + api_key="placeholder", + base_url="https://api.deepseek.com", + ) + + def test_capture_on_receive(self): + """When the response carries reasoning_content, it lands on the + AIMessage's additional_kwargs so the next turn can echo it back.""" + client = self._client() + result = client._create_chat_result( + { + "model": "deepseek-v4-flash", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Plan: buy NVDA.", + "reasoning_content": "Step 1: trend is up. Step 2: ...", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + ) + ai = result.generations[0].message + assert ai.additional_kwargs["reasoning_content"] == "Step 1: trend is up. Step 2: ..." + + def test_propagate_on_send(self): + """When an outgoing AIMessage carries reasoning_content, the request + payload echoes it on the corresponding message dict.""" + client = self._client() + prior = AIMessage( + content="Plan", + additional_kwargs={"reasoning_content": "weighed bull case"}, + ) + new_user = HumanMessage(content="Refine.") + payload = client._get_request_payload([prior, new_user]) + # Find the assistant message in the payload + assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"] + assert assistant_dicts, "assistant message missing from outgoing payload" + assert assistant_dicts[0]["reasoning_content"] == "weighed bull case" + + def test_propagate_through_chat_prompt_value(self): + """Gemini bot review note: non-list inputs (ChatPromptValue) must + also propagate reasoning_content.""" + client = self._client() + prior = AIMessage( + content="Plan", + additional_kwargs={"reasoning_content": "weighed bull case"}, + ) + prompt_value = ChatPromptValue(messages=[prior, HumanMessage(content="Refine.")]) + payload = client._get_request_payload(prompt_value) + assistant_dicts = [m for m in payload["messages"] if m.get("role") == "assistant"] + assert assistant_dicts[0]["reasoning_content"] == "weighed bull case" + + +# --------------------------------------------------------------------------- +# deepseek-reasoner: structured output unavailable, falls through to free-text +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestDeepSeekReasonerStructuredOutput: + def test_with_structured_output_raises_for_reasoner(self): + client = DeepSeekChatOpenAI( + model="deepseek-reasoner", + api_key="placeholder", + base_url="https://api.deepseek.com", + ) + from pydantic import BaseModel + + class _Sample(BaseModel): + answer: str + + with pytest.raises(NotImplementedError): + client.with_structured_output(_Sample) + + def test_with_structured_output_works_for_v4(self): + """V4 models (non-reasoner) accept tool_choice; structured output works.""" + client = DeepSeekChatOpenAI( + model="deepseek-v4-flash", + api_key="placeholder", + base_url="https://api.deepseek.com", + ) + from pydantic import BaseModel + + class _Sample(BaseModel): + answer: str + + # Should return a Runnable, not raise. (The actual API call would + # require a real key; we only assert binding succeeds.) + wrapped = client.with_structured_output(_Sample) + assert wrapped is not None + + +# --------------------------------------------------------------------------- +# Base class isolation: NormalizedChatOpenAI does NOT have DeepSeek behaviour +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBaseClassIsolation: + def test_normalized_does_not_propagate_reasoning_content(self): + """The general-purpose NormalizedChatOpenAI must not carry + DeepSeek-specific behaviour. Only the subclass does.""" + assert not hasattr(NormalizedChatOpenAI, "_get_request_payload") or ( + NormalizedChatOpenAI._get_request_payload + is NormalizedChatOpenAI.__bases__[0]._get_request_payload + ) diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index a2c57ed8..9a723a8b 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -65,10 +65,12 @@ MODEL_OPTIONS: ProviderModeOptions = { }, "deepseek": { "quick": [ + ("DeepSeek V4 Flash - Latest V4 fast model", "deepseek-v4-flash"), ("DeepSeek V3.2", "deepseek-chat"), ("Custom model ID", "custom"), ], "deep": [ + ("DeepSeek V4 Pro - Latest V4 flagship model", "deepseek-v4-pro"), ("DeepSeek V3.2 (thinking)", "deepseek-reasoner"), ("DeepSeek V3.2", "deepseek-chat"), ("Custom model ID", "custom"), diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index bbfcd39e..b74e26ef 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -1,6 +1,7 @@ import os from typing import Any, Optional +from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI from .base_client import BaseLLMClient, normalize_content @@ -11,29 +12,97 @@ class NormalizedChatOpenAI(ChatOpenAI): """ChatOpenAI with normalized content output. The Responses API returns content as a list of typed blocks - (reasoning, text, etc.). This normalizes to string for consistent - downstream handling. + (reasoning, text, etc.). ``invoke`` normalizes to string for + consistent downstream handling. ``with_structured_output`` defaults + to function-calling so the Responses-API parse path is avoided + (langchain-openai's parse path emits noisy + PydanticSerializationUnexpectedValue warnings per call without + affecting correctness). + + Provider-specific quirks (e.g. DeepSeek's thinking mode) live in + purpose-built subclasses below so this base class stays small. """ def invoke(self, input, config=None, **kwargs): return normalize_content(super().invoke(input, config, **kwargs)) def with_structured_output(self, schema, *, method=None, **kwargs): - """Wrap with structured output, defaulting to function_calling for OpenAI. - - langchain-openai's Responses-API-parse path (the default for json_schema - when use_responses_api=True) calls response.model_dump(...) on the OpenAI - SDK's union-typed parsed response, which makes Pydantic emit ~20 - PydanticSerializationUnexpectedValue warnings per call. The function-calling - path returns a plain tool-call shape that does not trigger that - serialization, so it is the cleaner choice for our combination of - use_responses_api=True + with_structured_output. Both paths use OpenAI's - strict mode and produce the same typed Pydantic instance. - """ if method is None: method = "function_calling" return super().with_structured_output(schema, method=method, **kwargs) + +def _input_to_messages(input_: Any) -> list: + """Normalise a langchain LLM input to a list of message objects. + + Accepts a list of messages, a ``ChatPromptValue`` (from a + ChatPromptTemplate), or anything else (treated as no messages). + Used by providers that need to walk the outgoing message history; + in particular DeepSeek thinking-mode propagation must work for + both bare-list invocations and ChatPromptTemplate-driven ones, so + treating only ``list`` here would silently skip half the call sites. + """ + if isinstance(input_, list): + return input_ + if hasattr(input_, "to_messages"): + return input_.to_messages() + return [] + + +class DeepSeekChatOpenAI(NormalizedChatOpenAI): + """DeepSeek-specific overrides on top of the OpenAI-compatible client. + + Two quirks that don't apply to other OpenAI-compatible providers: + + 1. **Thinking-mode round-trip.** When DeepSeek's thinking models return + a response with ``reasoning_content``, that field must be echoed + back as part of the assistant message on the next turn or the API + fails with HTTP 400. ``_create_chat_result`` captures the field on + receive and ``_get_request_payload`` re-attaches it on send. + + 2. **deepseek-reasoner has no tool_choice.** Structured output via + function-calling is unavailable, so we raise NotImplementedError + and let the agent factories fall back to free-text generation + (see ``tradingagents/agents/utils/structured.py``). + """ + + def _get_request_payload(self, input_, *, stop=None, **kwargs): + payload = super()._get_request_payload(input_, stop=stop, **kwargs) + outgoing = payload.get("messages", []) + for message_dict, message in zip(outgoing, _input_to_messages(input_)): + if not isinstance(message, AIMessage): + continue + reasoning = message.additional_kwargs.get("reasoning_content") + if reasoning is not None: + message_dict["reasoning_content"] = reasoning + return payload + + def _create_chat_result(self, response, generation_info=None): + chat_result = super()._create_chat_result(response, generation_info) + response_dict = ( + response + if isinstance(response, dict) + else response.model_dump( + exclude={"choices": {"__all__": {"message": {"parsed"}}}} + ) + ) + for generation, choice in zip( + chat_result.generations, response_dict.get("choices", []) + ): + reasoning = choice.get("message", {}).get("reasoning_content") + if reasoning is not None: + generation.message.additional_kwargs["reasoning_content"] = reasoning + return chat_result + + def with_structured_output(self, schema, *, method=None, **kwargs): + if self.model_name == "deepseek-reasoner": + raise NotImplementedError( + "deepseek-reasoner does not support tool_choice; structured " + "output is unavailable. Agent factories fall back to " + "free-text generation automatically." + ) + return super().with_structured_output(schema, method=method, **kwargs) + # Kwargs forwarded from user config to ChatOpenAI _PASSTHROUGH_KWARGS = ( "timeout", "max_retries", "reasoning_effort", @@ -75,10 +144,12 @@ class OpenAIClient(BaseLLMClient): self.warn_if_unknown_model() llm_kwargs = {"model": self.model} - # Provider-specific base URL and auth + # Provider-specific base URL and auth. An explicit base_url on the + # client (e.g. a corporate proxy) takes precedence over the + # provider default so users can route through their own gateway. if self.provider in _PROVIDER_CONFIG: - base_url, api_key_env = _PROVIDER_CONFIG[self.provider] - llm_kwargs["base_url"] = base_url + default_base, api_key_env = _PROVIDER_CONFIG[self.provider] + llm_kwargs["base_url"] = self.base_url or default_base if api_key_env: api_key = os.environ.get(api_key_env) if api_key: @@ -98,7 +169,10 @@ class OpenAIClient(BaseLLMClient): if self.provider == "openai": llm_kwargs["use_responses_api"] = True - return NormalizedChatOpenAI(**llm_kwargs) + # DeepSeek's thinking-mode quirks live in their own subclass so the + # base NormalizedChatOpenAI stays free of provider-specific branches. + chat_cls = DeepSeekChatOpenAI if self.provider == "deepseek" else NormalizedChatOpenAI + return chat_cls(**llm_kwargs) def validate_model(self) -> bool: """Validate model for the provider."""