From 895ed130f9df5a63889610c9e77af39567bd0a54 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Sun, 14 Jun 2026 04:24:54 +0000 Subject: [PATCH] feat(llm): add Amazon Bedrock as a first-class provider Bedrock uses the Converse API (langchain-aws) and the AWS credential chain, so it has its own client like Anthropic/Google rather than the OpenAI-compatible registry. langchain-aws is an optional dependency (pip install ".[bedrock]"), lazy-imported with a clear install hint; importing the package never requires it. The model name is a Bedrock model ID / inference profile ID. --- .env.example | 6 ++ README.md | 4 +- cli/utils.py | 1 + pyproject.toml | 5 ++ tests/test_bedrock_provider.py | 46 ++++++++++++++ tradingagents/llm_clients/api_key_env.py | 2 + tradingagents/llm_clients/bedrock_client.py | 69 +++++++++++++++++++++ tradingagents/llm_clients/factory.py | 4 ++ tradingagents/llm_clients/model_catalog.py | 2 + tradingagents/llm_clients/validators.py | 2 +- 10 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 tests/test_bedrock_provider.py create mode 100644 tradingagents/llm_clients/bedrock_client.py diff --git a/.env.example b/.env.example index aad73c9ad..1aab52424 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,12 @@ NVIDIA_API_KEY= # optional (local servers need none). #OPENAI_COMPATIBLE_API_KEY= +# AWS Bedrock (provider "bedrock", install with: pip install ".[bedrock]"). +# Auth uses the standard AWS credential chain; set the region (and optionally a +# named profile). No single API key. +#AWS_DEFAULT_REGION=us-west-2 +#AWS_PROFILE= + # Optional: point at a remote Ollama server. When unset, defaults to # the local instance at http://localhost:11434/v1. Convention follows # the broader Ollama ecosystem; both the CLI dropdown and programmatic diff --git a/README.md b/README.md index 20147d101..a047ae6ad 100644 --- a/README.md +++ b/README.md @@ -153,7 +153,9 @@ export OPENROUTER_API_KEY=... # OpenRouter export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage ``` -For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials. +For Azure OpenAI, copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials. + +For AWS Bedrock, install the extra with `pip install ".[bedrock]"`, set `llm_provider: "bedrock"`, configure AWS credentials (environment variables, `~/.aws/credentials`, or an IAM role) and `AWS_DEFAULT_REGION`, and use a Bedrock model ID, e.g. `us.anthropic.claude-opus-4-8-v1:0`. For local models, configure Ollama with `llm_provider: "ollama"`. The default endpoint is `http://localhost:11434/v1`; set `OLLAMA_BASE_URL` to point at a remote `ollama-serve`. Pull models with `ollama pull `, and pick "Custom model ID" in the CLI for any model not listed by default. diff --git a/cli/utils.py b/cli/utils.py index 3ee2d25c7..4583b6107 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -316,6 +316,7 @@ def _llm_provider_table() -> list[tuple[str, str, str | None]]: ("Groq", "groq", "https://api.groq.com/openai/v1"), ("NVIDIA NIM", "nvidia", "https://integrate.api.nvidia.com/v1"), ("Azure OpenAI", "azure", None), + ("Amazon Bedrock", "bedrock", None), ("Ollama", "ollama", ollama_url), ("OpenAI-compatible (vLLM, LM Studio, llama.cpp, custom relay)", "openai_compatible", None), ] diff --git a/pyproject.toml b/pyproject.toml index 40399819f..35c0c1d1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,11 @@ dev = [ "pytest>=8.0", "pytest-subtests>=0.13", ] +# Amazon Bedrock support (AWS SigV4 auth + boto3). Optional so the core install +# stays lean: pip install "tradingagents[bedrock]". +bedrock = [ + "langchain-aws>=1.5.0", +] [project.scripts] tradingagents = "cli.main:app" diff --git a/tests/test_bedrock_provider.py b/tests/test_bedrock_provider.py new file mode 100644 index 000000000..c868e9806 --- /dev/null +++ b/tests/test_bedrock_provider.py @@ -0,0 +1,46 @@ +"""Amazon Bedrock — first-class native client via the optional langchain-aws extra. + +Auth uses the AWS credential chain (no single key env); the model is a Bedrock +model ID / inference profile ID; langchain-aws is imported lazily with a clear +install hint when the [bedrock] extra is absent. +""" +import sys + +import pytest + +from tradingagents.llm_clients.api_key_env import get_api_key_env +from tradingagents.llm_clients.factory import create_llm_client +from tradingagents.llm_clients.validators import validate_model + + +@pytest.mark.unit +def test_factory_routes_bedrock(): + client = create_llm_client("bedrock", "us.anthropic.claude-opus-4-8-v1:0") + assert type(client).__name__ == "BedrockClient" + + +@pytest.mark.unit +def test_bedrock_any_model_and_no_key_env(): + assert validate_model("bedrock", "any.model-id:0") is True + # Bedrock uses the AWS credential chain, so there is no single key env. + assert get_api_key_env("bedrock") is None + + +@pytest.mark.unit +def test_helpful_error_when_langchain_aws_absent(monkeypatch): + import tradingagents.llm_clients.bedrock_client as bc + monkeypatch.setattr(bc, "_BEDROCK_CLASS", None) + monkeypatch.setitem(sys.modules, "langchain_aws", None) # force ImportError on import + with pytest.raises(ImportError, match=r"bedrock"): + create_llm_client("bedrock", "m").get_llm() + + +@pytest.mark.unit +def test_construction_when_extra_installed(monkeypatch): + pytest.importorskip("langchain_aws") + import tradingagents.llm_clients.bedrock_client as bc + monkeypatch.setattr(bc, "_BEDROCK_CLASS", None) + monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1") + llm = create_llm_client("bedrock", "us.anthropic.claude-sonnet-4-6-v1:0").get_llm() + assert type(llm).__name__ == "NormalizedChatBedrockConverse" + assert llm.region_name == "eu-west-1" diff --git a/tradingagents/llm_clients/api_key_env.py b/tradingagents/llm_clients/api_key_env.py index 4d909c400..8521e1602 100644 --- a/tradingagents/llm_clients/api_key_env.py +++ b/tradingagents/llm_clients/api_key_env.py @@ -19,6 +19,8 @@ PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = { "anthropic": "ANTHROPIC_API_KEY", "google": "GOOGLE_API_KEY", "azure": "AZURE_OPENAI_API_KEY", + # Bedrock authenticates via the AWS credential chain, not a single key env. + "bedrock": None, "xai": "XAI_API_KEY", "deepseek": "DEEPSEEK_API_KEY", # Dual-region providers each carry their own account; keys are not diff --git a/tradingagents/llm_clients/bedrock_client.py b/tradingagents/llm_clients/bedrock_client.py new file mode 100644 index 000000000..2ddb48dd3 --- /dev/null +++ b/tradingagents/llm_clients/bedrock_client.py @@ -0,0 +1,69 @@ +import os +from typing import Any + +from .base_client import BaseLLMClient, normalize_content +from .validators import validate_model + +# Bedrock has no global default region; us-west-2 hosts the broadest model set. +_DEFAULT_REGION = "us-west-2" +_BEDROCK_CLASS = None + + +def _bedrock_class(): + """Lazily import langchain-aws (the optional ``[bedrock]`` extra) and return a + ChatBedrockConverse subclass with normalized content output. + + Imported on demand so the optional dependency (and boto3) isn't required by + the rest of the package; cached after the first call. + """ + global _BEDROCK_CLASS + if _BEDROCK_CLASS is not None: + return _BEDROCK_CLASS + + try: + from langchain_aws import ChatBedrockConverse + except ImportError as exc: + raise ImportError( + "AWS Bedrock support requires the optional 'langchain-aws' dependency. " + 'Install it with: pip install "tradingagents[bedrock]"' + ) from exc + + class NormalizedChatBedrockConverse(ChatBedrockConverse): + """ChatBedrockConverse with normalized (string) content output.""" + + def invoke(self, input, config=None, **kwargs): + return normalize_content(super().invoke(input, config, **kwargs)) + + _BEDROCK_CLASS = NormalizedChatBedrockConverse + return _BEDROCK_CLASS + + +class BedrockClient(BaseLLMClient): + """Client for Amazon Bedrock via the Converse API (langchain-aws). + + Authentication uses the standard AWS credential chain (env vars, + ``~/.aws/credentials``, or an IAM role); set ``AWS_REGION`` / + ``AWS_DEFAULT_REGION`` and optionally ``AWS_PROFILE``. The model name is a + Bedrock model ID or cross-region inference profile ID, e.g. + ``us.anthropic.claude-opus-4-8-v1:0``. + """ + + def get_llm(self) -> Any: + """Return a configured ChatBedrockConverse instance.""" + self.warn_if_unknown_model() + chat_cls = _bedrock_class() + + region = ( + os.environ.get("AWS_REGION") + or os.environ.get("AWS_DEFAULT_REGION") + or _DEFAULT_REGION + ) + llm_kwargs = {"model": self.model, "region_name": region} + for key in ("temperature", "max_tokens", "max_retries", "callbacks"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + return chat_cls(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate model for Bedrock (any model ID accepted).""" + return validate_model("bedrock", self.model) diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index d61d48514..d0dbdd63c 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -43,6 +43,10 @@ def create_llm_client( from .azure_client import AzureOpenAIClient return AzureOpenAIClient(model, base_url, **kwargs) + if provider_lower == "bedrock": + from .bedrock_client import BedrockClient + return BedrockClient(model, base_url, **kwargs) + from .openai_client import OpenAIClient, is_openai_compatible if is_openai_compatible(provider_lower): return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index f1fc45b09..418c898a2 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -193,6 +193,8 @@ MODEL_OPTIONS: ProviderModeOptions = { "kimi": _CUSTOM_ONLY, "groq": _CUSTOM_ONLY, "nvidia": _CUSTOM_ONLY, + # Bedrock model IDs / cross-region inference profile IDs are user-specified. + "bedrock": _CUSTOM_ONLY, } diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index bed09a51a..214cf1ee0 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -8,7 +8,7 @@ from .model_catalog import get_known_models # accepted without warning. _ANY_MODEL_PROVIDERS = ( "ollama", "openrouter", "openai_compatible", - "mistral", "kimi", "groq", "nvidia", + "mistral", "kimi", "groq", "nvidia", "bedrock", ) VALID_MODELS = {