diff --git a/tests/test_reddit_fallback.py b/tests/test_reddit_fallback.py new file mode 100644 index 000000000..ecb944f71 --- /dev/null +++ b/tests/test_reddit_fallback.py @@ -0,0 +1,114 @@ +"""Tests for the Reddit RSS/Atom fallback when the JSON endpoint 403s (#862).""" + +from __future__ import annotations + +from unittest.mock import patch +from urllib.error import HTTPError + +import pytest + +from tradingagents.dataflows import reddit + + +_SAMPLE_ATOM = """ + + + NVDA earnings beat, stock pops + 2026-05-20T14:30:00+00:00 + <!-- SC_OFF --><div class="md"><p>Great <b>quarter</b> for NVDA&#39;s datacenter unit.</p></div><!-- SC_ON --> + + + Is NVDA overvalued? + 2026-05-19T09:00:00Z + <p>Forward P/E discussion</p> + + +""" + + +@pytest.mark.unit +class TestIsoToTimestamp: + def test_parses_offset_and_z(self): + assert reddit._iso_to_timestamp("2026-05-20T14:30:00+00:00") > 0 + assert reddit._iso_to_timestamp("2026-05-19T09:00:00Z") > 0 + + def test_none_and_garbage_return_none(self): + assert reddit._iso_to_timestamp(None) is None + assert reddit._iso_to_timestamp("not-a-date") is None + + +@pytest.mark.unit +class TestStripHtml: + def test_extracts_between_sc_markers_and_unescapes(self): + raw = "

Great quarter & more

