From 9f7abfcbd576686685210f2dc6b8ec52c5d744ba Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 11 May 2026 06:12:34 +0000 Subject: [PATCH] feat(cli): detect missing provider API keys and persist to .env MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a canonical PROVIDER_API_KEY_ENV mapping (14 providers including the three dual-region pairs) and an ensure_api_key() helper. When the selected provider's key is absent from the environment, the CLI prompts via questionary.password, writes the value to .env via python-dotenv's set_key (preserves existing lines), and exports it into os.environ so the run continues without restart. Wired into cli/main.py right after the region prompts so qwen-cn, glm-cn, and minimax-cn each check their own region-specific key. openai_client refactored to consult the same mapping, eliminating its private duplicate of provider→env-var data. --- cli/utils.py | 49 ++++++- tests/test_api_key_env.py | 149 +++++++++++++++++++++ tradingagents/llm_clients/api_key_env.py | 44 ++++++ tradingagents/llm_clients/openai_client.py | 43 +++--- 4 files changed, 261 insertions(+), 24 deletions(-) create mode 100644 tests/test_api_key_env.py create mode 100644 tradingagents/llm_clients/api_key_env.py diff --git a/cli/utils.py b/cli/utils.py index 1ccc12302..5fd0b806c 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,9 +1,13 @@ -import questionary +import os +from pathlib import Path from typing import List, Optional, Tuple, Dict +import questionary +from dotenv import find_dotenv, set_key from rich.console import Console from cli.models import AnalystType +from tradingagents.llm_clients.api_key_env import get_api_key_env from tradingagents.llm_clients.model_catalog import get_model_options console = Console() @@ -409,6 +413,49 @@ def ask_minimax_region() -> tuple[str, str]: ).ask() +def ensure_api_key(provider: str) -> Optional[str]: + """Make sure the API key for `provider` is available in the environment. + + If the env var is already set, returns its value untouched. Otherwise + interactively prompts the user, persists the value to the project's + .env file via python-dotenv's set_key (creating .env if needed), and + exports it into os.environ so the current process picks it up. + + Returns None for providers that do not require a key (e.g. ollama) + and for providers not found in the canonical mapping. + """ + env_var = get_api_key_env(provider) + if env_var is None: + return None # ollama / unknown — no key check possible + + existing = os.environ.get(env_var) + if existing: + return existing + + console.print( + f"\n[yellow]{env_var} is not set in your environment.[/yellow]" + ) + key = questionary.password( + f"Paste your {env_var} (will be saved to .env):", + style=questionary.Style([ + ("text", "fg:cyan"), + ("highlighted", "noinherit"), + ]), + ).ask() + if not key: + console.print( + f"[red]Skipped. API calls will fail until {env_var} is set.[/red]" + ) + return None + + env_path = find_dotenv(usecwd=True) or str(Path.cwd() / ".env") + Path(env_path).touch(exist_ok=True) + set_key(env_path, env_var, key) + os.environ[env_var] = key + console.print(f"[green]Saved {env_var} to {env_path}[/green]") + return key + + def ask_output_language() -> str: """Ask for report output language.""" choice = questionary.select( diff --git a/tests/test_api_key_env.py b/tests/test_api_key_env.py new file mode 100644 index 000000000..dde5a4886 --- /dev/null +++ b/tests/test_api_key_env.py @@ -0,0 +1,149 @@ +"""Tests for the canonical provider->env-var mapping and the CLI key-prompt helper.""" + +from __future__ import annotations + +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from tradingagents.llm_clients.api_key_env import PROVIDER_API_KEY_ENV, get_api_key_env + + +# ---- Mapping coverage ----------------------------------------------------- + + +def test_every_select_llm_provider_choice_has_an_entry(): + """select_llm_provider() must not present a provider the mapping doesn't know about.""" + # Mirrors the dropdown order in cli/utils.select_llm_provider so the two + # stay in lockstep. Region-specific keys (qwen-cn / minimax-cn / glm-cn) + # are reached via the secondary region prompt, so they must also be present. + expected = { + "openai", "google", "anthropic", "xai", "deepseek", + "qwen", "qwen-cn", + "glm", "glm-cn", + "minimax", "minimax-cn", + "openrouter", "azure", "ollama", + } + assert expected.issubset(PROVIDER_API_KEY_ENV.keys()) + + +@pytest.mark.parametrize( + "provider,env_var", + [ + ("openai", "OPENAI_API_KEY"), + ("anthropic", "ANTHROPIC_API_KEY"), + ("google", "GOOGLE_API_KEY"), + ("azure", "AZURE_OPENAI_API_KEY"), + ("xai", "XAI_API_KEY"), + ("deepseek", "DEEPSEEK_API_KEY"), + ("qwen", "DASHSCOPE_API_KEY"), + ("qwen-cn", "DASHSCOPE_CN_API_KEY"), + ("glm", "ZHIPU_API_KEY"), + ("glm-cn", "ZHIPU_CN_API_KEY"), + ("minimax", "MINIMAX_API_KEY"), + ("minimax-cn", "MINIMAX_CN_API_KEY"), + ("openrouter", "OPENROUTER_API_KEY"), + ], +) +def test_known_providers_resolve(provider, env_var): + assert get_api_key_env(provider) == env_var + + +def test_ollama_has_no_key(): + assert get_api_key_env("ollama") is None + + +def test_unknown_provider_returns_none(): + assert get_api_key_env("not-a-real-provider") is None + + +def test_case_insensitive_lookup(): + assert get_api_key_env("OpenAI") == "OPENAI_API_KEY" + assert get_api_key_env("QWEN-CN") == "DASHSCOPE_CN_API_KEY" + + +# ---- ensure_api_key behavior --------------------------------------------- + + +@pytest.fixture +def cli_utils(monkeypatch): + """Import cli.utils with a fresh environment so module-level state is consistent.""" + import importlib + import cli.utils as cli_utils_module + return importlib.reload(cli_utils_module) + + +def test_ensure_api_key_returns_existing(monkeypatch, cli_utils): + monkeypatch.setenv("OPENAI_API_KEY", "sk-already-set") + result = cli_utils.ensure_api_key("openai") + assert result == "sk-already-set" + + +def test_ensure_api_key_no_op_for_ollama(monkeypatch, cli_utils): + # Even with no env var set, ollama should not prompt and should return None. + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with patch.object(cli_utils, "questionary") as mock_q: + result = cli_utils.ensure_api_key("ollama") + assert result is None + mock_q.password.assert_not_called() + + +def test_ensure_api_key_unknown_provider_no_prompt(monkeypatch, cli_utils): + with patch.object(cli_utils, "questionary") as mock_q: + result = cli_utils.ensure_api_key("totally-fake-provider") + assert result is None + mock_q.password.assert_not_called() + + +def test_ensure_api_key_prompts_and_writes_to_env(monkeypatch, tmp_path, cli_utils): + """When key is missing, user-pasted value must be written to .env AND os.environ.""" + monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False) + monkeypatch.chdir(tmp_path) + + fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-deepseek-test")})() + with patch.object(cli_utils.questionary, "password", return_value=fake_prompt): + result = cli_utils.ensure_api_key("deepseek") + + assert result == "sk-deepseek-test" + assert os.environ["DEEPSEEK_API_KEY"] == "sk-deepseek-test" + env_file = tmp_path / ".env" + assert env_file.exists() + assert "DEEPSEEK_API_KEY" in env_file.read_text() + assert "sk-deepseek-test" in env_file.read_text() + + +def test_ensure_api_key_user_cancels_returns_none(monkeypatch, tmp_path, cli_utils): + """Empty prompt response (user cancelled) must not write to .env.""" + monkeypatch.delenv("XAI_API_KEY", raising=False) + monkeypatch.chdir(tmp_path) + + fake_prompt = type("P", (), {"ask": staticmethod(lambda: None)})() + with patch.object(cli_utils.questionary, "password", return_value=fake_prompt): + result = cli_utils.ensure_api_key("xai") + + assert result is None + assert "XAI_API_KEY" not in os.environ + # .env may or may not exist depending on find_dotenv's walk, but if it + # does it must not contain the key. + env_file = tmp_path / ".env" + if env_file.exists(): + assert "XAI_API_KEY" not in env_file.read_text() + + +def test_ensure_api_key_updates_existing_env_file(monkeypatch, tmp_path, cli_utils): + """An existing .env with other keys must be preserved on writeback.""" + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + monkeypatch.chdir(tmp_path) + env_file = tmp_path / ".env" + env_file.write_text("OPENAI_API_KEY=sk-existing\nOTHER=value\n") + + fake_prompt = type("P", (), {"ask": staticmethod(lambda: "sk-openrouter-new")})() + with patch.object(cli_utils.questionary, "password", return_value=fake_prompt): + cli_utils.ensure_api_key("openrouter") + + content = env_file.read_text() + assert "OPENAI_API_KEY" in content and "sk-existing" in content + assert "OTHER=value" in content + assert "OPENROUTER_API_KEY" in content and "sk-openrouter-new" in content diff --git a/tradingagents/llm_clients/api_key_env.py b/tradingagents/llm_clients/api_key_env.py new file mode 100644 index 000000000..ff03d441a --- /dev/null +++ b/tradingagents/llm_clients/api_key_env.py @@ -0,0 +1,44 @@ +"""Canonical provider -> API-key env-var mapping. + +A single source of truth for which environment variable holds the API +key for each supported LLM provider. Used by the CLI's interactive key +prompt (cli/utils.ensure_api_key) and by anything else that needs to +ask "does this provider require a key, and which env var is it?". + +When adding a new provider, register its env var here so the CLI flow +prompts for it automatically instead of failing on first API call. +""" + +from __future__ import annotations + +from typing import Optional + + +PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "azure": "AZURE_OPENAI_API_KEY", + "xai": "XAI_API_KEY", + "deepseek": "DEEPSEEK_API_KEY", + # Dual-region providers each carry their own account; keys are not + # interchangeable between the international and China endpoints. + "qwen": "DASHSCOPE_API_KEY", + "qwen-cn": "DASHSCOPE_CN_API_KEY", + "glm": "ZHIPU_API_KEY", + "glm-cn": "ZHIPU_CN_API_KEY", + "minimax": "MINIMAX_API_KEY", + "minimax-cn": "MINIMAX_CN_API_KEY", + "openrouter": "OPENROUTER_API_KEY", + # Local runtimes do not authenticate. + "ollama": None, +} + + +def get_api_key_env(provider: str) -> Optional[str]: + """Return the env var name for `provider`'s API key, or None if not applicable. + + Unknown providers also return None — callers should treat that as + "no key check possible" rather than as "no key required". + """ + return PROVIDER_API_KEY_ENV.get(provider.lower()) diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 89c67e31d..771b28127 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -4,6 +4,7 @@ from typing import Any, Optional from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI +from .api_key_env import get_api_key_env from .base_client import BaseLLMClient, normalize_content from .capabilities import get_capabilities from .validators import validate_model @@ -135,26 +136,22 @@ _PASSTHROUGH_KWARGS = ( "api_key", "callbacks", "http_client", "http_async_client", ) -# Provider base URLs and API key env vars -_PROVIDER_CONFIG = { - "xai": ("https://api.x.ai/v1", "XAI_API_KEY"), - "deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"), - # DashScope exposes two regional endpoints with separate accounts; an - # international key won't authenticate against the China endpoint and - # vice versa (fixes issue #758). - "qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"), - "qwen-cn": ("https://dashscope.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_CN_API_KEY"), - # Zhipu exposes the same GLM models under two brands with separate - # accounts: Z.AI (international, api.z.ai) and BigModel - # (open.bigmodel.cn, China). Keys aren't interchangeable across them. - "glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"), - "glm-cn": ("https://open.bigmodel.cn/api/paas/v4/", "ZHIPU_CN_API_KEY"), - # MiniMax exposes two regional endpoints with separate keys; mainland - # Chinese users hit .com while global users hit .io. - "minimax": ("https://api.minimax.io/v1", "MINIMAX_API_KEY"), - "minimax-cn": ("https://api.minimaxi.com/v1", "MINIMAX_CN_API_KEY"), - "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), - "ollama": ("http://localhost:11434/v1", None), +# Provider base URLs. API-key env vars live in api_key_env.PROVIDER_API_KEY_ENV +# (one canonical mapping consulted by both this client and the CLI's +# interactive key-prompt). Dual-region providers (qwen/glm/minimax) keep +# separate endpoints because international and China accounts cannot share +# credentials (#758). +_PROVIDER_BASE_URL = { + "xai": "https://api.x.ai/v1", + "deepseek": "https://api.deepseek.com", + "qwen": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "qwen-cn": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "glm": "https://api.z.ai/api/paas/v4/", + "glm-cn": "https://open.bigmodel.cn/api/paas/v4/", + "minimax": "https://api.minimax.io/v1", + "minimax-cn": "https://api.minimaxi.com/v1", + "openrouter": "https://openrouter.ai/api/v1", + "ollama": "http://localhost:11434/v1", } @@ -185,9 +182,9 @@ class OpenAIClient(BaseLLMClient): # Provider-specific base URL and auth. An explicit base_url on the # client (e.g. a corporate proxy) takes precedence over the # provider default so users can route through their own gateway. - if self.provider in _PROVIDER_CONFIG: - default_base, api_key_env = _PROVIDER_CONFIG[self.provider] - llm_kwargs["base_url"] = self.base_url or default_base + if self.provider in _PROVIDER_BASE_URL: + llm_kwargs["base_url"] = self.base_url or _PROVIDER_BASE_URL[self.provider] + api_key_env = get_api_key_env(self.provider) if api_key_env: api_key = os.environ.get(api_key_env) if api_key: