mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
feat(cli): detect missing provider API keys and persist to .env
Adds a canonical PROVIDER_API_KEY_ENV mapping (14 providers including the three dual-region pairs) and an ensure_api_key() helper. When the selected provider's key is absent from the environment, the CLI prompts via questionary.password, writes the value to .env via python-dotenv's set_key (preserves existing lines), and exports it into os.environ so the run continues without restart. Wired into cli/main.py right after the region prompts so qwen-cn, glm-cn, and minimax-cn each check their own region-specific key. openai_client refactored to consult the same mapping, eliminating its private duplicate of provider→env-var data.
This commit is contained in:
49
cli/utils.py
49
cli/utils.py
@@ -1,9 +1,13 @@
|
|||||||
import questionary
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Dict
|
from typing import List, Optional, Tuple, Dict
|
||||||
|
|
||||||
|
import questionary
|
||||||
|
from dotenv import find_dotenv, set_key
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
|
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||||
from tradingagents.llm_clients.model_catalog import get_model_options
|
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
@@ -409,6 +413,49 @@ def ask_minimax_region() -> tuple[str, str]:
|
|||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_api_key(provider: str) -> Optional[str]:
|
||||||
|
"""Make sure the API key for `provider` is available in the environment.
|
||||||
|
|
||||||
|
If the env var is already set, returns its value untouched. Otherwise
|
||||||
|
interactively prompts the user, persists the value to the project's
|
||||||
|
.env file via python-dotenv's set_key (creating .env if needed), and
|
||||||
|
exports it into os.environ so the current process picks it up.
|
||||||
|
|
||||||
|
Returns None for providers that do not require a key (e.g. ollama)
|
||||||
|
and for providers not found in the canonical mapping.
|
||||||
|
"""
|
||||||
|
env_var = get_api_key_env(provider)
|
||||||
|
if env_var is None:
|
||||||
|
return None # ollama / unknown — no key check possible
|
||||||
|
|
||||||
|
existing = os.environ.get(env_var)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"\n[yellow]{env_var} is not set in your environment.[/yellow]"
|
||||||
|
)
|
||||||
|
key = questionary.password(
|
||||||
|
f"Paste your {env_var} (will be saved to .env):",
|
||||||
|
style=questionary.Style([
|
||||||
|
("text", "fg:cyan"),
|
||||||
|
("highlighted", "noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
if not key:
|
||||||
|
console.print(
|
||||||
|
f"[red]Skipped. API calls will fail until {env_var} is set.[/red]"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
env_path = find_dotenv(usecwd=True) or str(Path.cwd() / ".env")
|
||||||
|
Path(env_path).touch(exist_ok=True)
|
||||||
|
set_key(env_path, env_var, key)
|
||||||
|
os.environ[env_var] = key
|
||||||
|
console.print(f"[green]Saved {env_var} to {env_path}[/green]")
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
def ask_output_language() -> str:
|
def ask_output_language() -> str:
|
||||||
"""Ask for report output language."""
|
"""Ask for report output language."""
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|||||||
149
tests/test_api_key_env.py
Normal file
149
tests/test_api_key_env.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""Tests for the canonical provider->env-var mapping and the CLI key-prompt helper."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Mapping coverage -----------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_every_select_llm_provider_choice_has_an_entry():
|
||||||
|
"""select_llm_provider() must not present a provider the mapping doesn't know about."""
|
||||||
|
# Mirrors the dropdown order in cli/utils.select_llm_provider so the two
|
||||||
|
# stay in lockstep. Region-specific keys (qwen-cn / minimax-cn / glm-cn)
|
||||||
|
# are reached via the secondary region prompt, so they must also be present.
|
||||||
|
expected = {
|
||||||
|
"openai", "google", "anthropic", "xai", "deepseek",
|
||||||
|
"qwen", "qwen-cn",
|
||||||
|
"glm", "glm-cn",
|
||||||
|
"minimax", "minimax-cn",
|
||||||
|
"openrouter", "azure", "ollama",
|
||||||
|
}
|
||||||
|
assert expected.issubset(PROVIDER_API_KEY_ENV.keys())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider,env_var",
|
||||||
|
[
|
||||||
|
("openai", "OPENAI_API_KEY"),
|
||||||
|
("anthropic", "ANTHROPIC_API_KEY"),
|
||||||
|
("google", "GOOGLE_API_KEY"),
|
||||||
|
("azure", "AZURE_OPENAI_API_KEY"),
|
||||||
|
("xai", "XAI_API_KEY"),
|
||||||
|
("deepseek", "DEEPSEEK_API_KEY"),
|
||||||
|
("qwen", "DASHSCOPE_API_KEY"),
|
||||||
|
("qwen-cn", "DASHSCOPE_CN_API_KEY"),
|
||||||
|
("glm", "ZHIPU_API_KEY"),
|
||||||
|
("glm-cn", "ZHIPU_CN_API_KEY"),
|
||||||
|
("minimax", "MINIMAX_API_KEY"),
|
||||||
|
("minimax-cn", "MINIMAX_CN_API_KEY"),
|
||||||
|
("openrouter", "OPENROUTER_API_KEY"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_known_providers_resolve(provider, env_var):
|
||||||
|
assert get_api_key_env(provider) == env_var
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_has_no_key():
|
||||||
|
assert get_api_key_env("ollama") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_provider_returns_none():
|
||||||
|
assert get_api_key_env("not-a-real-provider") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_insensitive_lookup():
|
||||||
|
assert get_api_key_env("OpenAI") == "OPENAI_API_KEY"
|
||||||
|
assert get_api_key_env("QWEN-CN") == "DASHSCOPE_CN_API_KEY"
|
||||||
|
|
||||||
|
|
||||||
|
# ---- ensure_api_key behavior ---------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_utils(monkeypatch):
|
||||||
|
"""Import cli.utils with a fresh environment so module-level state is consistent."""
|
||||||
|
import importlib
|
||||||
|
import cli.utils as cli_utils_module
|
||||||
|
return importlib.reload(cli_utils_module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_returns_existing(monkeypatch, cli_utils):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-already-set")
|
||||||
|
result = cli_utils.ensure_api_key("openai")
|
||||||
|
assert result == "sk-already-set"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_no_op_for_ollama(monkeypatch, cli_utils):
|
||||||
|
# Even with no env var set, ollama should not prompt and should return None.
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
|
with patch.object(cli_utils, "questionary") as mock_q:
|
||||||
|
result = cli_utils.ensure_api_key("ollama")
|
||||||
|
assert result is None
|
||||||
|
mock_q.password.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_unknown_provider_no_prompt(monkeypatch, cli_utils):
|
||||||
|
with patch.object(cli_utils, "questionary") as mock_q:
|
||||||
|
result = cli_utils.ensure_api_key("totally-fake-provider")
|
||||||
|
assert result is None
|
||||||
|
mock_q.password.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_prompts_and_writes_to_env(monkeypatch, tmp_path, cli_utils):
|
||||||
|
"""When key is missing, user-pasted value must be written to .env AND os.environ."""
|
||||||
|
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-deepseek-test")})()
|
||||||
|
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||||
|
result = cli_utils.ensure_api_key("deepseek")
|
||||||
|
|
||||||
|
assert result == "sk-deepseek-test"
|
||||||
|
assert os.environ["DEEPSEEK_API_KEY"] == "sk-deepseek-test"
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
assert env_file.exists()
|
||||||
|
assert "DEEPSEEK_API_KEY" in env_file.read_text()
|
||||||
|
assert "sk-deepseek-test" in env_file.read_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_user_cancels_returns_none(monkeypatch, tmp_path, cli_utils):
|
||||||
|
"""Empty prompt response (user cancelled) must not write to .env."""
|
||||||
|
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
fake_prompt = type("P", (), {"ask": staticmethod(lambda: None)})()
|
||||||
|
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||||
|
result = cli_utils.ensure_api_key("xai")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert "XAI_API_KEY" not in os.environ
|
||||||
|
# .env may or may not exist depending on find_dotenv's walk, but if it
|
||||||
|
# does it must not contain the key.
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
if env_file.exists():
|
||||||
|
assert "XAI_API_KEY" not in env_file.read_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_api_key_updates_existing_env_file(monkeypatch, tmp_path, cli_utils):
|
||||||
|
"""An existing .env with other keys must be preserved on writeback."""
|
||||||
|
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
env_file.write_text("OPENAI_API_KEY=sk-existing\nOTHER=value\n")
|
||||||
|
|
||||||
|
fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-openrouter-new")})()
|
||||||
|
with patch.object(cli_utils.questionary, "password", return_value=fake_prompt):
|
||||||
|
cli_utils.ensure_api_key("openrouter")
|
||||||
|
|
||||||
|
content = env_file.read_text()
|
||||||
|
assert "OPENAI_API_KEY" in content and "sk-existing" in content
|
||||||
|
assert "OTHER=value" in content
|
||||||
|
assert "OPENROUTER_API_KEY" in content and "sk-openrouter-new" in content
|
||||||
44
tradingagents/llm_clients/api_key_env.py
Normal file
44
tradingagents/llm_clients/api_key_env.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Canonical provider -> API-key env-var mapping.
|
||||||
|
|
||||||
|
A single source of truth for which environment variable holds the API
|
||||||
|
key for each supported LLM provider. Used by the CLI's interactive key
|
||||||
|
prompt (cli/utils.ensure_api_key) and by anything else that needs to
|
||||||
|
ask "does this provider require a key, and which env var is it?".
|
||||||
|
|
||||||
|
When adding a new provider, register its env var here so the CLI flow
|
||||||
|
prompts for it automatically instead of failing on first API call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
||||||
|
"openai": "OPENAI_API_KEY",
|
||||||
|
"anthropic": "ANTHROPIC_API_KEY",
|
||||||
|
"google": "GOOGLE_API_KEY",
|
||||||
|
"azure": "AZURE_OPENAI_API_KEY",
|
||||||
|
"xai": "XAI_API_KEY",
|
||||||
|
"deepseek": "DEEPSEEK_API_KEY",
|
||||||
|
# Dual-region providers each carry their own account; keys are not
|
||||||
|
# interchangeable between the international and China endpoints.
|
||||||
|
"qwen": "DASHSCOPE_API_KEY",
|
||||||
|
"qwen-cn": "DASHSCOPE_CN_API_KEY",
|
||||||
|
"glm": "ZHIPU_API_KEY",
|
||||||
|
"glm-cn": "ZHIPU_CN_API_KEY",
|
||||||
|
"minimax": "MINIMAX_API_KEY",
|
||||||
|
"minimax-cn": "MINIMAX_CN_API_KEY",
|
||||||
|
"openrouter": "OPENROUTER_API_KEY",
|
||||||
|
# Local runtimes do not authenticate.
|
||||||
|
"ollama": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_key_env(provider: str) -> Optional[str]:
|
||||||
|
"""Return the env var name for `provider`'s API key, or None if not applicable.
|
||||||
|
|
||||||
|
Unknown providers also return None — callers should treat that as
|
||||||
|
"no key check possible" rather than as "no key required".
|
||||||
|
"""
|
||||||
|
return PROVIDER_API_KEY_ENV.get(provider.lower())
|
||||||
@@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from .api_key_env import get_api_key_env
|
||||||
from .base_client import BaseLLMClient, normalize_content
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
from .capabilities import get_capabilities
|
from .capabilities import get_capabilities
|
||||||
from .validators import validate_model
|
from .validators import validate_model
|
||||||
@@ -135,26 +136,22 @@ _PASSTHROUGH_KWARGS = (
|
|||||||
"api_key", "callbacks", "http_client", "http_async_client",
|
"api_key", "callbacks", "http_client", "http_async_client",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Provider base URLs and API key env vars
|
# Provider base URLs. API-key env vars live in api_key_env.PROVIDER_API_KEY_ENV
|
||||||
_PROVIDER_CONFIG = {
|
# (one canonical mapping consulted by both this client and the CLI's
|
||||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
# interactive key-prompt). Dual-region providers (qwen/glm/minimax) keep
|
||||||
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
|
# separate endpoints because international and China accounts cannot share
|
||||||
# DashScope exposes two regional endpoints with separate accounts; an
|
# credentials (#758).
|
||||||
# international key won't authenticate against the China endpoint and
|
_PROVIDER_BASE_URL = {
|
||||||
# vice versa (fixes issue #758).
|
"xai": "https://api.x.ai/v1",
|
||||||
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
|
"deepseek": "https://api.deepseek.com",
|
||||||
"qwen-cn": ("https://dashscope.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_CN_API_KEY"),
|
"qwen": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||||
# Zhipu exposes the same GLM models under two brands with separate
|
"qwen-cn": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
# accounts: Z.AI (international, api.z.ai) and BigModel
|
"glm": "https://api.z.ai/api/paas/v4/",
|
||||||
# (open.bigmodel.cn, China). Keys aren't interchangeable across them.
|
"glm-cn": "https://open.bigmodel.cn/api/paas/v4/",
|
||||||
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
|
"minimax": "https://api.minimax.io/v1",
|
||||||
"glm-cn": ("https://open.bigmodel.cn/api/paas/v4/", "ZHIPU_CN_API_KEY"),
|
"minimax-cn": "https://api.minimaxi.com/v1",
|
||||||
# MiniMax exposes two regional endpoints with separate keys; mainland
|
"openrouter": "https://openrouter.ai/api/v1",
|
||||||
# Chinese users hit .com while global users hit .io.
|
"ollama": "http://localhost:11434/v1",
|
||||||
"minimax": ("https://api.minimax.io/v1", "MINIMAX_API_KEY"),
|
|
||||||
"minimax-cn": ("https://api.minimaxi.com/v1", "MINIMAX_CN_API_KEY"),
|
|
||||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
|
||||||
"ollama": ("http://localhost:11434/v1", None),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -185,9 +182,9 @@ class OpenAIClient(BaseLLMClient):
|
|||||||
# Provider-specific base URL and auth. An explicit base_url on the
|
# Provider-specific base URL and auth. An explicit base_url on the
|
||||||
# client (e.g. a corporate proxy) takes precedence over the
|
# client (e.g. a corporate proxy) takes precedence over the
|
||||||
# provider default so users can route through their own gateway.
|
# provider default so users can route through their own gateway.
|
||||||
if self.provider in _PROVIDER_CONFIG:
|
if self.provider in _PROVIDER_BASE_URL:
|
||||||
default_base, api_key_env = _PROVIDER_CONFIG[self.provider]
|
llm_kwargs["base_url"] = self.base_url or _PROVIDER_BASE_URL[self.provider]
|
||||||
llm_kwargs["base_url"] = self.base_url or default_base
|
api_key_env = get_api_key_env(self.provider)
|
||||||
if api_key_env:
|
if api_key_env:
|
||||||
api_key = os.environ.get(api_key_env)
|
api_key = os.environ.get(api_key_env)
|
||||||
if api_key:
|
if api_key:
|
||||||
|
|||||||
Reference in New Issue
Block a user