mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
feat(llm): unify OpenAI-compatible providers behind a registry + generic endpoint
The OpenAI-compatible family (openai, xAI, DeepSeek, Qwen, GLM, MiniMax, OpenRouter, Ollama) all speak the same Chat Completions API and differ only by base_url, key, and two narrow wire-format quirks already isolated in subclasses. Replace the scattered base-URL dict, key handling, and client-class branches with one ProviderSpec registry that get_llm and the factory drive off; provider quirks stay in their subclasses. Add a generic "openai_compatible" provider for any OpenAI-compatible server (vLLM, LM Studio, llama.cpp, relays) via backend_url + optional key — adding a provider is now one registry row. Native Anthropic/Google keep their own clients (genuinely different APIs). Also fixes the env backend URL being ignored when the provider was chosen interactively (#978).
This commit is contained in:
@@ -24,7 +24,7 @@ def _resync_reloaded_modules():
|
||||
importlib.reload(cli.main)
|
||||
|
||||
|
||||
# ---- openai_client side: _resolve_provider_base_url -----------------------
|
||||
# ---- openai_client side: registry-driven base_url resolution --------------
|
||||
|
||||
|
||||
def _reload_client():
|
||||
@@ -32,16 +32,20 @@ def _reload_client():
|
||||
return importlib.reload(mod)
|
||||
|
||||
|
||||
def _base_url(mod, provider, **kwargs):
|
||||
return str(mod.OpenAIClient(model="m", provider=provider, **kwargs).get_llm().openai_api_base)
|
||||
|
||||
|
||||
def test_resolver_returns_default_when_env_unset(monkeypatch):
|
||||
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||
mod = _reload_client()
|
||||
assert mod._resolve_provider_base_url("ollama") == "http://localhost:11434/v1"
|
||||
assert _base_url(mod, "ollama") == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_resolver_returns_env_when_set(monkeypatch):
|
||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-ollama:11434/v1")
|
||||
mod = _reload_client()
|
||||
assert mod._resolve_provider_base_url("ollama") == "http://remote-ollama:11434/v1"
|
||||
assert _base_url(mod, "ollama") == "http://remote-ollama:11434/v1"
|
||||
|
||||
|
||||
def test_resolver_evaluation_is_call_time(monkeypatch):
|
||||
@@ -49,15 +53,15 @@ def test_resolver_evaluation_is_call_time(monkeypatch):
|
||||
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||
mod = _reload_client()
|
||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://late-set:11434/v1")
|
||||
assert mod._resolve_provider_base_url("ollama") == "http://late-set:11434/v1"
|
||||
assert _base_url(mod, "ollama") == "http://late-set:11434/v1"
|
||||
|
||||
|
||||
def test_resolver_does_not_affect_other_providers(monkeypatch):
|
||||
"""OLLAMA_BASE_URL should NOT leak into xai/deepseek/etc."""
|
||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://elsewhere/v1")
|
||||
mod = _reload_client()
|
||||
assert mod._resolve_provider_base_url("xai") == "https://api.x.ai/v1"
|
||||
assert mod._resolve_provider_base_url("deepseek") == "https://api.deepseek.com"
|
||||
assert _base_url(mod, "xai") == "https://api.x.ai/v1"
|
||||
assert _base_url(mod, "deepseek") == "https://api.deepseek.com"
|
||||
|
||||
|
||||
def test_client_get_llm_picks_up_env(monkeypatch):
|
||||
|
||||
75
tests/test_openai_compatible_provider.py
Normal file
75
tests/test_openai_compatible_provider.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Generic OpenAI-compatible provider (vLLM / LM Studio / llama.cpp / relays).
|
||||
|
||||
Verifies the user-supplied base_url is required and honored, the key is optional
|
||||
(keyless local default), Chat Completions (not the Responses API) is used, any
|
||||
model name is accepted, and the env backend URL precedence (#978).
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||
from tradingagents.llm_clients.factory import create_llm_client
|
||||
from tradingagents.llm_clients.validators import validate_model
|
||||
|
||||
# Note: assert by class NAME, not isinstance — other tests reload the
|
||||
# openai_client module, which would otherwise create a second class identity.
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_factory_routes_to_openai_client():
|
||||
client = create_llm_client(
|
||||
provider="openai_compatible", model="my-model", base_url="http://localhost:8000/v1"
|
||||
)
|
||||
assert type(client).__name__ == "OpenAIClient"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_base_url_required(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_COMPATIBLE_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="requires a base_url"):
|
||||
create_llm_client(provider="openai_compatible", model="m").get_llm()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_keyless_local_uses_placeholder_and_chat_completions(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_COMPATIBLE_API_KEY", raising=False)
|
||||
llm = create_llm_client(
|
||||
provider="openai_compatible", model="qwen2.5", base_url="http://localhost:8000/v1"
|
||||
).get_llm()
|
||||
assert type(llm).__name__ == "NormalizedChatOpenAI"
|
||||
assert str(llm.openai_api_base) == "http://localhost:8000/v1"
|
||||
# keyless local servers: a placeholder key is sent
|
||||
key = llm.openai_api_key.get_secret_value() if hasattr(llm.openai_api_key, "get_secret_value") else llm.openai_api_key
|
||||
assert key == "EMPTY"
|
||||
# must use Chat Completions, not OpenAI's Responses API
|
||||
assert getattr(llm, "use_responses_api", False) in (False, None)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_optional_key_from_env(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_COMPATIBLE_API_KEY", "sk-relay-123")
|
||||
llm = create_llm_client(
|
||||
provider="openai_compatible", model="m", base_url="https://relay.example/v1"
|
||||
).get_llm()
|
||||
key = llm.openai_api_key.get_secret_value() if hasattr(llm.openai_api_key, "get_secret_value") else llm.openai_api_key
|
||||
assert key == "sk-relay-123"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_any_model_accepted_no_forced_key():
|
||||
assert validate_model("openai_compatible", "literally-anything") is True
|
||||
# The key env exists (read for keyed relays) but the provider is marked
|
||||
# key-optional, so the CLI never forces a prompt and keyless servers work.
|
||||
assert get_api_key_env("openai_compatible") == "OPENAI_COMPATIBLE_API_KEY"
|
||||
from tradingagents.llm_clients.openai_client import OPENAI_COMPATIBLE_PROVIDERS
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].key_optional is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_env_backend_url_precedence():
|
||||
# #978: explicit env URL wins over the menu/default regardless of provider source.
|
||||
from cli.utils import resolve_backend_url
|
||||
assert resolve_backend_url("openai", "https://api.openai.com/v1", env_url="http://proxy/v1") == "http://proxy/v1"
|
||||
assert resolve_backend_url("openai", "https://api.openai.com/v1", env_url=None) == "https://api.openai.com/v1"
|
||||
assert resolve_backend_url("deepseek", None, None) == "https://api.deepseek.com"
|
||||
55
tests/test_provider_registry.py
Normal file
55
tests/test_provider_registry.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""The OpenAI-compatible provider registry is the single source of truth for the
|
||||
family; this guards each provider's resolved config (base URL, subclass, auth,
|
||||
Responses API) so a future edit can't silently break one.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.openai_client import (
|
||||
OPENAI_COMPATIBLE_PROVIDERS,
|
||||
DeepSeekChatOpenAI,
|
||||
MinimaxChatOpenAI,
|
||||
NormalizedChatOpenAI,
|
||||
is_openai_compatible,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_registry_membership():
|
||||
assert is_openai_compatible("openai")
|
||||
assert is_openai_compatible("openai_compatible") # the generic endpoint
|
||||
# native (different API) clients are intentionally NOT in the registry
|
||||
assert not is_openai_compatible("anthropic")
|
||||
assert not is_openai_compatible("google")
|
||||
assert not is_openai_compatible("azure")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("provider,base_url,chat_class,responses", [
|
||||
("openai", None, NormalizedChatOpenAI, True),
|
||||
("xai", "https://api.x.ai/v1", NormalizedChatOpenAI, False),
|
||||
("deepseek", "https://api.deepseek.com", DeepSeekChatOpenAI, False),
|
||||
("qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", NormalizedChatOpenAI, False),
|
||||
("qwen-cn", "https://dashscope.aliyuncs.com/compatible-mode/v1", NormalizedChatOpenAI, False),
|
||||
("glm", "https://api.z.ai/api/paas/v4/", NormalizedChatOpenAI, False),
|
||||
("glm-cn", "https://open.bigmodel.cn/api/paas/v4/", NormalizedChatOpenAI, False),
|
||||
("minimax", "https://api.minimax.io/v1", MinimaxChatOpenAI, False),
|
||||
("minimax-cn", "https://api.minimaxi.com/v1", MinimaxChatOpenAI, False),
|
||||
("openrouter", "https://openrouter.ai/api/v1", NormalizedChatOpenAI, False),
|
||||
("ollama", "http://localhost:11434/v1", NormalizedChatOpenAI, False),
|
||||
])
|
||||
def test_registry_spec(provider, base_url, chat_class, responses):
|
||||
spec = OPENAI_COMPATIBLE_PROVIDERS[provider]
|
||||
assert spec.base_url == base_url
|
||||
assert spec.chat_class is chat_class
|
||||
assert spec.use_responses_api is responses
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_key_optionality():
|
||||
# Local/generic endpoints are key-optional; hosted APIs require a key.
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["ollama"].key_optional is True
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].key_optional is True
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].require_base_url is True
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["xai"].key_optional is False
|
||||
# OLLAMA_BASE_URL is the only base-URL env override.
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["ollama"].base_url_env == "OLLAMA_BASE_URL"
|
||||
Reference in New Issue
Block a user