Files
tradingagents/tradingagents/llm_clients/openai_client.py
Yijia-Xiao d0dd0420ad feat(llm): GLM dual-region split + catalog refresh
Zhipu serves GLM under two brands with separate accounts (Z.AI
international vs BigModel China); the CLI URL pointed at one while
the openai_client default pointed at the other. Split into glm +
glm-cn with secondary region prompt (same UX as Qwen + MiniMax).
Catalog adds glm-5-turbo and glm-4.5-air per docs.z.ai.
2026-05-11 04:19:50 +00:00

229 lines
9.8 KiB
Python

import os
from typing import Any, Optional
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from .base_client import BaseLLMClient, normalize_content
from .capabilities import get_capabilities
from .validators import validate_model
class NormalizedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with normalized content output and capability-aware binding.
The Responses API returns content as a list of typed blocks
(reasoning, text, etc.). ``invoke`` normalizes to string for
consistent downstream handling.
``with_structured_output`` consults the per-model capability table
(``capabilities.get_capabilities``) to pick the method and to decide
whether ``tool_choice`` may be sent. Models that reject ``tool_choice``
(e.g. DeepSeek V4 and reasoner — per their official tool-calling
guide) still bind the schema as a tool, but no ``tool_choice``
parameter is sent.
Provider-specific quirks beyond structured-output (e.g. DeepSeek's
reasoning_content roundtrip) live in subclasses so this base class
stays small.
"""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
def with_structured_output(self, schema, *, method=None, **kwargs):
caps = get_capabilities(self.model_name)
if caps.preferred_structured_method == "none":
raise NotImplementedError(
f"{self.model_name} has no structured-output method available; "
f"agent factories will fall back to free-text generation."
)
method = method or caps.preferred_structured_method
# When the model rejects tool_choice, suppress langchain's hardcoded
# value. The schema is still bound as a tool — exactly what
# DeepSeek's official tool-calling examples do.
if method == "function_calling" and not caps.supports_tool_choice:
kwargs.setdefault("tool_choice", None)
return super().with_structured_output(schema, method=method, **kwargs)
def _input_to_messages(input_: Any) -> list:
"""Normalise a langchain LLM input to a list of message objects.
Accepts a list of messages, a ``ChatPromptValue`` (from a
ChatPromptTemplate), or anything else (treated as no messages).
Used by providers that need to walk the outgoing message history;
in particular DeepSeek thinking-mode propagation must work for
both bare-list invocations and ChatPromptTemplate-driven ones, so
treating only ``list`` here would silently skip half the call sites.
"""
if isinstance(input_, list):
return input_
if hasattr(input_, "to_messages"):
return input_.to_messages()
return []
class DeepSeekChatOpenAI(NormalizedChatOpenAI):
"""DeepSeek-specific overrides on top of the OpenAI-compatible client.
Thinking-mode round-trip is the only DeepSeek-specific behavior that
stays here. When DeepSeek's thinking models return a response with
``reasoning_content``, that field must be echoed back as part of the
assistant message on the next turn or the API fails with HTTP 400.
``_create_chat_result`` captures it on receive and
``_get_request_payload`` re-attaches it on send.
Tool-choice handling for V4 and reasoner — those models reject the
``tool_choice`` parameter — is handled by the capability dispatch in
``NormalizedChatOpenAI.with_structured_output``, not here.
"""
def _get_request_payload(self, input_, *, stop=None, **kwargs):
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
outgoing = payload.get("messages", [])
for message_dict, message in zip(outgoing, _input_to_messages(input_)):
if not isinstance(message, AIMessage):
continue
reasoning = message.additional_kwargs.get("reasoning_content")
if reasoning is not None:
message_dict["reasoning_content"] = reasoning
return payload
def _create_chat_result(self, response, generation_info=None):
chat_result = super()._create_chat_result(response, generation_info)
response_dict = (
response
if isinstance(response, dict)
else response.model_dump(
exclude={"choices": {"__all__": {"message": {"parsed"}}}}
)
)
for generation, choice in zip(
chat_result.generations, response_dict.get("choices", [])
):
reasoning = choice.get("message", {}).get("reasoning_content")
if reasoning is not None:
generation.message.additional_kwargs["reasoning_content"] = reasoning
return chat_result
class MinimaxChatOpenAI(NormalizedChatOpenAI):
"""MiniMax-specific overrides on top of the OpenAI-compatible client.
M2.x reasoning models embed ``<think>...</think>`` blocks directly in
``message.content`` by default, which would pollute saved reports.
Per platform.minimax.io/docs/api-reference/text-openai-api, setting
``reasoning_split=True`` in the request body redirects the thinking
block into ``reasoning_details`` so ``content`` stays clean.
Tool-choice handling for M2.x — those models accept only the string
enum ``{"none", "auto"}`` and reject langchain's function-spec dict —
is handled by the capability dispatch in
``NormalizedChatOpenAI.with_structured_output``, not here.
"""
def _get_request_payload(self, input_, *, stop=None, **kwargs):
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload.setdefault("reasoning_split", True)
return payload
# Kwargs forwarded from user config to ChatOpenAI
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "reasoning_effort",
"api_key", "callbacks", "http_client", "http_async_client",
)
# Provider base URLs and API key env vars
_PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
# DashScope exposes two regional endpoints with separate accounts; an
# international key won't authenticate against the China endpoint and
# vice versa (fixes issue #758).
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
"qwen-cn": ("https://dashscope.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_CN_API_KEY"),
# Zhipu exposes the same GLM models under two brands with separate
# accounts: Z.AI (international, api.z.ai) and BigModel
# (open.bigmodel.cn, China). Keys aren't interchangeable across them.
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
"glm-cn": ("https://open.bigmodel.cn/api/paas/v4/", "ZHIPU_CN_API_KEY"),
# MiniMax exposes two regional endpoints with separate keys; mainland
# Chinese users hit .com while global users hit .io.
"minimax": ("https://api.minimax.io/v1", "MINIMAX_API_KEY"),
"minimax-cn": ("https://api.minimaxi.com/v1", "MINIMAX_CN_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None),
}
class OpenAIClient(BaseLLMClient):
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
For native OpenAI models, uses the Responses API (/v1/responses) which
supports reasoning_effort with function tools across all model families
(GPT-4.1, GPT-5). Third-party compatible providers (xAI, OpenRouter,
Ollama) use standard Chat Completions.
"""
def __init__(
self,
model: str,
base_url: Optional[str] = None,
provider: str = "openai",
**kwargs,
):
super().__init__(model, base_url, **kwargs)
self.provider = provider.lower()
def get_llm(self) -> Any:
"""Return configured ChatOpenAI instance."""
self.warn_if_unknown_model()
llm_kwargs = {"model": self.model}
# Provider-specific base URL and auth. An explicit base_url on the
# client (e.g. a corporate proxy) takes precedence over the
# provider default so users can route through their own gateway.
if self.provider in _PROVIDER_CONFIG:
default_base, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = self.base_url or default_base
if api_key_env:
api_key = os.environ.get(api_key_env)
if api_key:
llm_kwargs["api_key"] = api_key
else:
raise ValueError(
f"API key for provider '{self.provider}' is not set. "
f"Please set the {api_key_env} environment variable "
f"(e.g. add {api_key_env}=your_key to your .env file)."
)
else:
llm_kwargs["api_key"] = "ollama"
elif self.base_url:
llm_kwargs["base_url"] = self.base_url
# Forward user-provided kwargs
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
# Native OpenAI: use Responses API for consistent behavior across
# all model families. Third-party providers use Chat Completions.
if self.provider == "openai":
llm_kwargs["use_responses_api"] = True
# Provider-specific quirks live in their own subclasses so the
# base NormalizedChatOpenAI stays free of provider branches.
if self.provider == "deepseek":
chat_cls = DeepSeekChatOpenAI
elif self.provider in ("minimax", "minimax-cn"):
chat_cls = MinimaxChatOpenAI
else:
chat_cls = NormalizedChatOpenAI
return chat_cls(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for the provider."""
return validate_model(self.provider, self.model)