mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-17 05:16:14 +03:00
feat(llm): add Amazon Bedrock as a first-class provider
Bedrock uses the Converse API (langchain-aws) and the AWS credential chain, so it has its own client like Anthropic/Google rather than the OpenAI-compatible registry. langchain-aws is an optional dependency (pip install ".[bedrock]"), lazy-imported with a clear install hint; importing the package never requires it. The model name is a Bedrock model ID / inference profile ID.
This commit is contained in:
@@ -19,6 +19,8 @@ PROVIDER_API_KEY_ENV: dict[str, Optional[str]] = {
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
"azure": "AZURE_OPENAI_API_KEY",
|
||||
# Bedrock authenticates via the AWS credential chain, not a single key env.
|
||||
"bedrock": None,
|
||||
"xai": "XAI_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
# Dual-region providers each carry their own account; keys are not
|
||||
|
||||
69
tradingagents/llm_clients/bedrock_client.py
Normal file
69
tradingagents/llm_clients/bedrock_client.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .validators import validate_model
|
||||
|
||||
# Bedrock has no global default region; us-west-2 hosts the broadest model set.
|
||||
_DEFAULT_REGION = "us-west-2"
|
||||
_BEDROCK_CLASS = None
|
||||
|
||||
|
||||
def _bedrock_class():
|
||||
"""Lazily import langchain-aws (the optional ``[bedrock]`` extra) and return a
|
||||
ChatBedrockConverse subclass with normalized content output.
|
||||
|
||||
Imported on demand so the optional dependency (and boto3) isn't required by
|
||||
the rest of the package; cached after the first call.
|
||||
"""
|
||||
global _BEDROCK_CLASS
|
||||
if _BEDROCK_CLASS is not None:
|
||||
return _BEDROCK_CLASS
|
||||
|
||||
try:
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"AWS Bedrock support requires the optional 'langchain-aws' dependency. "
|
||||
'Install it with: pip install "tradingagents[bedrock]"'
|
||||
) from exc
|
||||
|
||||
class NormalizedChatBedrockConverse(ChatBedrockConverse):
|
||||
"""ChatBedrockConverse with normalized (string) content output."""
|
||||
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
_BEDROCK_CLASS = NormalizedChatBedrockConverse
|
||||
return _BEDROCK_CLASS
|
||||
|
||||
|
||||
class BedrockClient(BaseLLMClient):
|
||||
"""Client for Amazon Bedrock via the Converse API (langchain-aws).
|
||||
|
||||
Authentication uses the standard AWS credential chain (env vars,
|
||||
``~/.aws/credentials``, or an IAM role); set ``AWS_REGION`` /
|
||||
``AWS_DEFAULT_REGION`` and optionally ``AWS_PROFILE``. The model name is a
|
||||
Bedrock model ID or cross-region inference profile ID, e.g.
|
||||
``us.anthropic.claude-opus-4-8-v1:0``.
|
||||
"""
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return a configured ChatBedrockConverse instance."""
|
||||
self.warn_if_unknown_model()
|
||||
chat_cls = _bedrock_class()
|
||||
|
||||
region = (
|
||||
os.environ.get("AWS_REGION")
|
||||
or os.environ.get("AWS_DEFAULT_REGION")
|
||||
or _DEFAULT_REGION
|
||||
)
|
||||
llm_kwargs = {"model": self.model, "region_name": region}
|
||||
for key in ("temperature", "max_tokens", "max_retries", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
return chat_cls(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for Bedrock (any model ID accepted)."""
|
||||
return validate_model("bedrock", self.model)
|
||||
@@ -43,6 +43,10 @@ def create_llm_client(
|
||||
from .azure_client import AzureOpenAIClient
|
||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "bedrock":
|
||||
from .bedrock_client import BedrockClient
|
||||
return BedrockClient(model, base_url, **kwargs)
|
||||
|
||||
from .openai_client import OpenAIClient, is_openai_compatible
|
||||
if is_openai_compatible(provider_lower):
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
@@ -193,6 +193,8 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
||||
"kimi": _CUSTOM_ONLY,
|
||||
"groq": _CUSTOM_ONLY,
|
||||
"nvidia": _CUSTOM_ONLY,
|
||||
# Bedrock model IDs / cross-region inference profile IDs are user-specified.
|
||||
"bedrock": _CUSTOM_ONLY,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from .model_catalog import get_known_models
|
||||
# accepted without warning.
|
||||
_ANY_MODEL_PROVIDERS = (
|
||||
"ollama", "openrouter", "openai_compatible",
|
||||
"mistral", "kimi", "groq", "nvidia",
|
||||
"mistral", "kimi", "groq", "nvidia", "bedrock",
|
||||
)
|
||||
|
||||
VALID_MODELS = {
|
||||
|
||||
Reference in New Issue
Block a user