mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
feat(llm): unify OpenAI-compatible providers behind a registry + generic endpoint
The OpenAI-compatible family (openai, xAI, DeepSeek, Qwen, GLM, MiniMax, OpenRouter, Ollama) all speak the same Chat Completions API and differ only by base_url, key, and two narrow wire-format quirks already isolated in subclasses. Replace the scattered base-URL dict, key handling, and client-class branches with one ProviderSpec registry that get_llm and the factory drive off; provider quirks stay in their subclasses. Add a generic "openai_compatible" provider for any OpenAI-compatible server (vLLM, LM Studio, llama.cpp, relays) via backend_url + optional key — adding a provider is now one registry row. Native Anthropic/Google keep their own clients (genuinely different APIs). Also fixes the env backend URL being ignored when the provider was chosen interactively (#978).
This commit is contained in:
@@ -157,6 +157,8 @@ For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise
|
||||
|
||||
For local models, configure Ollama with `llm_provider: "ollama"`. The default endpoint is `http://localhost:11434/v1`; set `OLLAMA_BASE_URL` to point at a remote `ollama-serve`. Pull models with `ollama pull <name>`, and pick "Custom model ID" in the CLI for any model not listed by default.
|
||||
|
||||
For any other OpenAI-compatible server (vLLM, LM Studio, llama.cpp, or a custom relay), use `llm_provider: "openai_compatible"` and set the endpoint via `backend_url` (or `TRADINGAGENTS_LLM_BACKEND_URL`), e.g. `http://localhost:8000/v1` for vLLM or `http://localhost:1234/v1` for LM Studio. The model is whatever your server serves. No key is needed for local servers; set `OPENAI_COMPATIBLE_API_KEY` when the endpoint requires one.
|
||||
|
||||
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
|
||||
15
cli/main.py
15
cli/main.py
@@ -571,7 +571,9 @@ def get_user_selections():
|
||||
provider_from_env = bool(os.environ.get("TRADINGAGENTS_LLM_PROVIDER"))
|
||||
if provider_from_env:
|
||||
selected_llm_provider = DEFAULT_CONFIG["llm_provider"].lower()
|
||||
backend_url = DEFAULT_CONFIG["backend_url"] or provider_default_url(selected_llm_provider)
|
||||
backend_url = resolve_backend_url(
|
||||
selected_llm_provider, env_url=DEFAULT_CONFIG["backend_url"]
|
||||
)
|
||||
console.print(f"[green]✓ LLM provider from environment:[/green] {selected_llm_provider}")
|
||||
console.print(f"[green]✓ Backend URL:[/green] {backend_url}")
|
||||
# Still confirm/persist the API key so the run doesn't fail later.
|
||||
@@ -594,6 +596,17 @@ def get_user_selections():
|
||||
elif selected_llm_provider == "glm":
|
||||
selected_llm_provider, backend_url = ask_glm_region()
|
||||
|
||||
# Honor an explicit env backend URL even when the provider was chosen
|
||||
# interactively, so it isn't overwritten by the menu default (#978).
|
||||
backend_url = resolve_backend_url(
|
||||
selected_llm_provider, backend_url, env_url=DEFAULT_CONFIG["backend_url"]
|
||||
)
|
||||
|
||||
# The generic OpenAI-compatible endpoint has no default; ask for it if
|
||||
# neither the menu nor the environment supplied one.
|
||||
if selected_llm_provider == "openai_compatible" and not backend_url:
|
||||
backend_url = prompt_openai_compatible_url()
|
||||
|
||||
# For Ollama, surface the resolved endpoint (OLLAMA_BASE_URL vs default)
|
||||
# before model selection so it's obvious where we're connecting.
|
||||
if selected_llm_provider == "ollama":
|
||||
|
||||
35
cli/utils.py
35
cli/utils.py
@@ -313,6 +313,7 @@ def _llm_provider_table() -> list[tuple[str, str, str | None]]:
|
||||
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Azure OpenAI", "azure", None),
|
||||
("Ollama", "ollama", ollama_url),
|
||||
("OpenAI-compatible (vLLM, LM Studio, llama.cpp, custom relay)", "openai_compatible", None),
|
||||
]
|
||||
|
||||
|
||||
@@ -325,6 +326,33 @@ def provider_default_url(provider_key: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def resolve_backend_url(
|
||||
provider: str, menu_url: str | None = None, env_url: str | None = None
|
||||
) -> str | None:
|
||||
"""Resolve the backend URL with the correct precedence.
|
||||
|
||||
An explicit env override (``env_url``, from ``TRADINGAGENTS_LLM_BACKEND_URL``
|
||||
via ``DEFAULT_CONFIG['backend_url']``) is honored regardless of how the
|
||||
provider was chosen — interactively or from the environment (#978).
|
||||
Otherwise the menu/region URL, then the provider's default.
|
||||
"""
|
||||
return env_url or menu_url or provider_default_url(provider)
|
||||
|
||||
|
||||
def prompt_openai_compatible_url() -> str:
|
||||
"""Prompt for a custom OpenAI-compatible endpoint base URL."""
|
||||
url = questionary.text(
|
||||
"Enter the OpenAI-compatible base URL "
|
||||
"(e.g. http://localhost:8000/v1 for vLLM, http://localhost:1234/v1 for LM Studio):",
|
||||
validate=lambda x: x.strip().startswith(("http://", "https://"))
|
||||
or "Enter a URL starting with http:// or https://",
|
||||
).ask()
|
||||
if not url:
|
||||
console.print("\n[red]No endpoint URL provided. Exiting...[/red]")
|
||||
exit(1)
|
||||
return url.strip()
|
||||
|
||||
|
||||
def select_llm_provider() -> tuple[str, str | None]:
|
||||
"""Select the LLM provider and its API endpoint."""
|
||||
PROVIDERS = _llm_provider_table()
|
||||
@@ -538,6 +566,13 @@ def ensure_api_key(provider: str) -> Optional[str]:
|
||||
if env_var is None:
|
||||
return None # ollama / unknown — no key check possible
|
||||
|
||||
# Key-optional providers (generic OpenAI-compatible / local servers) read the
|
||||
# key when present but must never force an interactive prompt.
|
||||
from tradingagents.llm_clients.openai_client import OPENAI_COMPATIBLE_PROVIDERS
|
||||
spec = OPENAI_COMPATIBLE_PROVIDERS.get(provider.lower())
|
||||
if spec is not None and spec.key_optional:
|
||||
return os.environ.get(env_var)
|
||||
|
||||
existing = os.environ.get(env_var)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
@@ -24,7 +24,7 @@ def _resync_reloaded_modules():
|
||||
importlib.reload(cli.main)
|
||||
|
||||
|
||||
# ---- openai_client side: _resolve_provider_base_url -----------------------
|
||||
# ---- openai_client side: registry-driven base_url resolution --------------
|
||||
|
||||
|
||||
def _reload_client():
|
||||
@@ -32,16 +32,20 @@ def _reload_client():
|
||||
return importlib.reload(mod)
|
||||
|
||||
|
||||
def _base_url(mod, provider, **kwargs):
|
||||
return str(mod.OpenAIClient(model="m", provider=provider, **kwargs).get_llm().openai_api_base)
|
||||
|
||||
|
||||
def test_resolver_returns_default_when_env_unset(monkeypatch):
|
||||
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||
mod = _reload_client()
|
||||
assert mod._resolve_provider_base_url("ollama") == "http://localhost:11434/v1"
|
||||
assert _base_url(mod, "ollama") == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_resolver_returns_env_when_set(monkeypatch):
|
||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-ollama:11434/v1")
|
||||
mod = _reload_client()
|
||||
assert mod._resolve_provider_base_url("ollama") == "http://remote-ollama:11434/v1"
|
||||
assert _base_url(mod, "ollama") == "http://remote-ollama:11434/v1"
|
||||
|
||||
|
||||
def test_resolver_evaluation_is_call_time(monkeypatch):
|
||||
@@ -49,15 +53,15 @@ def test_resolver_evaluation_is_call_time(monkeypatch):
|
||||
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||
mod = _reload_client()
|
||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://late-set:11434/v1")
|
||||
assert mod._resolve_provider_base_url("ollama") == "http://late-set:11434/v1"
|
||||
assert _base_url(mod, "ollama") == "http://late-set:11434/v1"
|
||||
|
||||
|
||||
def test_resolver_does_not_affect_other_providers(monkeypatch):
|
||||
"""OLLAMA_BASE_URL should NOT leak into xai/deepseek/etc."""
|
||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://elsewhere/v1")
|
||||
mod = _reload_client()
|
||||
assert mod._resolve_provider_base_url("xai") == "https://api.x.ai/v1"
|
||||
assert mod._resolve_provider_base_url("deepseek") == "https://api.deepseek.com"
|
||||
assert _base_url(mod, "xai") == "https://api.x.ai/v1"
|
||||
assert _base_url(mod, "deepseek") == "https://api.deepseek.com"
|
||||
|
||||
|
||||
def test_client_get_llm_picks_up_env(monkeypatch):
|
||||
|
||||
75
tests/test_openai_compatible_provider.py
Normal file
75
tests/test_openai_compatible_provider.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Generic OpenAI-compatible provider (vLLM / LM Studio / llama.cpp / relays).
|
||||
|
||||
Verifies the user-supplied base_url is required and honored, the key is optional
|
||||
(keyless local default), Chat Completions (not the Responses API) is used, any
|
||||
model name is accepted, and the env backend URL precedence (#978).
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||
from tradingagents.llm_clients.factory import create_llm_client
|
||||
from tradingagents.llm_clients.validators import validate_model
|
||||
|
||||
# Note: assert by class NAME, not isinstance — other tests reload the
|
||||
# openai_client module, which would otherwise create a second class identity.
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_factory_routes_to_openai_client():
|
||||
client = create_llm_client(
|
||||
provider="openai_compatible", model="my-model", base_url="http://localhost:8000/v1"
|
||||
)
|
||||
assert type(client).__name__ == "OpenAIClient"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_base_url_required(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_COMPATIBLE_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="requires a base_url"):
|
||||
create_llm_client(provider="openai_compatible", model="m").get_llm()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_keyless_local_uses_placeholder_and_chat_completions(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_COMPATIBLE_API_KEY", raising=False)
|
||||
llm = create_llm_client(
|
||||
provider="openai_compatible", model="qwen2.5", base_url="http://localhost:8000/v1"
|
||||
).get_llm()
|
||||
assert type(llm).__name__ == "NormalizedChatOpenAI"
|
||||
assert str(llm.openai_api_base) == "http://localhost:8000/v1"
|
||||
# keyless local servers: a placeholder key is sent
|
||||
key = llm.openai_api_key.get_secret_value() if hasattr(llm.openai_api_key, "get_secret_value") else llm.openai_api_key
|
||||
assert key == "EMPTY"
|
||||
# must use Chat Completions, not OpenAI's Responses API
|
||||
assert getattr(llm, "use_responses_api", False) in (False, None)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_optional_key_from_env(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_COMPATIBLE_API_KEY", "sk-relay-123")
|
||||
llm = create_llm_client(
|
||||
provider="openai_compatible", model="m", base_url="https://relay.example/v1"
|
||||
).get_llm()
|
||||
key = llm.openai_api_key.get_secret_value() if hasattr(llm.openai_api_key, "get_secret_value") else llm.openai_api_key
|
||||
assert key == "sk-relay-123"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_any_model_accepted_no_forced_key():
|
||||
assert validate_model("openai_compatible", "literally-anything") is True
|
||||
# The key env exists (read for keyed relays) but the provider is marked
|
||||
# key-optional, so the CLI never forces a prompt and keyless servers work.
|
||||
assert get_api_key_env("openai_compatible") == "OPENAI_COMPATIBLE_API_KEY"
|
||||
from tradingagents.llm_clients.openai_client import OPENAI_COMPATIBLE_PROVIDERS
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].key_optional is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_env_backend_url_precedence():
|
||||
# #978: explicit env URL wins over the menu/default regardless of provider source.
|
||||
from cli.utils import resolve_backend_url
|
||||
assert resolve_backend_url("openai", "https://api.openai.com/v1", env_url="http://proxy/v1") == "http://proxy/v1"
|
||||
assert resolve_backend_url("openai", "https://api.openai.com/v1", env_url=None) == "https://api.openai.com/v1"
|
||||
assert resolve_backend_url("deepseek", None, None) == "https://api.deepseek.com"
|
||||
55
tests/test_provider_registry.py
Normal file
55
tests/test_provider_registry.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""The OpenAI-compatible provider registry is the single source of truth for the
|
||||
family; this guards each provider's resolved config (base URL, subclass, auth,
|
||||
Responses API) so a future edit can't silently break one.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from tradingagents.llm_clients.openai_client import (
|
||||
OPENAI_COMPATIBLE_PROVIDERS,
|
||||
DeepSeekChatOpenAI,
|
||||
MinimaxChatOpenAI,
|
||||
NormalizedChatOpenAI,
|
||||
is_openai_compatible,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_registry_membership():
|
||||
assert is_openai_compatible("openai")
|
||||
assert is_openai_compatible("openai_compatible") # the generic endpoint
|
||||
# native (different API) clients are intentionally NOT in the registry
|
||||
assert not is_openai_compatible("anthropic")
|
||||
assert not is_openai_compatible("google")
|
||||
assert not is_openai_compatible("azure")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("provider,base_url,chat_class,responses", [
|
||||
("openai", None, NormalizedChatOpenAI, True),
|
||||
("xai", "https://api.x.ai/v1", NormalizedChatOpenAI, False),
|
||||
("deepseek", "https://api.deepseek.com", DeepSeekChatOpenAI, False),
|
||||
("qwen", "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", NormalizedChatOpenAI, False),
|
||||
("qwen-cn", "https://dashscope.aliyuncs.com/compatible-mode/v1", NormalizedChatOpenAI, False),
|
||||
("glm", "https://api.z.ai/api/paas/v4/", NormalizedChatOpenAI, False),
|
||||
("glm-cn", "https://open.bigmodel.cn/api/paas/v4/", NormalizedChatOpenAI, False),
|
||||
("minimax", "https://api.minimax.io/v1", MinimaxChatOpenAI, False),
|
||||
("minimax-cn", "https://api.minimaxi.com/v1", MinimaxChatOpenAI, False),
|
||||
("openrouter", "https://openrouter.ai/api/v1", NormalizedChatOpenAI, False),
|
||||
("ollama", "http://localhost:11434/v1", NormalizedChatOpenAI, False),
|
||||
])
|
||||
def test_registry_spec(provider, base_url, chat_class, responses):
|
||||
spec = OPENAI_COMPATIBLE_PROVIDERS[provider]
|
||||
assert spec.base_url == base_url
|
||||
assert spec.chat_class is chat_class
|
||||
assert spec.use_responses_api is responses
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_key_optionality():
|
||||
# Local/generic endpoints are key-optional; hosted APIs require a key.
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["ollama"].key_optional is True
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].key_optional is True
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["openai_compatible"].require_base_url is True
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["xai"].key_optional is False
|
||||
# OLLAMA_BASE_URL is the only base-URL env override.
|
||||
assert OPENAI_COMPATIBLE_PROVIDERS["ollama"].base_url_env == "OLLAMA_BASE_URL"
|
||||
@@ -32,6 +32,10 @@ PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
# Local runtimes do not authenticate.
|
||||
"ollama": None,
|
||||
# Generic OpenAI-compatible endpoint: the client reads this when set (keyed
|
||||
# relays), but it is marked key-optional in the provider registry so the CLI
|
||||
# never forces a prompt and keyless local servers still work.
|
||||
"openai_compatible": "OPENAI_COMPATIBLE_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -2,16 +2,6 @@ from typing import Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
# Providers that use the OpenAI-compatible chat completions API
|
||||
_OPENAI_COMPATIBLE = (
|
||||
"openai", "xai", "deepseek",
|
||||
"qwen", "qwen-cn",
|
||||
"glm", "glm-cn",
|
||||
"minimax", "minimax-cn",
|
||||
"ollama", "openrouter",
|
||||
)
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -38,10 +28,9 @@ def create_llm_client(
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in _OPENAI_COMPATIBLE:
|
||||
from .openai_client import OpenAIClient
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
# Native (non-OpenAI) APIs are matched first so their string check doesn't
|
||||
# import the OpenAI client. Everything else is OpenAI-compatible and routes
|
||||
# through the provider registry (single source of truth).
|
||||
if provider_lower == "anthropic":
|
||||
from .anthropic_client import AnthropicClient
|
||||
return AnthropicClient(model, base_url, **kwargs)
|
||||
@@ -54,4 +43,8 @@ def create_llm_client(
|
||||
from .azure_client import AzureOpenAIClient
|
||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||
|
||||
from .openai_client import OpenAIClient, is_openai_compatible
|
||||
if is_openai_compatible(provider_lower):
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
@@ -175,6 +175,12 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
# Generic OpenAI-compatible endpoint: the model is whatever the user's
|
||||
# server serves, so only "Custom model ID" is offered.
|
||||
"openai_compatible": {
|
||||
"quick": [("Custom model ID", "custom")],
|
||||
"deep": [("Custom model ID", "custom")],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
@@ -148,39 +149,56 @@ _PASSTHROUGH_KWARGS = (
|
||||
"api_key", "callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
# Provider base URLs. API-key env vars live in api_key_env.PROVIDER_API_KEY_ENV
|
||||
# (one canonical mapping consulted by both this client and the CLI's
|
||||
# interactive key-prompt). Dual-region providers (qwen/glm/minimax) keep
|
||||
# separate endpoints because international and China accounts cannot share
|
||||
# credentials (#758).
|
||||
_PROVIDER_BASE_URL = {
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"deepseek": "https://api.deepseek.com",
|
||||
"qwen": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
"qwen-cn": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"glm": "https://api.z.ai/api/paas/v4/",
|
||||
"glm-cn": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"minimax": "https://api.minimax.io/v1",
|
||||
"minimax-cn": "https://api.minimaxi.com/v1",
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"ollama": "http://localhost:11434/v1",
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
"""Declarative config for one OpenAI-compatible provider.
|
||||
|
||||
The OpenAI-compatible family (OpenAI, xAI, DeepSeek, Qwen, GLM, MiniMax,
|
||||
OpenRouter, Ollama, and any user endpoint) all speak the same Chat
|
||||
Completions API and differ only by these fields — so one row here replaces
|
||||
the former per-provider base-URL dict, auth handling, and client-class
|
||||
branches. Native Anthropic / Google use their own clients (genuinely
|
||||
different APIs) and are intentionally NOT in this registry.
|
||||
|
||||
The API-key env var stays in ``api_key_env.PROVIDER_API_KEY_ENV`` (the single
|
||||
source consulted by both this client and the CLI prompt); only behavior that
|
||||
is provider-specific (base URL, key optionality, wire-format quirks via
|
||||
``chat_class``) lives here.
|
||||
"""
|
||||
|
||||
chat_class: type = NormalizedChatOpenAI # provider quirks live in the subclass
|
||||
base_url: Optional[str] = None # default endpoint (None -> SDK default)
|
||||
base_url_env: Optional[str] = None # env var that overrides base_url (e.g. OLLAMA_BASE_URL)
|
||||
key_optional: bool = False # don't require/prompt; send a placeholder if unset
|
||||
placeholder_key: str = "EMPTY" # sent when no key is available (keyless local servers)
|
||||
require_base_url: bool = False # error if no base_url is resolved (generic endpoint)
|
||||
use_responses_api: bool = False # native OpenAI Responses API
|
||||
|
||||
|
||||
# Single source of truth for the OpenAI-compatible provider family. Dual-region
|
||||
# providers (qwen/glm/minimax) keep separate endpoints because international and
|
||||
# China accounts cannot share credentials (#758).
|
||||
OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderSpec] = {
|
||||
"openai": ProviderSpec(use_responses_api=True),
|
||||
"xai": ProviderSpec(base_url="https://api.x.ai/v1"),
|
||||
"deepseek": ProviderSpec(base_url="https://api.deepseek.com", chat_class=DeepSeekChatOpenAI),
|
||||
"qwen": ProviderSpec(base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1"),
|
||||
"qwen-cn": ProviderSpec(base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
"glm": ProviderSpec(base_url="https://api.z.ai/api/paas/v4/"),
|
||||
"glm-cn": ProviderSpec(base_url="https://open.bigmodel.cn/api/paas/v4/"),
|
||||
"minimax": ProviderSpec(base_url="https://api.minimax.io/v1", chat_class=MinimaxChatOpenAI),
|
||||
"minimax-cn": ProviderSpec(base_url="https://api.minimaxi.com/v1", chat_class=MinimaxChatOpenAI),
|
||||
"openrouter": ProviderSpec(base_url="https://openrouter.ai/api/v1"),
|
||||
"ollama": ProviderSpec(base_url="http://localhost:11434/v1", base_url_env="OLLAMA_BASE_URL",
|
||||
key_optional=True, placeholder_key="ollama"),
|
||||
# Generic endpoint: user supplies base_url; key optional (keyless local).
|
||||
"openai_compatible": ProviderSpec(require_base_url=True, key_optional=True),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_provider_base_url(provider: str) -> Optional[str]:
|
||||
"""Default base URL for ``provider``, with env-var overrides where defined.
|
||||
|
||||
Currently only Ollama supports an env-var override (``OLLAMA_BASE_URL``),
|
||||
matching the convention in the broader Ollama tooling ecosystem so users
|
||||
can point at a remote ollama-serve without editing code. The check is
|
||||
call-time, not import-time, so tests that monkeypatch the env after
|
||||
import behave correctly.
|
||||
"""
|
||||
if provider == "ollama":
|
||||
env_url = os.environ.get("OLLAMA_BASE_URL")
|
||||
if env_url:
|
||||
return env_url
|
||||
return _PROVIDER_BASE_URL.get(provider)
|
||||
def is_openai_compatible(provider: str) -> bool:
|
||||
"""Whether ``provider`` is served by the OpenAI-compatible registry."""
|
||||
return provider.lower() in OPENAI_COMPATIBLE_PROVIDERS
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
@@ -203,28 +221,47 @@ class OpenAIClient(BaseLLMClient):
|
||||
self.provider = provider.lower()
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatOpenAI instance."""
|
||||
"""Return a configured ChatOpenAI instance, driven by the provider registry."""
|
||||
self.warn_if_unknown_model()
|
||||
llm_kwargs = {"model": self.model}
|
||||
spec = OPENAI_COMPATIBLE_PROVIDERS.get(self.provider)
|
||||
chat_cls = NormalizedChatOpenAI
|
||||
|
||||
# Provider-specific base URL and auth. An explicit base_url on the
|
||||
# client (e.g. a corporate proxy) takes precedence over the
|
||||
# provider default so users can route through their own gateway.
|
||||
if self.provider in _PROVIDER_BASE_URL:
|
||||
llm_kwargs["base_url"] = self.base_url or _resolve_provider_base_url(self.provider)
|
||||
if spec is not None:
|
||||
chat_cls = spec.chat_class
|
||||
|
||||
# base_url precedence: explicit client base_url (carries the config /
|
||||
# TRADINGAGENTS_LLM_BACKEND_URL value) > provider env override (e.g.
|
||||
# OLLAMA_BASE_URL) > provider default. None means use the SDK default.
|
||||
env_base_url = os.environ.get(spec.base_url_env) if spec.base_url_env else None
|
||||
base_url = self.base_url or env_base_url or spec.base_url
|
||||
if spec.require_base_url and not base_url:
|
||||
raise ValueError(
|
||||
f"Provider '{self.provider}' requires a base_url. Set it via "
|
||||
"backend_url / TRADINGAGENTS_LLM_BACKEND_URL to your endpoint, "
|
||||
"e.g. http://localhost:8000/v1 (vLLM) or http://localhost:1234/v1 "
|
||||
"(LM Studio)."
|
||||
)
|
||||
if base_url:
|
||||
llm_kwargs["base_url"] = base_url
|
||||
|
||||
# API key: required unless key_optional; keyless local servers get a
|
||||
# placeholder. The env-var name is the single source in api_key_env.
|
||||
api_key_env = get_api_key_env(self.provider)
|
||||
if api_key_env:
|
||||
api_key = os.environ.get(api_key_env)
|
||||
api_key = os.environ.get(api_key_env) if api_key_env else None
|
||||
if api_key:
|
||||
llm_kwargs["api_key"] = api_key
|
||||
else:
|
||||
elif spec.key_optional:
|
||||
llm_kwargs["api_key"] = spec.placeholder_key
|
||||
elif api_key_env:
|
||||
raise ValueError(
|
||||
f"API key for provider '{self.provider}' is not set. "
|
||||
f"Please set the {api_key_env} environment variable "
|
||||
f"(e.g. add {api_key_env}=your_key to your .env file)."
|
||||
)
|
||||
else:
|
||||
llm_kwargs["api_key"] = "ollama"
|
||||
|
||||
if spec.use_responses_api:
|
||||
llm_kwargs["use_responses_api"] = True
|
||||
elif self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
@@ -233,19 +270,7 @@ class OpenAIClient(BaseLLMClient):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
# Native OpenAI: use Responses API for consistent behavior across
|
||||
# all model families. Third-party providers use Chat Completions.
|
||||
if self.provider == "openai":
|
||||
llm_kwargs["use_responses_api"] = True
|
||||
|
||||
# Provider-specific quirks live in their own subclasses so the
|
||||
# base NormalizedChatOpenAI stays free of provider branches.
|
||||
if self.provider == "deepseek":
|
||||
chat_cls = DeepSeekChatOpenAI
|
||||
elif self.provider in ("minimax", "minimax-cn"):
|
||||
chat_cls = MinimaxChatOpenAI
|
||||
else:
|
||||
chat_cls = NormalizedChatOpenAI
|
||||
# The subclass (provider quirks) comes from the registry spec.
|
||||
return chat_cls(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
|
||||
@@ -3,21 +3,25 @@
|
||||
from .model_catalog import get_known_models
|
||||
|
||||
|
||||
# Providers whose model names are user-defined (local servers, relays, any
|
||||
# OpenAI-compatible endpoint), so any model string is accepted without warning.
|
||||
_ANY_MODEL_PROVIDERS = ("ollama", "openrouter", "openai_compatible")
|
||||
|
||||
VALID_MODELS = {
|
||||
provider: models
|
||||
for provider, models in get_known_models().items()
|
||||
if provider not in ("ollama", "openrouter")
|
||||
if provider not in _ANY_MODEL_PROVIDERS
|
||||
}
|
||||
|
||||
|
||||
def validate_model(provider: str, model: str) -> bool:
|
||||
"""Check if model name is valid for the given provider.
|
||||
|
||||
For ollama, openrouter - any model is accepted.
|
||||
For ollama, openrouter, and openai_compatible - any model is accepted.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("ollama", "openrouter"):
|
||||
if provider_lower in _ANY_MODEL_PROVIDERS:
|
||||
return True
|
||||
|
||||
if provider_lower not in VALID_MODELS:
|
||||
|
||||
Reference in New Issue
Block a user