" + assert reddit._strip_html(raw) == "Great quarter & more" + + def test_empty(self): + assert reddit._strip_html("") == "" + + +@pytest.mark.unit +class TestRssFallbackParsing: + def _patch_rss_response(self, xml_bytes): + class _Resp: + def __enter__(self_inner): + return self_inner + def __exit__(self_inner, *a): + return False + def read(self_inner): + return xml_bytes + return patch.object(reddit, "urlopen", return_value=_Resp()) + + def test_parses_atom_entries(self): + with self._patch_rss_response(_SAMPLE_ATOM.encode("utf-8")): + posts = reddit._fetch_subreddit_rss("NVDA", "stocks", limit=5, timeout=5.0) + assert len(posts) == 2 + assert posts[0]["title"] == "NVDA earnings beat, stock pops" + assert posts[0]["source"] == "rss" + assert posts[0]["score"] is None + assert posts[0]["num_comments"] is None + assert posts[0]["created_utc"] > 0 + assert "datacenter unit" in posts[0]["selftext"] + + def test_malformed_xml_fails_open(self): + with self._patch_rss_response(b"<>"): + assert reddit._fetch_subreddit_rss("NVDA", "stocks", 5, 5.0) == [] + + +@pytest.mark.unit +class TestJsonFallsBackToRss: + def test_403_triggers_rss(self): + err = HTTPError("url", 403, "Blocked", {}, None) + with patch.object(reddit, "urlopen", side_effect=err), \ + patch.object(reddit, "_fetch_subreddit_rss", return_value=[{"title": "x", "source": "rss", "score": None, "num_comments": None, "created_utc": None, "selftext": ""}]) as rss: + out = reddit._fetch_subreddit("NVDA", "stocks", 5, 5.0) + rss.assert_called_once() + assert out and out[0]["source"] == "rss" + + +@pytest.mark.unit +class TestFormatterHandlesRssPosts: + def test_rss_posts_omit_fake_counts_and_note_source(self): + rss_posts = [{ + "title": "NVDA pops", "score": None, "num_comments": None, + "created_utc": reddit._iso_to_timestamp("2026-05-20T14:30:00Z"), + "selftext": "great quarter", "source": "rss", + }] + with patch.object(reddit, "_fetch_subreddit", return_value=rss_posts): + out = reddit.fetch_reddit_posts("NVDA", subreddits=("stocks",), inter_request_delay=0) + assert "via RSS feed" in out + assert "↑" not in out # no fake score arrow + assert "NVDA pops" in out + assert "great quarter" in out + + def test_json_posts_still_show_counts(self): + json_posts = [{ + "title": "NVDA pops", "score": 1234, "num_comments": 56, + "created_utc": reddit._iso_to_timestamp("2026-05-20T14:30:00Z"), + "selftext": "", + }] + with patch.object(reddit, "_fetch_subreddit", return_value=json_posts): + out = reddit.fetch_reddit_posts("NVDA", subreddits=("stocks",), inter_request_delay=0) + assert "1234↑" in out + assert "56c" in out + assert "via RSS" not in out diff --git a/tradingagents/dataflows/reddit.py b/tradingagents/dataflows/reddit.py index c8e01b92c..82303d9b4 100644 --- a/tradingagents/dataflows/reddit.py +++ b/tradingagents/dataflows/reddit.py @@ -1,21 +1,30 @@ """Reddit search fetcher for ticker-specific discussion posts. -Uses Reddit's public JSON endpoints (``reddit.com/r/{sub}/search.json``) -which do not require an API key. Public throughput is ~10 requests per -minute per IP, well within budget for a single agent run that queries -a handful of finance subreddits per ticker. +Primary path is Reddit's public JSON search endpoint +(``reddit.com/r/{sub}/search.json``), which carries the richest data +(score, comment count, body). Reddit's WAF increasingly returns +``HTTP 403 Blocked`` on that endpoint (issue #862), so when the JSON request +fails we transparently fall back to the public Atom/RSS search feed +(``/search.rss``). The RSS feed is gated less aggressively and serves the +same descriptive User-Agent we already send; the fallback lacks score / +comment counts, so RSS-sourced posts are marked and the formatter omits those +metrics rather than printing fake zeros. -Returns formatted plaintext blocks ready for prompt injection. Degrades -gracefully — returns a placeholder string rather than raising, so callers -never have to special-case missing data. +No API key required either way. Returns formatted plaintext blocks ready for +prompt injection and degrades gracefully — returns a placeholder string +rather than raising, so callers never special-case missing data. """ from __future__ import annotations +import html import json import logging +import re import time -from typing import Iterable +import xml.etree.ElementTree as ET +from datetime import datetime +from typing import Iterable, Optional from urllib.error import HTTPError, URLError from urllib.parse import urlencode from urllib.request import Request, urlopen @@ -23,7 +32,13 @@ from urllib.request import Request, urlopen logger = logging.getLogger(__name__) _API = "https://www.reddit.com/r/{sub}/search.json?{qs}" +_RSS = "https://www.reddit.com/r/{sub}/search.rss?{qs}" +# A descriptive, identified User-Agent (per Reddit's API etiquette). Reddit +# blocks generic/anonymous tokens like bare "Mozilla/5.0" or "curl/…" but +# serves this one on both endpoints; the RSS feed accepts it even when the +# JSON search endpoint 403s, so no browser-spoofing is needed. _UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)" +_ATOM_NS = {"atom": "http://www.w3.org/2005/Atom"} # Default subreddits ordered roughly by signal density for ticker-specific # discussion. wallstreetbets has the most volume but most noise; stocks / @@ -31,29 +46,95 @@ _UA = "tradingagents/0.2 (+https://github.com/TauricResearch/TradingAgents)" DEFAULT_SUBREDDITS = ("wallstreetbets", "stocks", "investing") +def _search_qs(ticker: str, limit: int) -> str: + return urlencode({ + "q": ticker, + "restrict_sr": "on", + "sort": "new", + "t": "week", # last 7 days + "limit": limit, + }) + + +def _iso_to_timestamp(iso_str: Optional[str]) -> Optional[float]: + """Parse an Atom ``published`` timestamp to a UTC epoch, or None.""" + if not iso_str: + return None + try: + normalized = iso_str[:-1] + "+00:00" if iso_str.endswith("Z") else iso_str + return datetime.fromisoformat(normalized).timestamp() + except (ValueError, TypeError): + return None + + +def _strip_html(content: str) -> str: + """Reduce the HTML body Reddit embeds in an Atom entry to plain text.""" + if not content: + return "" + # Reddit wraps the real selftext between SC_OFF / SC_ON markers. + if "" in content and "" in content: + content = content.split("")[1].split("")[0] + text = re.sub(r"<[^>]+>", " ", content) + return " ".join(html.unescape(text).split()) + + +def _fetch_subreddit_rss( + ticker: str, + sub: str, + limit: int, + timeout: float, +) -> list[dict]: + """Fallback path: parse the public Atom search feed for a subreddit. + + Carries no score / comment counts, so those fields are left None and the + post is tagged ``source="rss"`` for honest display. + """ + url = _RSS.format(sub=sub, qs=_search_qs(ticker, limit)) + req = Request(url, headers={"User-Agent": _UA}) + try: + with urlopen(req, timeout=timeout) as resp: + root = ET.fromstring(resp.read()) + except (HTTPError, URLError, TimeoutError, ET.ParseError) as exc: + logger.warning("Reddit RSS fetch failed for r/%s · %s: %s", sub, ticker, exc) + return [] + + posts = [] + for entry in root.findall("atom:entry", _ATOM_NS)[:limit]: + title_el = entry.find("atom:title", _ATOM_NS) + published_el = entry.find("atom:published", _ATOM_NS) + content_el = entry.find("atom:content", _ATOM_NS) + posts.append({ + "title": (title_el.text if title_el is not None else "") or "", + "score": None, + "num_comments": None, + "created_utc": _iso_to_timestamp( + published_el.text if published_el is not None else None + ), + "selftext": _strip_html(content_el.text if content_el is not None else ""), + "source": "rss", + }) + return posts + + def _fetch_subreddit( ticker: str, sub: str, limit: int, timeout: float, ) -> list[dict]: - qs = urlencode({ - "q": ticker, - "restrict_sr": "on", - "sort": "new", - "t": "week", # last 7 days - "limit": limit, - }) - url = _API.format(sub=sub, qs=qs) + url = _API.format(sub=sub, qs=_search_qs(ticker, limit)) req = Request(url, headers={"User-Agent": _UA, "Accept": "application/json"}) try: with urlopen(req, timeout=timeout) as resp: payload = json.loads(resp.read()) + children = (payload.get("data") or {}).get("children") or [] + return [c.get("data", {}) for c in children if isinstance(c, dict)] except (HTTPError, URLError, json.JSONDecodeError, TimeoutError) as exc: - logger.warning("Reddit fetch failed for r/%s · %s: %s", sub, ticker, exc) - return [] - children = (payload.get("data") or {}).get("children") or [] - return [c.get("data", {}) for c in children if isinstance(c, dict)] + logger.warning( + "Reddit JSON fetch failed for r/%s · %s: %s — falling back to RSS feed.", + sub, ticker, exc, + ) + return _fetch_subreddit_rss(ticker, sub, limit, timeout) def fetch_reddit_posts( @@ -80,20 +161,28 @@ def fetch_reddit_posts( blocks.append(f"r/{sub}: ") continue - lines = [f"r/{sub} — {len(posts)} recent posts mentioning {ticker.upper()}:"] + via_rss = any(p.get("source") == "rss" for p in posts) + header = f"r/{sub} — {len(posts)} recent posts mentioning {ticker.upper()}" + header += " (via RSS feed; scores/comments unavailable):" if via_rss else ":" + lines = [header] for p in posts: title = (p.get("title") or "").replace("\n", " ").strip() - score = p.get("score", 0) - comments = p.get("num_comments", 0) + score = p.get("score") + comments = p.get("num_comments") created = p.get("created_utc") created_str = ( time.strftime("%Y-%m-%d", time.gmtime(created)) if created else "?" ) + # Score / comment counts are absent on the RSS fallback path — + # show them only when present rather than printing fake zeros. + meta = created_str + if score is not None and comments is not None: + meta += f" · {score:>4}↑ · {comments:>3}c" selftext = (p.get("selftext") or "").replace("\n", " ").strip() if len(selftext) > 240: selftext = selftext[:240] + "…" lines.append( - f" [{created_str} · {score:>4}↑ · {comments:>3}c] {title}" + f" [{meta}] {title}" + (f"\n body excerpt: {selftext}" if selftext else "") ) blocks.append("\n".join(lines))