mirror of
https://github.com/TauricResearch/TradingAgents.git
synced 2026-05-02 15:03:10 +03:00
fix(security): validate ticker before using as path component (#618)
The ticker symbol reaches three filesystem-path construction sites (load_ohlcv cache filename, checkpointer DB path, _log_state results directory) without validation. A value containing path separators or "../" escapes the configured cache / checkpoints / results directory. Two attack vectors: - Programmatic callers passing arbitrary ticker to propagate() - Prompt injection via fetched news content steering the LLM into tool calls with attacker-chosen ticker Fix: new safe_ticker_component() validator in tradingagents/dataflows/ utils.py applied at all three sites. Allows the standard ticker character set ([A-Za-z0-9._\-\^], up to 32 chars) and explicitly rejects dot-only values like "." and ".." which would otherwise pass the regex but traverse parent directories. Seven test cases cover the accepted formats (BRK-B, 7203.T, ^GSPC, etc.) and the rejected inputs (path separators, null bytes, whitespace, empty values, overlong strings, dot-only values). Closes #618.
This commit is contained in:
52
tests/test_safe_ticker_component.py
Normal file
52
tests/test_safe_ticker_component.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Tests for the ticker path-component validator that blocks directory traversal."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSafeTickerComponent(unittest.TestCase):
|
||||||
|
def test_accepts_common_ticker_formats(self):
|
||||||
|
for ticker in ("AAPL", "BRK-B", "BRK.A", "0700.HK", "7203.T", "BHP.AX", "^GSPC"):
|
||||||
|
self.assertEqual(safe_ticker_component(ticker), ticker)
|
||||||
|
|
||||||
|
def test_rejects_path_separators(self):
|
||||||
|
for bad in (".", "..", "../etc", "a/b", "a\\b", "/abs", "..\\..\\x"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_null_byte_and_whitespace(self):
|
||||||
|
for bad in ("AAP L", "AAPL\x00", "AAPL\n", "\tAAPL"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_empty_or_non_string(self):
|
||||||
|
for bad in ("", None, 123, b"AAPL"):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_rejects_overlong_input(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component("A" * 33)
|
||||||
|
|
||||||
|
def test_rejects_dot_only_values(self):
|
||||||
|
# '.' and '..' pass the regex but traverse when used as a path
|
||||||
|
# component (e.g. ``Path(results_dir) / ticker / "logs"``).
|
||||||
|
for bad in (".", "..", "...", "...."):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
safe_ticker_component(bad)
|
||||||
|
|
||||||
|
def test_traversal_string_does_not_escape_join(self):
|
||||||
|
"""Sanity: sanitized values stay within base when joined."""
|
||||||
|
base = os.path.realpath("/tmp/cache")
|
||||||
|
ticker = safe_ticker_component("AAPL")
|
||||||
|
joined = os.path.realpath(os.path.join(base, f"{ticker}.csv"))
|
||||||
|
self.assertTrue(joined.startswith(base + os.sep))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -8,6 +8,7 @@ from stockstats import wrap
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
import os
|
import os
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
|
from .utils import safe_ticker_component
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,6 +52,10 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
|||||||
subsequent calls the cache is reused. Rows after curr_date are
|
subsequent calls the cache is reused. Rows after curr_date are
|
||||||
filtered out so backtests never see future prices.
|
filtered out so backtests never see future prices.
|
||||||
"""
|
"""
|
||||||
|
# Reject ticker values that would escape the cache directory when
|
||||||
|
# interpolated into the cache filename (e.g. ``../../tmp/x``).
|
||||||
|
safe_symbol = safe_ticker_component(symbol)
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
curr_date_dt = pd.to_datetime(curr_date)
|
curr_date_dt = pd.to_datetime(curr_date)
|
||||||
|
|
||||||
@@ -63,7 +68,7 @@ def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
|||||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||||
data_file = os.path.join(
|
data_file = os.path.join(
|
||||||
config["data_cache_dir"],
|
config["data_cache_dir"],
|
||||||
f"{symbol}-YFin-data-{start_str}-{end_str}.csv",
|
f"{safe_symbol}-YFin-data-{start_str}-{end_str}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.exists(data_file):
|
if os.path.exists(data_file):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datetime import date, timedelta, datetime
|
from datetime import date, timedelta, datetime
|
||||||
@@ -6,6 +7,40 @@ from typing import Annotated
|
|||||||
|
|
||||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||||
|
|
||||||
|
# Tickers can contain letters, digits, dot, dash, underscore, and caret
|
||||||
|
# (for index symbols like ^GSPC). Anything else is rejected so the value
|
||||||
|
# never escapes a containing directory when interpolated into a path.
|
||||||
|
_TICKER_PATH_RE = re.compile(r"^[A-Za-z0-9._\-\^]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_ticker_component(value: str, *, max_len: int = 32) -> str:
|
||||||
|
"""Validate ``value`` is safe to interpolate into a filesystem path.
|
||||||
|
|
||||||
|
Tickers come from user CLI input or from LLM tool calls, both of which
|
||||||
|
can be influenced by attacker-controlled content (e.g. prompt injection
|
||||||
|
embedded in fetched news). Without validation, a value like
|
||||||
|
``"../../../etc/foo"`` flows into ``os.path.join`` / ``Path /`` and
|
||||||
|
escapes the configured cache, checkpoint, or results directory.
|
||||||
|
|
||||||
|
Returns ``value`` unchanged when it matches the allowed pattern; raises
|
||||||
|
``ValueError`` otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str) or not value:
|
||||||
|
raise ValueError(f"ticker must be a non-empty string, got {value!r}")
|
||||||
|
if len(value) > max_len:
|
||||||
|
raise ValueError(f"ticker exceeds {max_len} chars: {value!r}")
|
||||||
|
if not _TICKER_PATH_RE.fullmatch(value):
|
||||||
|
raise ValueError(
|
||||||
|
f"ticker contains characters not allowed in a filesystem path: {value!r}"
|
||||||
|
)
|
||||||
|
# The regex above allows '.', so values like '.', '..', '...' would pass,
|
||||||
|
# and as a path component they traverse the parent directory. Reject any
|
||||||
|
# value that's only dots.
|
||||||
|
if set(value) == {"."}:
|
||||||
|
raise ValueError(f"ticker cannot consist solely of dots: {value!r}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
||||||
if save_path:
|
if save_path:
|
||||||
data.to_csv(save_path, encoding="utf-8")
|
data.to_csv(save_path, encoding="utf-8")
|
||||||
|
|||||||
@@ -13,12 +13,16 @@ from typing import Generator
|
|||||||
|
|
||||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
|
|
||||||
|
|
||||||
def _db_path(data_dir: str | Path, ticker: str) -> Path:
|
def _db_path(data_dir: str | Path, ticker: str) -> Path:
|
||||||
"""Return the SQLite checkpoint DB path for a ticker."""
|
"""Return the SQLite checkpoint DB path for a ticker."""
|
||||||
|
# Reject ticker values that would escape the checkpoints directory.
|
||||||
|
safe = safe_ticker_component(ticker).upper()
|
||||||
p = Path(data_dir) / "checkpoints"
|
p = Path(data_dir) / "checkpoints"
|
||||||
p.mkdir(parents=True, exist_ok=True)
|
p.mkdir(parents=True, exist_ok=True)
|
||||||
return p / f"{ticker.upper()}.db"
|
return p / f"{safe}.db"
|
||||||
|
|
||||||
|
|
||||||
def thread_id(ticker: str, date: str) -> str:
|
def thread_id(ticker: str, date: str) -> str:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from tradingagents.llm_clients import create_llm_client
|
|||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||||
|
from tradingagents.dataflows.utils import safe_ticker_component
|
||||||
from tradingagents.agents.utils.agent_states import (
|
from tradingagents.agents.utils.agent_states import (
|
||||||
AgentState,
|
AgentState,
|
||||||
InvestDebateState,
|
InvestDebateState,
|
||||||
@@ -378,8 +379,10 @@ class TradingAgentsGraph:
|
|||||||
"final_trade_decision": final_state["final_trade_decision"],
|
"final_trade_decision": final_state["final_trade_decision"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save to file
|
# Save to file. Reject ticker values that would escape the
|
||||||
directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs"
|
# results directory when joined as a path component.
|
||||||
|
safe_ticker = safe_ticker_component(self.ticker)
|
||||||
|
directory = Path(self.config["results_dir"]) / safe_ticker / "TradingAgentsStrategy_logs"
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
log_path = directory / f"full_states_log_{trade_date}.json"
|
log_path = directory / f"full_states_log_{trade_date}.json"
|
||||||
|
|||||||
Reference in New Issue
Block a user