mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-01 14:33:10 +03:00
feat: add footer statistics tracking with LangChain callbacks
- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens - Integrate callbacks into TradingAgentsGraph and all LLM clients - Dynamic agent/report counts based on selected analysts - Fix report completion counting (tied to agent completion)
This commit is contained in:
76
cli/stats_handler.py
Normal file
76
cli/stats_handler.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import threading
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
class StatsCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler that tracks LLM calls, tool calls, and token usage."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
self.llm_calls = 0
|
||||
self.tool_calls = 0
|
||||
self.tokens_in = 0
|
||||
self.tokens_out = 0
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Increment LLM call counter when an LLM starts."""
|
||||
with self._lock:
|
||||
self.llm_calls += 1
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[Any]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Increment LLM call counter when a chat model starts."""
|
||||
with self._lock:
|
||||
self.llm_calls += 1
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Extract token usage from LLM response."""
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
except (IndexError, TypeError):
|
||||
return
|
||||
|
||||
usage_metadata = None
|
||||
if hasattr(generation, "message"):
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"):
|
||||
usage_metadata = message.usage_metadata
|
||||
|
||||
if usage_metadata:
|
||||
with self._lock:
|
||||
self.tokens_in += usage_metadata.get("input_tokens", 0)
|
||||
self.tokens_out += usage_metadata.get("output_tokens", 0)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Increment tool call counter when a tool starts."""
|
||||
with self._lock:
|
||||
self.tool_calls += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Return current statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
"llm_calls": self.llm_calls,
|
||||
"tool_calls": self.tool_calls,
|
||||
"tokens_in": self.tokens_in,
|
||||
"tokens_out": self.tokens_out,
|
||||
}
|
||||
Reference in New Issue
Block a user