mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-06-16 21:06:15 +03:00
Clear the deferred full-repo lint backlog so the whole tree passes the strict ruff select (E,W,F,I,B,UP,C4,SIM). Mechanical fixes dominate: import sorting, pep585/604 annotations, dropped dead imports, and whitespace. The few semantic changes are behavior-preserving: declare __all__ on the agent_utils and alpha_vantage re-export hubs; expand 'from x import *' to explicit names; use immutable tuple defaults instead of mutable list defaults; contextlib.suppress for try/except/pass; and narrow an over-broad assertRaises.
77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
import threading
|
|
from typing import Any
|
|
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.messages import AIMessage
|
|
from langchain_core.outputs import LLMResult
|
|
|
|
|
|
class StatsCallbackHandler(BaseCallbackHandler):
|
|
"""Callback handler that tracks LLM calls, tool calls, and token usage."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._lock = threading.Lock()
|
|
self.llm_calls = 0
|
|
self.tool_calls = 0
|
|
self.tokens_in = 0
|
|
self.tokens_out = 0
|
|
|
|
def on_llm_start(
|
|
self,
|
|
serialized: dict[str, Any],
|
|
prompts: list[str],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Increment LLM call counter when an LLM starts."""
|
|
with self._lock:
|
|
self.llm_calls += 1
|
|
|
|
def on_chat_model_start(
|
|
self,
|
|
serialized: dict[str, Any],
|
|
messages: list[list[Any]],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Increment LLM call counter when a chat model starts."""
|
|
with self._lock:
|
|
self.llm_calls += 1
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
"""Extract token usage from LLM response."""
|
|
try:
|
|
generation = response.generations[0][0]
|
|
except (IndexError, TypeError):
|
|
return
|
|
|
|
usage_metadata = None
|
|
if hasattr(generation, "message"):
|
|
message = generation.message
|
|
if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"):
|
|
usage_metadata = message.usage_metadata
|
|
|
|
if usage_metadata:
|
|
with self._lock:
|
|
self.tokens_in += usage_metadata.get("input_tokens", 0)
|
|
self.tokens_out += usage_metadata.get("output_tokens", 0)
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: dict[str, Any],
|
|
input_str: str,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Increment tool call counter when a tool starts."""
|
|
with self._lock:
|
|
self.tool_calls += 1
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Return current statistics."""
|
|
with self._lock:
|
|
return {
|
|
"llm_calls": self.llm_calls,
|
|
"tool_calls": self.tool_calls,
|
|
"tokens_in": self.tokens_in,
|
|
"tokens_out": self.tokens_out,
|
|
}
|