mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-01 14:33:10 +03:00
test: lazy-load LLM provider clients and add API-key fixtures so the test suite runs cleanly without credentials (#588)
This commit is contained in:
@@ -40,3 +40,15 @@ include = ["tradingagents*", "cli*"]
|
|||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
cli = ["static/*"]
|
cli = ["static/*"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
addopts = "-ra --strict-markers"
|
||||||
|
markers = [
|
||||||
|
"unit: fast isolated unit tests",
|
||||||
|
"integration: tests requiring external services",
|
||||||
|
"smoke: quick sanity-check tests",
|
||||||
|
]
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::DeprecationWarning",
|
||||||
|
]
|
||||||
|
|||||||
42
tests/conftest.py
Normal file
42
tests/conftest.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""Shared pytest fixtures that prevent CI hangs when API keys are absent."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
for marker in ("unit", "integration", "smoke"):
|
||||||
|
config.addinivalue_line("markers", f"{marker}: {marker}-level tests")
|
||||||
|
|
||||||
|
|
||||||
|
_API_KEY_ENV_VARS = (
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"GOOGLE_API_KEY",
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"DEEPSEEK_API_KEY",
|
||||||
|
"DASHSCOPE_API_KEY",
|
||||||
|
"ZHIPU_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"AZURE_OPENAI_API_KEY",
|
||||||
|
"ALPHA_VANTAGE_API_KEY",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _dummy_api_keys(monkeypatch):
|
||||||
|
for env_var in _API_KEY_ENV_VARS:
|
||||||
|
monkeypatch.setenv(env_var, os.environ.get(env_var, "placeholder"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_llm_client():
|
||||||
|
client = MagicMock()
|
||||||
|
client.get_llm.return_value = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"tradingagents.llm_clients.factory.create_llm_client",
|
||||||
|
return_value=client,
|
||||||
|
):
|
||||||
|
yield client
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from tradingagents.llm_clients.google_client import GoogleClient
|
from tradingagents.llm_clients.google_client import GoogleClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
class TestGoogleApiKeyStandardization(unittest.TestCase):
|
class TestGoogleApiKeyStandardization(unittest.TestCase):
|
||||||
"""Verify GoogleClient accepts unified api_key parameter."""
|
"""Verify GoogleClient accepts unified api_key parameter."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from tradingagents.llm_clients.base_client import BaseLLMClient
|
from tradingagents.llm_clients.base_client import BaseLLMClient
|
||||||
from tradingagents.llm_clients.model_catalog import get_known_models
|
from tradingagents.llm_clients.model_catalog import get_known_models
|
||||||
from tradingagents.llm_clients.validators import validate_model
|
from tradingagents.llm_clients.validators import validate_model
|
||||||
@@ -19,6 +21,7 @@ class DummyLLMClient(BaseLLMClient):
|
|||||||
return validate_model(self.provider, self.model)
|
return validate_model(self.provider, self.model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
class ModelValidationTests(unittest.TestCase):
|
class ModelValidationTests(unittest.TestCase):
|
||||||
def test_cli_catalog_models_are_all_validator_approved(self):
|
def test_cli_catalog_models_are_all_validator_approved(self):
|
||||||
for provider, models in get_known_models().items():
|
for provider, models in get_known_models().items():
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from cli.utils import normalize_ticker_symbol
|
from cli.utils import normalize_ticker_symbol
|
||||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
class TickerSymbolHandlingTests(unittest.TestCase):
|
class TickerSymbolHandlingTests(unittest.TestCase):
|
||||||
def test_normalize_ticker_symbol_preserves_exchange_suffix(self):
|
def test_normalize_ticker_symbol_preserves_exchange_suffix(self):
|
||||||
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")
|
self.assertEqual(normalize_ticker_symbol(" cnc.to "), "CNC.TO")
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base_client import BaseLLMClient
|
from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
|
||||||
from .anthropic_client import AnthropicClient
|
|
||||||
from .google_client import GoogleClient
|
|
||||||
from .azure_client import AzureOpenAIClient
|
|
||||||
|
|
||||||
# Providers that use the OpenAI-compatible chat completions API
|
# Providers that use the OpenAI-compatible chat completions API
|
||||||
_OPENAI_COMPATIBLE = (
|
_OPENAI_COMPATIBLE = (
|
||||||
@@ -20,6 +16,10 @@ def create_llm_client(
|
|||||||
) -> BaseLLMClient:
|
) -> BaseLLMClient:
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
|
Provider modules are imported lazily so that simply importing this
|
||||||
|
factory (e.g. during test collection) does not pull in heavy LLM SDKs
|
||||||
|
or fail when their API keys are absent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider name
|
provider: LLM provider name
|
||||||
model: Model name/identifier
|
model: Model name/identifier
|
||||||
@@ -35,15 +35,19 @@ def create_llm_client(
|
|||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in _OPENAI_COMPATIBLE:
|
if provider_lower in _OPENAI_COMPATIBLE:
|
||||||
|
from .openai_client import OpenAIClient
|
||||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "anthropic":
|
if provider_lower == "anthropic":
|
||||||
|
from .anthropic_client import AnthropicClient
|
||||||
return AnthropicClient(model, base_url, **kwargs)
|
return AnthropicClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
|
from .google_client import GoogleClient
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "azure":
|
if provider_lower == "azure":
|
||||||
|
from .azure_client import AzureOpenAIClient
|
||||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|||||||
Reference in New Issue
Block a user