mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
feat(llm): add Amazon Bedrock as a first-class provider
Bedrock uses the Converse API (langchain-aws) and the AWS credential chain, so it has its own client like Anthropic/Google rather than the OpenAI-compatible registry. langchain-aws is an optional dependency (pip install ".[bedrock]"), lazy-imported with a clear install hint; importing the package never requires it. The model name is a Bedrock model ID / inference profile ID.
This commit is contained in:
@@ -21,6 +21,12 @@ NVIDIA_API_KEY=
|
|||||||
# optional (local servers need none).
|
# optional (local servers need none).
|
||||||
#OPENAI_COMPATIBLE_API_KEY=
|
#OPENAI_COMPATIBLE_API_KEY=
|
||||||
|
|
||||||
|
# AWS Bedrock (provider "bedrock", install with: pip install ".[bedrock]").
|
||||||
|
# Auth uses the standard AWS credential chain; set the region (and optionally a
|
||||||
|
# named profile). No single API key.
|
||||||
|
#AWS_DEFAULT_REGION=us-west-2
|
||||||
|
#AWS_PROFILE=
|
||||||
|
|
||||||
# Optional: point at a remote Ollama server. When unset, defaults to
|
# Optional: point at a remote Ollama server. When unset, defaults to
|
||||||
# the local instance at http://localhost:11434/v1. Convention follows
|
# the local instance at http://localhost:11434/v1. Convention follows
|
||||||
# the broader Ollama ecosystem; both the CLI dropdown and programmatic
|
# the broader Ollama ecosystem; both the CLI dropdown and programmatic
|
||||||
|
|||||||
@@ -153,7 +153,9 @@ export OPENROUTER_API_KEY=... # OpenRouter
|
|||||||
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||||
```
|
```
|
||||||
|
|
||||||
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
For Azure OpenAI, copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
||||||
|
|
||||||
|
For AWS Bedrock, install the extra with `pip install ".[bedrock]"`, set `llm_provider: "bedrock"`, configure AWS credentials (environment variables, `~/.aws/credentials`, or an IAM role) and `AWS_DEFAULT_REGION`, and use a Bedrock model ID, e.g. `us.anthropic.claude-opus-4-8-v1:0`.
|
||||||
|
|
||||||
For local models, configure Ollama with `llm_provider: "ollama"`. The default endpoint is `http://localhost:11434/v1`; set `OLLAMA_BASE_URL` to point at a remote `ollama-serve`. Pull models with `ollama pull <name>`, and pick "Custom model ID" in the CLI for any model not listed by default.
|
For local models, configure Ollama with `llm_provider: "ollama"`. The default endpoint is `http://localhost:11434/v1`; set `OLLAMA_BASE_URL` to point at a remote `ollama-serve`. Pull models with `ollama pull <name>`, and pick "Custom model ID" in the CLI for any model not listed by default.
|
||||||
|
|
||||||
|
|||||||
@@ -316,6 +316,7 @@ def _llm_provider_table() -> list[tuple[str, str, str | None]]:
|
|||||||
("Groq", "groq", "https://api.groq.com/openai/v1"),
|
("Groq", "groq", "https://api.groq.com/openai/v1"),
|
||||||
("NVIDIA NIM", "nvidia", "https://integrate.api.nvidia.com/v1"),
|
("NVIDIA NIM", "nvidia", "https://integrate.api.nvidia.com/v1"),
|
||||||
("Azure OpenAI", "azure", None),
|
("Azure OpenAI", "azure", None),
|
||||||
|
("Amazon Bedrock", "bedrock", None),
|
||||||
("Ollama", "ollama", ollama_url),
|
("Ollama", "ollama", ollama_url),
|
||||||
("OpenAI-compatible (vLLM, LM Studio, llama.cpp, custom relay)", "openai_compatible", None),
|
("OpenAI-compatible (vLLM, LM Studio, llama.cpp, custom relay)", "openai_compatible", None),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -39,6 +39,11 @@ dev = [
|
|||||||
"pytest>=8.0",
|
"pytest>=8.0",
|
||||||
"pytest-subtests>=0.13",
|
"pytest-subtests>=0.13",
|
||||||
]
|
]
|
||||||
|
# Amazon Bedrock support (AWS SigV4 auth + boto3). Optional so the core install
|
||||||
|
# stays lean: pip install "tradingagents[bedrock]".
|
||||||
|
bedrock = [
|
||||||
|
"langchain-aws>=1.5.0",
|
||||||
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
tradingagents = "cli.main:app"
|
tradingagents = "cli.main:app"
|
||||||
|
|||||||
46
tests/test_bedrock_provider.py
Normal file
46
tests/test_bedrock_provider.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Amazon Bedrock — first-class native client via the optional langchain-aws extra.
|
||||||
|
|
||||||
|
Auth uses the AWS credential chain (no single key env); the model is a Bedrock
|
||||||
|
model ID / inference profile ID; langchain-aws is imported lazily with a clear
|
||||||
|
install hint when the [bedrock] extra is absent.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.api_key_env import get_api_key_env
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
from tradingagents.llm_clients.validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_factory_routes_bedrock():
|
||||||
|
client = create_llm_client("bedrock", "us.anthropic.claude-opus-4-8-v1:0")
|
||||||
|
assert type(client).__name__ == "BedrockClient"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_bedrock_any_model_and_no_key_env():
|
||||||
|
assert validate_model("bedrock", "any.model-id:0") is True
|
||||||
|
# Bedrock uses the AWS credential chain, so there is no single key env.
|
||||||
|
assert get_api_key_env("bedrock") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_helpful_error_when_langchain_aws_absent(monkeypatch):
|
||||||
|
import tradingagents.llm_clients.bedrock_client as bc
|
||||||
|
monkeypatch.setattr(bc, "_BEDROCK_CLASS", None)
|
||||||
|
monkeypatch.setitem(sys.modules, "langchain_aws", None) # force ImportError on import
|
||||||
|
with pytest.raises(ImportError, match=r"bedrock"):
|
||||||
|
create_llm_client("bedrock", "m").get_llm()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_construction_when_extra_installed(monkeypatch):
|
||||||
|
pytest.importorskip("langchain_aws")
|
||||||
|
import tradingagents.llm_clients.bedrock_client as bc
|
||||||
|
monkeypatch.setattr(bc, "_BEDROCK_CLASS", None)
|
||||||
|
monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1")
|
||||||
|
llm = create_llm_client("bedrock", "us.anthropic.claude-sonnet-4-6-v1:0").get_llm()
|
||||||
|
assert type(llm).__name__ == "NormalizedChatBedrockConverse"
|
||||||
|
assert llm.region_name == "eu-west-1"
|
||||||
@@ -19,6 +19,8 @@ PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
|||||||
"anthropic": "ANTHROPIC_API_KEY",
|
"anthropic": "ANTHROPIC_API_KEY",
|
||||||
"google": "GOOGLE_API_KEY",
|
"google": "GOOGLE_API_KEY",
|
||||||
"azure": "AZURE_OPENAI_API_KEY",
|
"azure": "AZURE_OPENAI_API_KEY",
|
||||||
|
# Bedrock authenticates via the AWS credential chain, not a single key env.
|
||||||
|
"bedrock": None,
|
||||||
"xai": "XAI_API_KEY",
|
"xai": "XAI_API_KEY",
|
||||||
"deepseek": "DEEPSEEK_API_KEY",
|
"deepseek": "DEEPSEEK_API_KEY",
|
||||||
# Dual-region providers each carry their own account; keys are not
|
# Dual-region providers each carry their own account; keys are not
|
||||||
|
|||||||
69
tradingagents/llm_clients/bedrock_client.py
Normal file
69
tradingagents/llm_clients/bedrock_client.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
from .validators import validate_model
|
||||||
|
|
||||||
|
# Bedrock has no global default region; us-west-2 hosts the broadest model set.
|
||||||
|
_DEFAULT_REGION = "us-west-2"
|
||||||
|
_BEDROCK_CLASS = None
|
||||||
|
|
||||||
|
|
||||||
|
def _bedrock_class():
|
||||||
|
"""Lazily import langchain-aws (the optional ``[bedrock]`` extra) and return a
|
||||||
|
ChatBedrockConverse subclass with normalized content output.
|
||||||
|
|
||||||
|
Imported on demand so the optional dependency (and boto3) isn't required by
|
||||||
|
the rest of the package; cached after the first call.
|
||||||
|
"""
|
||||||
|
global _BEDROCK_CLASS
|
||||||
|
if _BEDROCK_CLASS is not None:
|
||||||
|
return _BEDROCK_CLASS
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_aws import ChatBedrockConverse
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"AWS Bedrock support requires the optional 'langchain-aws' dependency. "
|
||||||
|
'Install it with: pip install "tradingagents[bedrock]"'
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
class NormalizedChatBedrockConverse(ChatBedrockConverse):
|
||||||
|
"""ChatBedrockConverse with normalized (string) content output."""
|
||||||
|
|
||||||
|
def invoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
_BEDROCK_CLASS = NormalizedChatBedrockConverse
|
||||||
|
return _BEDROCK_CLASS
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockClient(BaseLLMClient):
|
||||||
|
"""Client for Amazon Bedrock via the Converse API (langchain-aws).
|
||||||
|
|
||||||
|
Authentication uses the standard AWS credential chain (env vars,
|
||||||
|
``~/.aws/credentials``, or an IAM role); set ``AWS_REGION`` /
|
||||||
|
``AWS_DEFAULT_REGION`` and optionally ``AWS_PROFILE``. The model name is a
|
||||||
|
Bedrock model ID or cross-region inference profile ID, e.g.
|
||||||
|
``us.anthropic.claude-opus-4-8-v1:0``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return a configured ChatBedrockConverse instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
chat_cls = _bedrock_class()
|
||||||
|
|
||||||
|
region = (
|
||||||
|
os.environ.get("AWS_REGION")
|
||||||
|
or os.environ.get("AWS_DEFAULT_REGION")
|
||||||
|
or _DEFAULT_REGION
|
||||||
|
)
|
||||||
|
llm_kwargs = {"model": self.model, "region_name": region}
|
||||||
|
for key in ("temperature", "max_tokens", "max_retries", "callbacks"):
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
return chat_cls(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate model for Bedrock (any model ID accepted)."""
|
||||||
|
return validate_model("bedrock", self.model)
|
||||||
@@ -43,6 +43,10 @@ def create_llm_client(
|
|||||||
from .azure_client import AzureOpenAIClient
|
from .azure_client import AzureOpenAIClient
|
||||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "bedrock":
|
||||||
|
from .bedrock_client import BedrockClient
|
||||||
|
return BedrockClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
from .openai_client import OpenAIClient, is_openai_compatible
|
from .openai_client import OpenAIClient, is_openai_compatible
|
||||||
if is_openai_compatible(provider_lower):
|
if is_openai_compatible(provider_lower):
|
||||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
|||||||
@@ -193,6 +193,8 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
|||||||
"kimi": _CUSTOM_ONLY,
|
"kimi": _CUSTOM_ONLY,
|
||||||
"groq": _CUSTOM_ONLY,
|
"groq": _CUSTOM_ONLY,
|
||||||
"nvidia": _CUSTOM_ONLY,
|
"nvidia": _CUSTOM_ONLY,
|
||||||
|
# Bedrock model IDs / cross-region inference profile IDs are user-specified.
|
||||||
|
"bedrock": _CUSTOM_ONLY,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from .model_catalog import get_known_models
|
|||||||
# accepted without warning.
|
# accepted without warning.
|
||||||
_ANY_MODEL_PROVIDERS = (
|
_ANY_MODEL_PROVIDERS = (
|
||||||
"ollama", "openrouter", "openai_compatible",
|
"ollama", "openrouter", "openai_compatible",
|
||||||
"mistral", "kimi", "groq", "nvidia",
|
"mistral", "kimi", "groq", "nvidia", "bedrock",
|
||||||
)
|
)
|
||||||
|
|
||||||
VALID_MODELS = {
|
VALID_MODELS = {
|
||||||
|
|||||||
Reference in New Issue
Block a user