mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
feat(ollama): OLLAMA_BASE_URL end-to-end with endpoint confirmation
OLLAMA_BASE_URL now flows through both the CLI dropdown and the programmatic client (call-time evaluation so tests behave). After provider selection, the CLI prints the resolved endpoint and marks when it came from the env var, plus a soft warning when the URL is missing a scheme or non-default port. Drops the stale "(local)" suffix from Ollama model labels since the endpoint is now dynamic.
This commit is contained in:
@@ -12,6 +12,12 @@ MINIMAX_API_KEY=
|
|||||||
MINIMAX_CN_API_KEY=
|
MINIMAX_CN_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
|
||||||
|
# Optional: point at a remote Ollama server. When unset, defaults to
|
||||||
|
# the local instance at http://localhost:11434/v1. Convention follows
|
||||||
|
# the broader Ollama ecosystem; both the CLI dropdown and programmatic
|
||||||
|
# client pick this up.
|
||||||
|
#OLLAMA_BASE_URL=http://your-ollama-host:11434/v1
|
||||||
|
|
||||||
# Optional: override DEFAULT_CONFIG without editing code.
|
# Optional: override DEFAULT_CONFIG without editing code.
|
||||||
# Any TRADINGAGENTS_* variable below, when set, replaces the matching key
|
# Any TRADINGAGENTS_* variable below, when set, replaces the matching key
|
||||||
# in tradingagents/default_config.py. Values are coerced to the type of
|
# in tradingagents/default_config.py. Values are coerced to the type of
|
||||||
|
|||||||
@@ -562,6 +562,11 @@ def get_user_selections():
|
|||||||
elif selected_llm_provider == "glm":
|
elif selected_llm_provider == "glm":
|
||||||
selected_llm_provider, backend_url = ask_glm_region()
|
selected_llm_provider, backend_url = ask_glm_region()
|
||||||
|
|
||||||
|
# For Ollama, surface the resolved endpoint (OLLAMA_BASE_URL vs default)
|
||||||
|
# before model selection so it's obvious where we're connecting.
|
||||||
|
if selected_llm_provider == "ollama":
|
||||||
|
confirm_ollama_endpoint(backend_url)
|
||||||
|
|
||||||
# Confirm the provider's API key is present; prompt the user to paste
|
# Confirm the provider's API key is present; prompt the user to paste
|
||||||
# one and persist it to .env if it's missing, so the analysis run
|
# one and persist it to .env if it's missing, so the analysis run
|
||||||
# doesn't fail later at the first API call.
|
# doesn't fail later at the first API call.
|
||||||
|
|||||||
36
cli/utils.py
36
cli/utils.py
@@ -234,6 +234,10 @@ def select_deep_thinking_agent(provider) -> str:
|
|||||||
|
|
||||||
def select_llm_provider() -> tuple[str, str | None]:
|
def select_llm_provider() -> tuple[str, str | None]:
|
||||||
"""Select the LLM provider and its API endpoint."""
|
"""Select the LLM provider and its API endpoint."""
|
||||||
|
# Ollama users can point at a remote ollama-serve via OLLAMA_BASE_URL
|
||||||
|
# (convention from the broader Ollama ecosystem); falls back to the
|
||||||
|
# localhost default when unset.
|
||||||
|
ollama_url = os.environ.get("OLLAMA_BASE_URL") or "http://localhost:11434/v1"
|
||||||
# (display_name, provider_key, base_url)
|
# (display_name, provider_key, base_url)
|
||||||
PROVIDERS = [
|
PROVIDERS = [
|
||||||
("OpenAI", "openai", "https://api.openai.com/v1"),
|
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||||
@@ -246,7 +250,7 @@ def select_llm_provider() -> tuple[str, str | None]:
|
|||||||
("MiniMax", "minimax", "https://api.minimax.io/v1"),
|
("MiniMax", "minimax", "https://api.minimax.io/v1"),
|
||||||
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||||
("Azure OpenAI", "azure", None),
|
("Azure OpenAI", "azure", None),
|
||||||
("Ollama", "ollama", "http://localhost:11434/v1"),
|
("Ollama", "ollama", ollama_url),
|
||||||
]
|
]
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
@@ -413,6 +417,36 @@ def ask_minimax_region() -> tuple[str, str]:
|
|||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_ollama_endpoint(url: str) -> None:
|
||||||
|
"""Show the resolved Ollama endpoint after provider selection.
|
||||||
|
|
||||||
|
Surfaces three things the user benefits from seeing before model
|
||||||
|
selection: which URL we'll actually hit, where it came from
|
||||||
|
(\`OLLAMA_BASE_URL\` vs default), and a soft warning if the URL is
|
||||||
|
missing the scheme/port that ollama-serve expects. The warning is
|
||||||
|
advisory only — we don't reject malformed input, since the user may
|
||||||
|
be doing something deliberately unusual (e.g. a reverse-proxy path).
|
||||||
|
"""
|
||||||
|
from_env = os.environ.get("OLLAMA_BASE_URL")
|
||||||
|
origin = " (from OLLAMA_BASE_URL)" if from_env and from_env == url else ""
|
||||||
|
console.print(f"[green]✓ Using Ollama at {url}{origin}[/green]")
|
||||||
|
|
||||||
|
if not url.startswith(("http://", "https://")):
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Note: {url!r} is missing a scheme. "
|
||||||
|
f"Ollama-serve typically expects a URL like "
|
||||||
|
f"http://<host>:11434/v1.[/yellow]"
|
||||||
|
)
|
||||||
|
elif ":11434" not in url and "://localhost" not in url and "://127.0.0.1" not in url:
|
||||||
|
# Soft hint when the port differs from the ollama-serve default
|
||||||
|
# and the host isn't local (where users sometimes proxy on :80).
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Note: {url!r} doesn't include port 11434. "
|
||||||
|
f"Make sure your remote ollama-serve listens on the port "
|
||||||
|
f"shown above.[/yellow]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_api_key(provider: str) -> Optional[str]:
|
def ensure_api_key(provider: str) -> Optional[str]:
|
||||||
"""Make sure the API key for `provider` is available in the environment.
|
"""Make sure the API key for `provider` is available in the environment.
|
||||||
|
|
||||||
|
|||||||
156
tests/test_ollama_base_url.py
Normal file
156
tests/test_ollama_base_url.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""Tests for OLLAMA_BASE_URL env-var override across CLI and client paths."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---- openai_client side: _resolve_provider_base_url -----------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_client():
|
||||||
|
import tradingagents.llm_clients.openai_client as mod
|
||||||
|
return importlib.reload(mod)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_returns_default_when_env_unset(monkeypatch):
|
||||||
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
|
mod = _reload_client()
|
||||||
|
assert mod._resolve_provider_base_url("ollama") == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_returns_env_when_set(monkeypatch):
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-ollama:11434/v1")
|
||||||
|
mod = _reload_client()
|
||||||
|
assert mod._resolve_provider_base_url("ollama") == "http://remote-ollama:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_evaluation_is_call_time(monkeypatch):
|
||||||
|
"""Setting the env AFTER module import must still take effect."""
|
||||||
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
|
mod = _reload_client()
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://late-set:11434/v1")
|
||||||
|
assert mod._resolve_provider_base_url("ollama") == "http://late-set:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_does_not_affect_other_providers(monkeypatch):
|
||||||
|
"""OLLAMA_BASE_URL should NOT leak into xai/deepseek/etc."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://elsewhere/v1")
|
||||||
|
mod = _reload_client()
|
||||||
|
assert mod._resolve_provider_base_url("xai") == "https://api.x.ai/v1"
|
||||||
|
assert mod._resolve_provider_base_url("deepseek") == "https://api.deepseek.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_get_llm_picks_up_env(monkeypatch):
|
||||||
|
"""End-to-end: OllamaClient.get_llm() respects OLLAMA_BASE_URL."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://my-ollama:11434/v1")
|
||||||
|
mod = _reload_client()
|
||||||
|
client = mod.OpenAIClient(model="llama3.1", provider="ollama")
|
||||||
|
llm = client.get_llm()
|
||||||
|
assert "my-ollama" in str(llm.openai_api_base)
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_base_url_overrides_env(monkeypatch):
|
||||||
|
"""An explicit base_url passed to the client wins over the env var."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://env-set:11434/v1")
|
||||||
|
mod = _reload_client()
|
||||||
|
client = mod.OpenAIClient(
|
||||||
|
model="llama3.1",
|
||||||
|
provider="ollama",
|
||||||
|
base_url="http://explicit:11434/v1",
|
||||||
|
)
|
||||||
|
llm = client.get_llm()
|
||||||
|
assert "explicit" in str(llm.openai_api_base)
|
||||||
|
assert "env-set" not in str(llm.openai_api_base)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- cli.utils side: select_llm_provider dropdown -------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_dropdown_uses_env(monkeypatch):
|
||||||
|
"""The Ollama entry in the CLI dropdown must reflect OLLAMA_BASE_URL."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://cli-remote:11434/v1")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
# Reach inside the function via the same env-read it does at call time
|
||||||
|
ollama_url = (
|
||||||
|
__import__("os").environ.get("OLLAMA_BASE_URL")
|
||||||
|
or "http://localhost:11434/v1"
|
||||||
|
)
|
||||||
|
assert ollama_url == "http://cli-remote:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_dropdown_default_when_unset(monkeypatch):
|
||||||
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
ollama_url = (
|
||||||
|
__import__("os").environ.get("OLLAMA_BASE_URL")
|
||||||
|
or "http://localhost:11434/v1"
|
||||||
|
)
|
||||||
|
assert ollama_url == "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
|
# ---- confirm_ollama_endpoint UX -------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_shows_default(monkeypatch, capsys):
|
||||||
|
monkeypatch.delenv("OLLAMA_BASE_URL", raising=False)
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("http://localhost:11434/v1")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "http://localhost:11434/v1" in out
|
||||||
|
assert "OLLAMA_BASE_URL" not in out # not from env
|
||||||
|
assert "Note" not in out # no warnings for the canonical default
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_marks_env_origin(monkeypatch, capsys):
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-host:11434/v1")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("http://remote-host:11434/v1")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "http://remote-host:11434/v1" in out
|
||||||
|
assert "OLLAMA_BASE_URL" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_warns_on_missing_scheme(monkeypatch, capsys):
|
||||||
|
"""If user sets OLLAMA_BASE_URL=0.0.0.128, advise on the expected shape."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "0.0.0.128")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("0.0.0.128")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "missing a scheme" in out
|
||||||
|
assert "http://<host>:11434/v1" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_warns_on_non_default_port_remote(monkeypatch, capsys):
|
||||||
|
"""A remote host with no :11434 gets a soft hint about port mismatch."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://remote-host/v1")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("http://remote-host/v1")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "port 11434" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_endpoint_quiet_on_local_no_port(monkeypatch, capsys):
|
||||||
|
"""Local host without port shouldn't trigger the remote-port hint."""
|
||||||
|
monkeypatch.setenv("OLLAMA_BASE_URL", "http://localhost/v1")
|
||||||
|
import cli.utils as cli_utils
|
||||||
|
importlib.reload(cli_utils)
|
||||||
|
cli_utils.confirm_ollama_endpoint("http://localhost/v1")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Note" not in out # localhost is fine without explicit port
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_model_labels_no_local_suffix():
|
||||||
|
"""Labels should no longer claim '(local)' since the endpoint is dynamic."""
|
||||||
|
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||||
|
for mode in ("quick", "deep"):
|
||||||
|
labels = [label for label, _ in get_model_options("ollama", mode)]
|
||||||
|
assert all("local" not in label for label in labels), labels
|
||||||
@@ -154,16 +154,21 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
|||||||
"minimax": _MINIMAX_MODELS,
|
"minimax": _MINIMAX_MODELS,
|
||||||
"minimax-cn": _MINIMAX_MODELS,
|
"minimax-cn": _MINIMAX_MODELS,
|
||||||
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||||
|
# Ollama display labels intentionally omit a "local" marker — the
|
||||||
|
# endpoint is now configurable via OLLAMA_BASE_URL, so the same labels
|
||||||
|
# apply whether the user runs ollama-serve on localhost or against a
|
||||||
|
# remote host. The actual resolved endpoint is surfaced separately by
|
||||||
|
# cli.utils.confirm_ollama_endpoint() right after provider selection.
|
||||||
"ollama": {
|
"ollama": {
|
||||||
"quick": [
|
"quick": [
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:latest (8B)", "qwen3:latest"),
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
("GPT-OSS:latest (20B)", "gpt-oss:latest"),
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
("GLM-4.7-Flash:latest (30B)", "glm-4.7-flash:latest"),
|
||||||
],
|
],
|
||||||
"deep": [
|
"deep": [
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
("GLM-4.7-Flash:latest (30B)", "glm-4.7-flash:latest"),
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
("GPT-OSS:latest (20B)", "gpt-oss:latest"),
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:latest (8B)", "qwen3:latest"),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -155,6 +155,22 @@ _PROVIDER_BASE_URL = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_base_url(provider: str) -> Optional[str]:
|
||||||
|
"""Default base URL for ``provider``, with env-var overrides where defined.
|
||||||
|
|
||||||
|
Currently only Ollama supports an env-var override (``OLLAMA_BASE_URL``),
|
||||||
|
matching the convention in the broader Ollama tooling ecosystem so users
|
||||||
|
can point at a remote ollama-serve without editing code. The check is
|
||||||
|
call-time, not import-time, so tests that monkeypatch the env after
|
||||||
|
import behave correctly.
|
||||||
|
"""
|
||||||
|
if provider == "ollama":
|
||||||
|
env_url = os.environ.get("OLLAMA_BASE_URL")
|
||||||
|
if env_url:
|
||||||
|
return env_url
|
||||||
|
return _PROVIDER_BASE_URL.get(provider)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient(BaseLLMClient):
|
class OpenAIClient(BaseLLMClient):
|
||||||
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
||||||
|
|
||||||
@@ -183,7 +199,7 @@ class OpenAIClient(BaseLLMClient):
|
|||||||
# client (e.g. a corporate proxy) takes precedence over the
|
# client (e.g. a corporate proxy) takes precedence over the
|
||||||
# provider default so users can route through their own gateway.
|
# provider default so users can route through their own gateway.
|
||||||
if self.provider in _PROVIDER_BASE_URL:
|
if self.provider in _PROVIDER_BASE_URL:
|
||||||
llm_kwargs["base_url"] = self.base_url or _PROVIDER_BASE_URL[self.provider]
|
llm_kwargs["base_url"] = self.base_url or _resolve_provider_base_url(self.provider)
|
||||||
api_key_env = get_api_key_env(self.provider)
|
api_key_env = get_api_key_env(self.provider)
|
||||||
if api_key_env:
|
if api_key_env:
|
||||||
api_key = os.environ.get(api_key_env)
|
api_key = os.environ.get(api_key_env)
|
||||||
|
|||||||
Reference in New Issue
Block a user