mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-02 15:03:10 +03:00
fix(security): validate ticker before using as path component (#618)
The ticker symbol reaches three filesystem-path construction sites (load_ohlcv cache filename, checkpointer DB path, _log_state results directory) without validation. A value containing path separators or "../" escapes the configured cache / checkpoints / results directory. Two attack vectors: - Programmatic callers passing arbitrary ticker to propagate() - Prompt injection via fetched news content steering the LLM into tool calls with attacker-chosen ticker Fix: new safe_ticker_component() validator in tradingagents/dataflows/ utils.py applied at all three sites. Allows the standard ticker character set ([A-Za-z0-9._\-\^], up to 32 chars) and explicitly rejects dot-only values like "." and ".." which would otherwise pass the regex but traverse parent directories. Seven test cases cover the accepted formats (BRK-B, 7203.T, ^GSPC, etc.) and the rejected inputs (path separators, null bytes, whitespace, empty values, overlong strings, dot-only values). Closes #618.
This commit is contained in:
52
tests/test_safe_ticker_component.py
Normal file
52
tests/test_safe_ticker_component.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Tests for the ticker path-component validator that blocks directory traversal."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.dataflows.utils import safe_ticker_component
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSafeTickerComponent(unittest.TestCase):
|
||||
def test_accepts_common_ticker_formats(self):
|
||||
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
|
||||
self.assertEqual(safe_ticker_component(ticker), ticker)
|
||||
|
||||
def test_rejects_path_separators(self):
|
||||
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component(bad)
|
||||
|
||||
def test_rejects_null_byte_and_whitespace(self):
|
||||
for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component(bad)
|
||||
|
||||
def test_rejects_empty_or_non_string(self):
|
||||
for bad in ("", None, 123, b"AAPL"):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component(bad)
|
||||
|
||||
def test_rejects_overlong_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component("A" * 33)
|
||||
|
||||
def test_rejects_dot_only_values(self):
|
||||
# '.' and '..' pass the regex but traverse when used as a path
|
||||
# component (e.g. ``Path(results_dir) / ticker / "logs"``).
|
||||
for bad in (".", "..", "...", "...."):
|
||||
with self.assertRaises(ValueError):
|
||||
safe_ticker_component(bad)
|
||||
|
||||
def test_traversal_string_does_not_escape_join(self):
|
||||
"""Sanity: sanitized values stay within base when joined."""
|
||||
base = os.path.realpath("/tmp/cache")
|
||||
ticker = safe_ticker_component("AAPL")
|
||||
joined = os.path.realpath(os.path.join(base, f"{ticker}.csv"))
|
||||
self.assertTrue(joined.startswith(base + os.sep))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -8,6 +8,7 @@ from stockstats import wrap
|
||||
from typing import Annotated
|
||||
import os
|
||||
from .config import get_config
|
||||
from .utils import safe_ticker_component
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +52,10 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
||||
subsequent calls the cache is reused. Rows after curr_date are
|
||||
filtered out so backtests never see future prices.
|
||||
"""
|
||||
# Reject ticker values that would escape the cache directory when
|
||||
# interpolated into the cache filename (e.g. ``../../tmp/x``).
|
||||
safe_symbol = safe_ticker_component(symbol)
|
||||
|
||||
config = get_config()
|
||||
curr_date_dt = pd.to_datetime(curr_date)
|
||||
|
||||
@@ -63,7 +68,7 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||
data_file = os.path.join(
|
||||
config["data_cache_dir"],
|
||||
f"{symbol}-YFin-data-{start_str}-{end_str}.csv",
|
||||
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
|
||||
)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import pandas as pd
|
||||
from datetime import date, timedelta, datetime
|
||||
@@ -6,6 +7,40 @@ from typing import Annotated
|
||||
|
||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||
|
||||
# Tickers can contain letters, digits, dot, dash, underscore, and caret
|
||||
# (for index symbols like ^GSPC). Anything else is rejected so the value
|
||||
# never escapes a containing directory when interpolated into a path.
|
||||
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$")
|
||||
|
||||
|
||||
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
|
||||
"""Validate ``value`` is safe to interpolate into a filesystem path.
|
||||
|
||||
Tickers come from user CLI input or from LLM tool calls, both of which
|
||||
can be influenced by attacker-controlled content (e.g. prompt injection
|
||||
embedded in fetched news). Without validation, a value like
|
||||
``"../../../etc/foo"`` flows into ``os.path.join`` / ``Path /`` and
|
||||
escapes the configured cache, checkpoint, or results directory.
|
||||
|
||||
Returns ``value`` unchanged when it matches the allowed pattern; raises
|
||||
``ValueError`` otherwise.
|
||||
"""
|
||||
if not isinstance(value, str) or not value:
|
||||
raise ValueError(f"ticker must be a non-empty string, got {value!r}")
|
||||
if len(value) > max_len:
|
||||
raise ValueError(f"ticker exceeds {max_len} chars: {value!r}")
|
||||
if not _TICKER_PATH_RE.fullmatch(value):
|
||||
raise ValueError(
|
||||
f"ticker contains characters not allowed in a filesystem path: {value!r}"
|
||||
)
|
||||
# The regex above allows '.', so values like '.', '..', '...' would pass,
|
||||
# and as a path component they traverse the parent directory. Reject any
|
||||
# value that's only dots.
|
||||
if set(value) == {"."}:
|
||||
raise ValueError(f"ticker cannot consist solely of dots: {value!r}")
|
||||
return value
|
||||
|
||||
|
||||
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
||||
if save_path:
|
||||
data.to_csv(save_path, encoding="utf-8")
|
||||
|
||||
@@ -13,12 +13,16 @@ from typing import Generator
|
||||
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
|
||||
from tradingagents.dataflows.utils import safe_ticker_component
|
||||
|
||||
|
||||
def _db_path(data_dir: str | Path, ticker: str) -> Path:
|
||||
"""Return the SQLite checkpoint DB path for a ticker."""
|
||||
# Reject ticker values that would escape the checkpoints directory.
|
||||
safe = safe_ticker_component(ticker).upper()
|
||||
p = Path(data_dir) / "checkpoints"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p / f"{ticker.upper()}.db"
|
||||
return p / f"{safe}.db"
|
||||
|
||||
|
||||
def thread_id(ticker: str, date: str) -> str:
|
||||
|
||||
@@ -18,6 +18,7 @@ from tradingagents.llm_clients import create_llm_client
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||
from tradingagents.dataflows.utils import safe_ticker_component
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
@@ -378,8 +379,10 @@ class TradingAgentsGraph:
|
||||
"final_trade_decision": final_state["final_trade_decision"],
|
||||
}
|
||||
|
||||
# Save to file
|
||||
directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs"
|
||||
# Save to file. Reject ticker values that would escape the
|
||||
# results directory when joined as a path component.
|
||||
safe_ticker = safe_ticker_component(self.ticker)
|
||||
directory = Path(self.config["results_dir"]) / safe_ticker / "TradingAgentsStrategy_logs"
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
log_path = directory / f"full_states_log_{trade_date}.json"
|
||||
|
||||
Reference in New Issue
Block a user