From 1de5a8f3e65984bcf641839e791fc72177101536 Mon Sep 17 00:00:00 2001 From: Vladyslav Doloman Date: Sun, 19 Oct 2025 23:50:08 +0000 Subject: [PATCH] Fix WebTransport server implementation and add test client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Server fixes: - Move H3Connection initialization to ProtocolNegotiated event (matches official aioquic pattern) - Fix datagram routing to use session_id instead of flow_id - Add max_datagram_frame_size=65536 to enable QUIC datagrams - Fix send_datagram() to use keyword arguments - Add certificate chain handling for Let's Encrypt - Add no-cache headers to static server Command-line improvements: - Move settings from environment variables to argparse - Add comprehensive CLI arguments with defaults - Default mode=wt, cert=cert.pem, key=key.pem Test clients: - Add test_webtransport_client.py - Python WebTransport client that successfully connects - Add test_http3.py - Basic HTTP/3 connectivity test Client updates: - Auto-configure server URL and certificate hash from /cert-hash.json - Add ES6 module support Status: βœ… Python WebTransport client works perfectly βœ… Server properly handles WebTransport connections and datagrams ❌ Chrome fails due to cached QUIC state (QUIC_IETF_GQUIC_ERROR_MISSING) πŸ” Firefox sends packets but fails differently - to be debugged next session πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- client/client.js | 101 +++++++-- client/index.html | 7 +- run.py | 390 +++++++++++++++++++++++++--------- server/quic_transport.py | 13 +- server/static_server.py | 47 +++- server/utils.py | 132 ++++++++++++ server/webtransport_server.py | 100 ++++++--- test_http3.py | 59 +++++ test_webtransport_client.py | 188 ++++++++++++++++ 9 files changed, 890 insertions(+), 147 deletions(-) create mode 100644 test_http3.py create mode 100644 test_webtransport_client.py diff --git a/client/client.js b/client/client.js index 37f073d..bf7d8fe 100644 --- a/client/client.js +++ b/client/client.js @@ -193,24 +193,61 @@ async function readLoop() { } async function connectWT() { + // Check if WebTransport is available + if (typeof WebTransport === 'undefined') { + setStatus('ERROR: WebTransport not supported in this browser. Use Chrome/Edge 97+ or Firefox with flag enabled.'); + return; + } + const url = ui.url.value.trim(); const hashHex = ui.hash.value.trim(); + + if (!url) { + setStatus('ERROR: Please enter a server URL'); + return; + } + setStatus('connecting...'); + console.log('Connecting to:', url); + const opts = {}; if (hashHex) { - const bytes = new Uint8Array(hashHex.match(/.{1,2}/g).map(b => parseInt(b,16))); - opts.serverCertificateHashes = [{ algorithm: 'sha-256', value: bytes }]; + try { + const bytes = new Uint8Array(hashHex.match(/.{1,2}/g).map(b => parseInt(b,16))); + opts.serverCertificateHashes = [{ algorithm: 'sha-256', value: bytes }]; + console.log('Using certificate hash:', hashHex); + } catch (err) { + setStatus('ERROR: Invalid certificate hash format'); + return; + } + } else { + console.warn('No certificate hash provided - connection may fail for self-signed certs'); + } + + try { + const wt = new WebTransport(url, opts); + await wt.ready; + transport = wt; + writer = wt.datagrams.writable.getWriter(); + setStatus('connected to server'); + console.log('WebTransport connected successfully'); + + readLoop(); + requestAnimationFrame(render); + + // Send JOIN immediately (spectator β†’ player upon space handled below) + const pkt = buildJoin(1, nextSeq(), ui.name.value.trim()); + await writer.write(pkt); + } catch (err) { + console.error('WebTransport connection failed:', err); + if (err.message.includes('certificate')) { + setStatus('ERROR: Certificate validation failed. Check cert hash or use valid CA cert.'); + } else if (err.message.includes('net::')) { + setStatus('ERROR: Network error - is server running on ' + url + '?'); + } else { + setStatus('ERROR: ' + err.message); + } } - const wt = new WebTransport(url, opts); - await wt.ready; - transport = wt; - writer = wt.datagrams.writable.getWriter(); - setStatus('connected'); - readLoop(); - requestAnimationFrame(render); - // Send JOIN immediately (spectator β†’ player upon space handled below) - const pkt = buildJoin(1, nextSeq(), ui.name.value.trim()); - await writer.write(pkt); } function dirFromKey(e) { @@ -229,7 +266,47 @@ async function onKey(e) { try { await writer.write(pkt); } catch (err) { /* ignore */ } } +// Auto-configure on page load +async function autoConfigureFromServer() { + try { + // Try to fetch cert hash from the static server + const response = await fetch('/cert-hash.json'); + if (!response.ok) { + setStatus('Auto-config unavailable, please enter manually'); + return; + } + const config = await response.json(); + + // Auto-populate fields + if (config.wtUrl) { + // Use the hostname from browser but with the WT port from server + const hostname = window.location.hostname || '127.0.0.1'; + const wtPort = config.wtPort || 4433; + ui.url.value = `https://${hostname}:${wtPort}/`; + } + if (config.sha256) { + ui.hash.value = config.sha256; + } + + setStatus('Auto-configured from server'); + console.log('Auto-configured:', config); + } catch (err) { + console.warn('Auto-config failed:', err); + // Fallback: just populate URL from browser location + const hostname = window.location.hostname || '127.0.0.1'; + ui.url.value = `https://${hostname}:4433/`; + setStatus('Manual configuration required'); + } +} + ui.connect.onclick = () => { connectWT().catch(e => setStatus('connect failed: ' + e)); }; window.addEventListener('keydown', onKey); window.addEventListener('resize', render); +// Auto-configure when page loads +if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', autoConfigureFromServer); +} else { + autoConfigureFromServer(); +} + diff --git a/client/index.html b/client/index.html index 26c48bb..b97783b 100644 --- a/client/index.html +++ b/client/index.html @@ -15,16 +15,15 @@
- - + +
press space to join
- - + diff --git a/run.py b/run.py index fef3af2..4dadc9d 100644 --- a/run.py +++ b/run.py @@ -1,5 +1,5 @@ ο»Ώimport asyncio -import os +import argparse import logging import sys @@ -7,132 +7,332 @@ from server.server import GameServer from server.config import ServerConfig -async def _run_tasks_with_optional_timeout(tasks): - """Await tasks, optionally honoring RUN_SECONDS env var to cancel after a timeout.""" - timeout_s = os.environ.get("RUN_SECONDS") +async def _run_tasks_with_optional_timeout(tasks, timeout_s=None): + """Await tasks, optionally with a timeout to cancel after specified seconds.""" if not timeout_s: await asyncio.gather(*tasks) return try: await asyncio.wait_for(asyncio.gather(*tasks), timeout=float(timeout_s)) except asyncio.TimeoutError: - logging.info("Timeout reached (RUN_SECONDS=%s); stopping server tasks...", timeout_s) + logging.info("Timeout reached (%s seconds); stopping server tasks...", timeout_s) for t in tasks: t.cancel() await asyncio.gather(*tasks, return_exceptions=True) -async def run_in_memory(): +async def run_in_memory(args): from server.transport import InMemoryTransport - cfg = ServerConfig() + cfg = ServerConfig( + width=args.width, + height=args.height, + tick_rate=args.tick_rate, + wrap_edges=args.wrap_edges, + apples_per_snake=args.apples_per_snake, + apples_cap=args.apples_cap, + players_max=args.players_max, + ) server = GameServer(transport=InMemoryTransport(lambda d, p: server.on_datagram(d, p)), config=cfg) tasks = [asyncio.create_task(server.transport.run()), asyncio.create_task(server.tick_loop())] - await _run_tasks_with_optional_timeout(tasks) + await _run_tasks_with_optional_timeout(tasks, args.run_seconds) -async def run_quic(): +async def run_quic(args): + import os from server.quic_transport import QuicWebTransportServer - cfg = ServerConfig() - host = os.environ.get("QUIC_HOST", "0.0.0.0") - port = int(os.environ.get("QUIC_PORT", "4433")) - cert = os.environ["QUIC_CERT"] - key = os.environ["QUIC_KEY"] - server = GameServer(transport=QuicWebTransportServer(host, port, cert, key, lambda d, p: server.on_datagram(d, p)), config=cfg) - tasks = [asyncio.create_task(server.transport.run()), asyncio.create_task(server.tick_loop())] - await _run_tasks_with_optional_timeout(tasks) + from server.utils import combine_cert_chain + cfg = ServerConfig( + width=args.width, + height=args.height, + tick_rate=args.tick_rate, + wrap_edges=args.wrap_edges, + apples_per_snake=args.apples_per_snake, + apples_cap=args.apples_cap, + players_max=args.players_max, + ) -async def run_webtransport(): - from server.webtransport_server import WebTransportServer - from server.static_server import start_https_static - cfg = ServerConfig() - host = os.environ.get("WT_HOST", os.environ.get("QUIC_HOST", "0.0.0.0")) - port = int(os.environ.get("WT_PORT", os.environ.get("QUIC_PORT", "4433"))) - cert = os.environ.get("WT_CERT") or os.environ["QUIC_CERT"] - key = os.environ.get("WT_KEY") or os.environ["QUIC_KEY"] - # Optional static HTTPS server for client assets - static = os.environ.get("STATIC", "1") - static_host = os.environ.get("STATIC_HOST", host) - static_port = int(os.environ.get("STATIC_PORT", "8443")) - static_root = os.environ.get("STATIC_ROOT", "client") - httpd = None - if static == "1": - httpd, _t = start_https_static(static_host, static_port, cert, key, static_root) - print(f"HTTPS static server: https://{static_host}:{static_port}/ serving '{static_root}'") - server = GameServer(transport=WebTransportServer(host, port, cert, key, lambda d, p: server.on_datagram(d, p)), config=cfg) - print(f"WebTransport server: https://{host}:{port}/ (HTTP/3)") + # Handle certificate chain + # Check if cert file contains multiple certificates (like Let's Encrypt fullchain.pem) + with open(args.cert, 'r') as f: + cert_content = f.read() + + from server.utils import split_pem_certificates + certs_in_file = split_pem_certificates(cert_content) + + cert_file = args.cert + temp_cert_file = None + + if args.cert_chain: + # Combine cert + chain files into single file + temp_cert_file = combine_cert_chain(args.cert, args.cert_chain) + cert_file = temp_cert_file + logging.info("Combined certificate with %d chain file(s)", len(args.cert_chain)) + elif len(certs_in_file) > 1: + # Certificate file contains full chain - reformat it for aioquic + temp_cert_file = combine_cert_chain(args.cert, []) + cert_file = temp_cert_file + logging.info("Reformatted certificate chain from fullchain file (%d certs)", len(certs_in_file)) + + server = GameServer(transport=QuicWebTransportServer(args.quic_host, args.quic_port, cert_file, args.key, lambda d, p: server.on_datagram(d, p)), config=cfg) + logging.info("QUIC server: %s:%d", args.quic_host, args.quic_port) try: tasks = [asyncio.create_task(server.transport.run()), asyncio.create_task(server.tick_loop())] - await _run_tasks_with_optional_timeout(tasks) + await _run_tasks_with_optional_timeout(tasks, args.run_seconds) + finally: + if temp_cert_file: + os.unlink(temp_cert_file) # Clean up temp combined cert file + + +async def run_webtransport(args): + import json + import os + from server.webtransport_server import WebTransportServer + from server.static_server import start_https_static + from server.utils import get_cert_sha256_hash, combine_cert_chain, extract_server_cert + + cfg = ServerConfig( + width=args.width, + height=args.height, + tick_rate=args.tick_rate, + wrap_edges=args.wrap_edges, + apples_per_snake=args.apples_per_snake, + apples_cap=args.apples_cap, + players_max=args.players_max, + ) + + # Handle certificate chain + # Check if cert file contains multiple certificates (like Let's Encrypt fullchain.pem) + with open(args.cert, 'r') as f: + cert_content = f.read() + + from server.utils import split_pem_certificates + certs_in_file = split_pem_certificates(cert_content) + + cert_file = args.cert + temp_cert_file = None + + if args.cert_chain: + # Combine cert + chain files into single file + temp_cert_file = combine_cert_chain(args.cert, args.cert_chain) + cert_file = temp_cert_file + logging.info("Combined certificate with %d chain file(s)", len(args.cert_chain)) + elif len(certs_in_file) > 1: + # Certificate file contains full chain - reformat it for aioquic + temp_cert_file = combine_cert_chain(args.cert, []) + cert_file = temp_cert_file + logging.info("Reformatted certificate chain from fullchain file (%d certs)", len(certs_in_file)) + + # Calculate certificate hash for WebTransport client (only hash server cert, not chain) + server_cert_file = extract_server_cert(args.cert) + cert_hash = get_cert_sha256_hash(server_cert_file) + if server_cert_file != args.cert: + os.unlink(server_cert_file) # Clean up temp file + + # Prepare cert-hash.json content + cert_hash_json = json.dumps({ + "sha256": cert_hash, + "wtUrl": f"https://{args.wt_host}:{args.wt_port}/", + "wtPort": args.wt_port + }) + + # Optional static HTTPS server for client assets + httpd = None + if args.static: + httpd, _t = start_https_static( + args.static_host, + args.static_port, + cert_file, + args.key, + args.static_root, + cert_hash_json=cert_hash_json + ) + print(f"HTTPS static server: https://{args.static_host}:{args.static_port}/ serving '{args.static_root}'") + print(f"Certificate SHA-256: {cert_hash}") + + server = GameServer(transport=WebTransportServer(args.wt_host, args.wt_port, cert_file, args.key, lambda d, p: server.on_datagram(d, p)), config=cfg) + print(f"WebTransport server: https://{args.wt_host}:{args.wt_port}/ (HTTP/3)") + try: + tasks = [asyncio.create_task(server.transport.run()), asyncio.create_task(server.tick_loop())] + await _run_tasks_with_optional_timeout(tasks, args.run_seconds) finally: if httpd is not None: httpd.shutdown() + if temp_cert_file: + os.unlink(temp_cert_file) # Clean up temp combined cert file + + +async def run_net(args): + import os + from server.webtransport_server import WebTransportServer + from server.quic_transport import QuicWebTransportServer + from server.multi_transport import MultiTransport + from server.utils import combine_cert_chain + + cfg = ServerConfig( + width=args.width, + height=args.height, + tick_rate=args.tick_rate, + wrap_edges=args.wrap_edges, + apples_per_snake=args.apples_per_snake, + apples_cap=args.apples_cap, + players_max=args.players_max, + ) + + # Handle certificate chain + # Check if cert file contains multiple certificates (like Let's Encrypt fullchain.pem) + with open(args.cert, 'r') as f: + cert_content = f.read() + + from server.utils import split_pem_certificates + certs_in_file = split_pem_certificates(cert_content) + + cert_file = args.cert + temp_cert_file = None + + if args.cert_chain: + # Combine cert + chain files into single file + temp_cert_file = combine_cert_chain(args.cert, args.cert_chain) + cert_file = temp_cert_file + logging.info("Combined certificate with %d chain file(s)", len(args.cert_chain)) + elif len(certs_in_file) > 1: + # Certificate file contains full chain - reformat it for aioquic + temp_cert_file = combine_cert_chain(args.cert, []) + cert_file = temp_cert_file + logging.info("Reformatted certificate chain from fullchain file (%d certs)", len(certs_in_file)) + + server: GameServer + wt = WebTransportServer(args.wt_host, args.wt_port, cert_file, args.key, lambda d, p: server.on_datagram(d, p)) + qu = QuicWebTransportServer(args.quic_host, args.quic_port, cert_file, args.key, lambda d, p: server.on_datagram(d, p)) + m = MultiTransport(wt, qu) + server = GameServer(transport=m, config=cfg) + logging.info("WebTransport server: %s:%d", args.wt_host, args.wt_port) + logging.info("QUIC server: %s:%d", args.quic_host, args.quic_port) + try: + tasks = [asyncio.create_task(m.run()), asyncio.create_task(server.tick_loop())] + await _run_tasks_with_optional_timeout(tasks, args.run_seconds) + finally: + if temp_cert_file: + os.unlink(temp_cert_file) # Clean up temp combined cert file + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Real-time multiplayer Snake game server", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python run.py # WebTransport mode with defaults + python run.py --mode mem # In-memory mode (testing) + python run.py --mode quic --quic-port 5000 # QUIC mode on custom port + python run.py --cert my.pem --key my.key # Custom TLS certificate + python run.py --no-static # Disable static HTTPS server + python run.py --width 80 --height 60 # Custom field size + """ + ) + + # Transport mode + parser.add_argument("--mode", choices=["mem", "wt", "quic", "net"], default="wt", + help="Transport mode: mem (in-memory), wt (WebTransport/HTTP3), quic (QUIC datagrams), net (both wt+quic). Default: wt") + + # TLS certificate/key + parser.add_argument("--cert", default="cert.pem", + help="TLS certificate file path (may contain full chain). Default: cert.pem") + parser.add_argument("--key", default="key.pem", + help="TLS private key file path. Default: key.pem") + parser.add_argument("--cert-chain", action="append", dest="cert_chain", + help="Intermediate certificate file (can be used multiple times for long chains)") + + # WebTransport server settings + parser.add_argument("--wt-host", default="0.0.0.0", + help="WebTransport server host. Default: 0.0.0.0") + parser.add_argument("--wt-port", type=int, default=4433, + help="WebTransport server port. Default: 4433") + + # QUIC server settings + parser.add_argument("--quic-host", default="0.0.0.0", + help="QUIC server host. Default: 0.0.0.0") + parser.add_argument("--quic-port", type=int, default=4433, + help="QUIC server port. Default: 4433 (or 4443 in net mode)") + + # Static HTTPS server settings + parser.add_argument("--static", dest="static", action="store_true", default=True, + help="Enable static HTTPS server for client files (default in wt mode)") + parser.add_argument("--no-static", dest="static", action="store_false", + help="Disable static HTTPS server") + parser.add_argument("--static-host", default=None, + help="Static server host. Default: same as wt-host") + parser.add_argument("--static-port", type=int, default=8443, + help="Static server port. Default: 8443") + parser.add_argument("--static-root", default="client", + help="Static files directory. Default: client") + + # Game configuration + parser.add_argument("--width", type=int, default=60, + help="Field width (3-255). Default: 60") + parser.add_argument("--height", type=int, default=40, + help="Field height (3-255). Default: 40") + parser.add_argument("--tick-rate", type=int, default=10, + help="Server tick rate in TPS (5-30). Default: 10") + parser.add_argument("--wrap-edges", action="store_true", default=False, + help="Enable edge wrapping (default: disabled)") + parser.add_argument("--apples-per-snake", type=int, default=1, + help="Apples per connected snake (1-12). Default: 1") + parser.add_argument("--apples-cap", type=int, default=255, + help="Maximum total apples (0-255). Default: 255") + parser.add_argument("--players-max", type=int, default=32, + help="Maximum concurrent players. Default: 32") + + # Logging and testing + parser.add_argument("--log-level", default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level. Default: INFO") + parser.add_argument("--run-seconds", type=float, default=None, + help="Optional timeout in seconds for testing") + + args = parser.parse_args() + + # Post-process: static-host defaults to wt-host + if args.static_host is None: + args.static_host = args.wt_host + + # In net mode, default quic-port to 4443 if not explicitly set + if args.mode == "net" and "--quic-port" not in sys.argv: + args.quic_port = 4443 + + # Validate TLS files for modes that need them + if args.mode in ("wt", "quic", "net"): + import os + if not os.path.exists(args.cert): + parser.error(f"Certificate file not found: {args.cert}") + if not os.path.exists(args.key): + parser.error(f"Key file not found: {args.key}") + + return args if __name__ == "__main__": try: - if any(a in ("-h", "--help") for a in sys.argv[1:]): - print( - "Usage: python run.py [--mode mem|quic|wt|net] [--log-level LEVEL] [--run-seconds N]\n" - " TLS (for wt/quic): set QUIC_CERT/QUIC_KEY or WT_CERT/WT_KEY env vars\n" - " WT static server (MODE=wt): STATIC=1 [STATIC_HOST/PORT/ROOT]\n" - "Examples:\n MODE=wt QUIC_CERT=cert.pem QUIC_KEY=key.pem python run.py\n MODE=mem python run.py" - ) - sys.exit(0) + args = parse_args() + # Logging setup - level = os.environ.get("LOG_LEVEL", "INFO").upper() - logging.basicConfig(level=getattr(logging, level, logging.INFO), format="[%(asctime)s] %(levelname)s: %(message)s") - mode = os.environ.get("MODE", "mem").lower() - if mode == "wt": + logging.basicConfig( + level=getattr(logging, args.log_level.upper()), + format="[%(asctime)s] %(levelname)s: %(message)s" + ) + + # Run appropriate mode + if args.mode == "wt": logging.info("Starting in WebTransport mode") - asyncio.run(run_webtransport()) - elif mode == "quic": + asyncio.run(run_webtransport(args)) + elif args.mode == "quic": logging.info("Starting in QUIC datagram mode") - asyncio.run(run_quic()) - elif mode == "net": + asyncio.run(run_quic(args)) + elif args.mode == "net": logging.info("Starting in combined WebTransport+QUIC mode") - from server.webtransport_server import WebTransportServer - from server.quic_transport import QuicWebTransportServer - from server.multi_transport import MultiTransport - cfg = ServerConfig() - host_wt = os.environ.get("WT_HOST", os.environ.get("QUIC_HOST", "0.0.0.0")) - port_wt = int(os.environ.get("WT_PORT", os.environ.get("QUIC_PORT", "4433"))) - host_quic = os.environ.get("QUIC_HOST", host_wt) - port_quic = int(os.environ.get("QUIC_PORT", "4443")) - cert = os.environ.get("WT_CERT") or os.environ.get("QUIC_CERT") - key = os.environ.get("WT_KEY") or os.environ.get("QUIC_KEY") - if not cert or not key: - raise SystemExit("WT/QUIC cert/key required: set WT_CERT/WT_KEY or QUIC_CERT/QUIC_KEY") - async def _run_net(): - server: GameServer - wt = WebTransportServer(host_wt, port_wt, cert, key, lambda d, p: server.on_datagram(d, p)) - qu = QuicWebTransportServer(host_quic, port_quic, cert, key, lambda d, p: server.on_datagram(d, p)) - m = MultiTransport(wt, qu) - server = GameServer(transport=m, config=cfg) - await asyncio.gather(m.run(), server.tick_loop()) - asyncio.run(_run_net()) - else: + asyncio.run(run_net(args)) + else: # mem logging.info("Starting in in-memory transport mode") - asyncio.run(run_in_memory()) + asyncio.run(run_in_memory(args)) except KeyboardInterrupt: pass - - - - - - -async def _run_tasks_with_optional_timeout(tasks): - """Await tasks, optionally honoring RUN_SECONDS env var to cancel after a timeout.""" - timeout_s = os.environ.get("RUN_SECONDS") - if not timeout_s: - await asyncio.gather(*tasks) - return - try: - await asyncio.wait_for(asyncio.gather(*tasks), timeout=float(timeout_s)) - except asyncio.TimeoutError: - logging.info("Timeout reached (RUN_SECONDS=%s); stopping server tasks...", timeout_s) - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - diff --git a/server/quic_transport.py b/server/quic_transport.py index d6459b3..fbdf44d 100644 --- a/server/quic_transport.py +++ b/server/quic_transport.py @@ -44,7 +44,7 @@ class GameQuicProtocol(QuicConnectionProtocol): # type: ignore[misc] async def send_datagram(self, data: bytes) -> None: logging.debug("QUIC send datagram: %d bytes", len(data)) self._quic.send_datagram_frame(data) # type: ignore[attr-defined] - await self._loop.run_in_executor(None, self.transmit) # type: ignore[attr-defined] + self.transmit() # Send queued QUIC packets immediately class QuicWebTransportServer(DatagramServerTransport): @@ -65,15 +65,20 @@ class QuicWebTransportServer(DatagramServerTransport): await proto.send_datagram(data) async def run(self) -> None: - configuration = QuicConfiguration(is_client=False, alpn_protocols=["h3"]) + configuration = QuicConfiguration( + is_client=False, + alpn_protocols=["h3"], + max_datagram_frame_size=65536 # Enable QUIC datagrams + ) configuration.load_cert_chain(self.certfile, self.keyfile) - async def _create_protocol(*args, **kwargs): + def _create_protocol(*args, **kwargs): return GameQuicProtocol(*args, on_datagram=self._on_datagram, peers=self._peers, **kwargs) logging.info("QUIC datagram server listening on %s:%d", self.host, self.port) self._server = await serve(self.host, self.port, configuration=configuration, create_protocol=_create_protocol) try: - await self._server.wait_closed() + # Wait indefinitely until cancelled (e.g., by KeyboardInterrupt or timeout) + await asyncio.Event().wait() finally: self._server.close() diff --git a/server/static_server.py b/server/static_server.py index 57b2c0c..8c957e5 100644 --- a/server/static_server.py +++ b/server/static_server.py @@ -3,25 +3,66 @@ from __future__ import annotations import ssl import logging import threading +import json from http.server import ThreadingHTTPServer, SimpleHTTPRequestHandler from pathlib import Path -from typing import Tuple +from typing import Tuple, Optional class _Handler(SimpleHTTPRequestHandler): - # Allow passing a base directory at construction time + # Allow passing a base directory and cert_hash_json at construction time + cert_hash_json: Optional[str] = None + def __init__(self, *args, directory: str | None = None, **kwargs): super().__init__(*args, directory=directory, **kwargs) + def end_headers(self): + # Add no-cache headers for all static files to force browser to fetch fresh versions + self.send_header('Cache-Control', 'no-cache, no-store, must-revalidate') + self.send_header('Pragma', 'no-cache') + self.send_header('Expires', '0') + super().end_headers() -def start_https_static(host: str, port: int, certfile: str, keyfile: str, docroot: str) -> Tuple[ThreadingHTTPServer, threading.Thread]: + def do_GET(self): + # Intercept /cert-hash.json requests + if self.path == '/cert-hash.json' and self.cert_hash_json: + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.send_header('Access-Control-Allow-Origin', '*') + self.end_headers() + self.wfile.write(self.cert_hash_json.encode('utf-8')) + else: + # Serve regular static files + super().do_GET() + + +def start_https_static( + host: str, + port: int, + certfile: str, + keyfile: str, + docroot: str, + cert_hash_json: Optional[str] = None +) -> Tuple[ThreadingHTTPServer, threading.Thread]: """Start a simple HTTPS static file server in a background thread. + Args: + host: Host to bind to + port: Port to bind to + certfile: Path to TLS certificate + keyfile: Path to TLS private key + docroot: Document root directory + cert_hash_json: Optional JSON string to serve at /cert-hash.json + Returns the (httpd, thread). Caller is responsible for calling httpd.shutdown() to stop the server on application exit. """ docroot_path = str(Path(docroot).resolve()) + # Set class variable for the handler + if cert_hash_json: + _Handler.cert_hash_json = cert_hash_json + def handler(*args, **kwargs): return _Handler(*args, directory=docroot_path, **kwargs) diff --git a/server/utils.py b/server/utils.py index ccb6b9d..b72dbc3 100644 --- a/server/utils.py +++ b/server/utils.py @@ -1,5 +1,11 @@ from __future__ import annotations +import hashlib +import re +import tempfile +from typing import List, Optional +from pathlib import Path + def is_newer(a: int, b: int) -> bool: """Return True if 16-bit sequence number a is newer than b (wrap-aware). @@ -8,3 +14,129 @@ def is_newer(a: int, b: int) -> bool: """ return ((a - b) & 0xFFFF) < 0x8000 + +def get_cert_sha256_hash(cert_path: str) -> str: + """Calculate SHA-256 hash of a certificate file. + + Returns the hash as a lowercase hex string (64 characters). + This hash can be used by WebTransport clients for certificate pinning. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + + # Read the certificate file + with open(cert_path, 'rb') as f: + cert_data = f.read() + + # Parse the certificate + cert = x509.load_pem_x509_certificate(cert_data, default_backend()) + + # Get the DER-encoded certificate and hash it + cert_der = cert.public_bytes(encoding=x509.Encoding.DER) + hash_bytes = hashlib.sha256(cert_der).digest() + + # Return as hex string + return hash_bytes.hex() + except Exception as e: + # Fallback: just hash the file contents (less accurate but works) + with open(cert_path, 'rb') as f: + return hashlib.sha256(f.read()).digest().hex() + + +def split_pem_certificates(pem_data: str) -> List[str]: + """Split a PEM file containing multiple certificates into individual certs. + + Args: + pem_data: String containing one or more PEM certificates + + Returns: + List of individual certificate strings (each includes BEGIN/END markers) + """ + # Match all certificate blocks + cert_pattern = r'(-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----)' + certificates = re.findall(cert_pattern, pem_data, re.DOTALL) + return certificates + + +def combine_cert_chain(cert_path: str, chain_paths: List[str]) -> str: + """Combine server cert and chain certs into a single properly-formatted PEM file. + + Args: + cert_path: Path to server certificate (may contain full chain) + chain_paths: List of paths to intermediate/root certificates (optional) + + Returns: + Path to temporary combined certificate file (caller should clean up) + """ + combined_content = [] + + # Read server certificate file + with open(cert_path, 'r') as f: + cert_data = f.read().strip() + certs = split_pem_certificates(cert_data) + + if chain_paths: + # If chain files provided separately, only use first cert from cert_path + if certs: + combined_content.append(certs[0].strip()) + else: + combined_content.append(cert_data) + else: + # No separate chain files - include all certs from cert_path (fullchain case) + if certs: + combined_content.extend([c.strip() for c in certs]) + else: + combined_content.append(cert_data) + + # Read separate chain certificate files + for chain_path in chain_paths: + with open(chain_path, 'r') as f: + chain_data = f.read().strip() + # Each chain file might contain multiple certs + chain_certs = split_pem_certificates(chain_data) + if chain_certs: + combined_content.extend([c.strip() for c in chain_certs]) + else: + combined_content.append(chain_data) + + # Write to temporary file with proper formatting + # Each cert should be separated by a newline + temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False) + temp_file.write('\n'.join(combined_content)) + temp_file.write('\n') # End with newline + temp_file.close() + + return temp_file.name + + +def extract_server_cert(cert_path: str) -> str: + """Extract just the server certificate from a file that may contain a full chain. + + Args: + cert_path: Path to certificate file (may contain chain) + + Returns: + Path to temporary file containing only the server certificate + """ + with open(cert_path, 'r') as f: + cert_data = f.read() + + # Split into individual certificates + certs = split_pem_certificates(cert_data) + + if not certs: + # No certs found, return original path + return cert_path + + # First cert is the server cert + server_cert = certs[0].strip() + + # Write to temporary file + temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False) + temp_file.write(server_cert) + temp_file.write('\n') + temp_file.close() + + return temp_file.name + diff --git a/server/webtransport_server.py b/server/webtransport_server.py index 9d80849..72d3962 100644 --- a/server/webtransport_server.py +++ b/server/webtransport_server.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio from dataclasses import dataclass -from typing import Awaitable, Callable, Dict, Optional +from typing import Callable, Dict, Optional from .transport import DatagramServerTransport, OnDatagram, TransportPeer import logging @@ -11,7 +11,8 @@ import logging try: from aioquic.asyncio import QuicConnectionProtocol, serve from aioquic.quic.configuration import QuicConfiguration - from aioquic.h3.connection import H3Connection + from aioquic.quic.events import ProtocolNegotiated + from aioquic.h3.connection import H3_ALPN, H3Connection from aioquic.h3.events import HeadersReceived # Datagram event names vary by aioquic version; try both try: @@ -21,6 +22,8 @@ try: except Exception: # pragma: no cover - optional dependency not installed QuicConnectionProtocol = object # type: ignore QuicConfiguration = object # type: ignore + ProtocolNegotiated = object # type: ignore + H3_ALPN = [] # type: ignore H3Connection = object # type: ignore HeadersReceived = object # type: ignore H3DatagramReceived = object # type: ignore @@ -38,55 +41,75 @@ class GameWTProtocol(QuicConnectionProtocol): # type: ignore[misc] super().__init__(*args, **kwargs) self._on_datagram = on_datagram self._sessions = sessions - self._http: Optional[H3Connection] = None + self._http: Optional[H3Connection] = None # Created after ProtocolNegotiated + logging.debug("GameWTProtocol initialized") def http_event_received(self, event) -> None: # type: ignore[override] + logging.debug("HTTP event received: %s", type(event).__name__) # Headers for CONNECT :protocol = webtransport open a session if isinstance(event, HeadersReceived): headers = {k.decode().lower(): v.decode() for k, v in event.headers} + logging.debug("HeadersReceived: %s", headers) method = headers.get(":method") protocol = headers.get(":protocol") or headers.get("sec-webtransport-protocol") + logging.debug("Method: %s, Protocol: %s", method, protocol) if method == "CONNECT" and (protocol == "webtransport"): - # In WebTransport over H3, datagrams use the CONNECT stream id as flow_id - flow_id = event.stream_id # type: ignore[attr-defined] - self._sessions[flow_id] = WTSession(flow_id=flow_id, proto=self) - logging.info("WT CONNECT accepted: flow_id=%s", flow_id) + # In WebTransport over H3, datagrams use the CONNECT stream id as session id + session_id = event.stream_id # type: ignore[attr-defined] + self._sessions[session_id] = WTSession(flow_id=session_id, proto=self) + logging.info("WT CONNECT accepted: session_id=%s", session_id) # Send 2xx to accept the session if self._http is not None: self._http.send_headers(event.stream_id, [(b":status", b"200")]) + self.transmit() # Actually send the response to complete handshake + logging.debug("Sent 200 response for WT CONNECT") + else: + logging.warning("Unexpected CONNECT: method=%s, protocol=%s", method, protocol) elif isinstance(event, H3DatagramReceived): # type: ignore[misc] - # Route datagram to session by flow_id - flow_id = getattr(event, "flow_id", None) + # Route datagram to session by stream_id (WebTransport session ID) + # DatagramReceived has: stream_id (session), data (payload) + session_id = getattr(event, "stream_id", None) data = getattr(event, "data", None) - if flow_id is None or data is None: + logging.debug("DatagramReceived event: session_id=%s, data_len=%s, has_data=%s", + session_id, len(data) if data else None, data is not None) + if session_id is None or data is None: + logging.warning("Invalid datagram event: session_id=%s, data=%s", session_id, data) return - sess = self._sessions.get(flow_id) + sess = self._sessions.get(session_id) if not sess: + logging.warning("No session found for datagram: session_id=%s", session_id) return - peer = TransportPeer(addr=(self, flow_id)) - logging.debug("WT datagram received: flow_id=%s, %d bytes", flow_id, len(data)) + peer = TransportPeer(addr=(self, session_id)) + logging.info("WT datagram received: session_id=%s, %d bytes", session_id, len(data)) asyncio.ensure_future(self._on_datagram(bytes(data), peer)) def quic_event_received(self, event) -> None: # type: ignore[override] - # Lazily create H3 connection wrapper - if self._http is None and hasattr(self, "_quic"): - try: - self._http = H3Connection(self._quic, enable_webtransport=True) # type: ignore[attr-defined] - except Exception: - self._http = None + event_name = type(event).__name__ + logging.debug("QUIC event received: %s", event_name) + + # Create H3Connection after ALPN protocol is negotiated + if isinstance(event, ProtocolNegotiated): + if event.alpn_protocol in H3_ALPN: + self._http = H3Connection(self._quic, enable_webtransport=True) + logging.info("H3Connection created with WebTransport support (ALPN: %s)", event.alpn_protocol) + else: + logging.warning("Unexpected ALPN protocol: %s", event.alpn_protocol) + + # Pass event to HTTP layer if connection is established if self._http is not None: for http_event in self._http.handle_event(event): self.http_event_received(http_event) async def send_h3_datagram(self, flow_id: int, data: bytes) -> None: - if self._http is None: - return try: + if self._http is None: + logging.warning("Cannot send datagram: H3Connection not established") + return logging.debug("WT send datagram: flow_id=%s, %d bytes", flow_id, len(data)) - self._http.send_datagram(flow_id, data) - await self._loop.run_in_executor(None, self.transmit) # type: ignore[attr-defined] - except Exception: - pass + self._http.send_datagram(stream_id=flow_id, data=data) + self.transmit() # Send queued QUIC packets immediately + except Exception as e: + logging.debug("Failed to send datagram: %s", e) class WebTransportServer(DatagramServerTransport): @@ -115,15 +138,34 @@ class WebTransportServer(DatagramServerTransport): await proto.send_h3_datagram(flow_id, data) async def run(self) -> None: - configuration = QuicConfiguration(is_client=False, alpn_protocols=["h3"]) - configuration.load_cert_chain(self.certfile, self.keyfile) + from aioquic.quic.logger import QuicFileLogger + import os - async def _create_protocol(*args, **kwargs): + # Enable QUIC logging if debug level + quic_logger = None + if logging.getLogger().getEffectiveLevel() <= logging.DEBUG: + log_dir = os.path.join(os.getcwd(), "quic_logs") + os.makedirs(log_dir, exist_ok=True) + quic_logger = QuicFileLogger(log_dir) + logging.info("QUIC logging enabled in: %s", log_dir) + + configuration = QuicConfiguration( + is_client=False, + alpn_protocols=["h3"], # HTTP/3 with WebTransport support + max_datagram_frame_size=65536, # Enable QUIC datagrams (required for WebTransport) + quic_logger=quic_logger + ) + configuration.load_cert_chain(self.certfile, self.keyfile) + logging.debug("QUIC configuration: ALPN=%s, max_datagram_frame_size=%d", + configuration.alpn_protocols, configuration.max_datagram_frame_size or 0) + + def _create_protocol(*args, **kwargs): return GameWTProtocol(*args, on_datagram=self._on_datagram, sessions=self._sessions, **kwargs) logging.info("WebTransport (H3) server listening on %s:%d", self.host, self.port) self._server = await serve(self.host, self.port, configuration=configuration, create_protocol=_create_protocol) try: - await self._server.wait_closed() + # Wait indefinitely until cancelled (e.g., by KeyboardInterrupt or timeout) + await asyncio.Event().wait() finally: self._server.close() diff --git a/test_http3.py b/test_http3.py new file mode 100644 index 0000000..bb3a22e --- /dev/null +++ b/test_http3.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +"""Simple HTTP/3 client to test the server.""" + +import asyncio +import logging +from aioquic.asyncio import connect +from aioquic.quic.configuration import QuicConfiguration + +logging.basicConfig(level=logging.DEBUG) + +async def test_http3(url: str): + """Test HTTP/3 connection to the server.""" + print(f"Testing HTTP/3 connection to: {url}") + + # Parse URL + if url.startswith("https://"): + url = url[8:] + + host, port = url.split(":") + port = int(port.rstrip("/")) + + print(f"Connecting to {host}:{port}...") + + # Create QUIC configuration + configuration = QuicConfiguration( + is_client=True, + alpn_protocols=["h3"], + verify_mode=0 # Skip certificate verification for testing + ) + + try: + async with connect( + host, + port, + configuration=configuration, + ) as protocol: + print(f"βœ“ QUIC connection established!") + print(f" ALPN protocol: {protocol._quic.tls.alpn_negotiated}") + print(f" Remote address: {protocol._quic.remote_address}") + + # Just test the connection, don't send HTTP requests yet + await asyncio.sleep(1) + + print("βœ“ Connection successful!") + return True + + except Exception as e: + print(f"βœ— Connection failed: {e}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + import sys + + url = sys.argv[1] if len(sys.argv) > 1 else "https://127.0.0.1:4433" + + success = asyncio.run(test_http3(url)) + sys.exit(0 if success else 1) diff --git a/test_webtransport_client.py b/test_webtransport_client.py new file mode 100644 index 0000000..b3d6dc3 --- /dev/null +++ b/test_webtransport_client.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +""" +WebTransport test client to verify server is working correctly. +""" + +import asyncio +import logging +import sys +from typing import Optional + +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.h3.connection import H3Connection +from aioquic.h3.events import DatagramReceived, H3Event, HeadersReceived +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import QuicEvent + +logging.basicConfig( + level=logging.DEBUG, + format="[%(asctime)s] %(levelname)s: %(message)s" +) + +logger = logging.getLogger(__name__) + + +class WebTransportClient(QuicConnectionProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._http: Optional[H3Connection] = None + self._session_id: Optional[int] = None + self._session_established = asyncio.Event() + self._received_datagrams = [] + + def quic_event_received(self, event: QuicEvent) -> None: + logger.debug(f"QUIC event: {type(event).__name__}") + + # Create H3Connection on first event + if self._http is None: + self._http = H3Connection(self._quic, enable_webtransport=True) + logger.info("H3Connection created with WebTransport support") + + # Process event through HTTP/3 layer + if self._http is not None: + for http_event in self._http.handle_event(event): + self.http_event_received(http_event) + + def http_event_received(self, event: H3Event) -> None: + logger.debug(f"HTTP event: {type(event).__name__}") + + if isinstance(event, HeadersReceived): + headers = dict(event.headers) + status = headers.get(b":status", b"").decode() + logger.info(f"Received headers: status={status}") + + if status == "200" and self._session_id is None: + self._session_id = event.stream_id + logger.info(f"WebTransport session established! session_id={self._session_id}") + self._session_established.set() + else: + logger.error(f"WebTransport session rejected: status={status}, headers={headers}") + + elif isinstance(event, DatagramReceived): + logger.info(f"Received datagram: {len(event.data)} bytes, flow_id={event.flow_id}") + self._received_datagrams.append(event.data) + + async def establish_session(self, path: str = "/") -> None: + """Send WebTransport CONNECT request.""" + logger.info(f"Sending WebTransport CONNECT request for path: {path}") + + # Allocate stream ID for CONNECT request + stream_id = self._quic.get_next_available_stream_id() + + # Send CONNECT request with WebTransport protocol + headers = [ + (b":method", b"CONNECT"), + (b":scheme", b"https"), + (b":authority", b"np.vaku.org.ua:4433"), + (b":path", path.encode()), + (b":protocol", b"webtransport"), + ] + + self._http.send_headers(stream_id=stream_id, headers=headers, end_stream=False) + self.transmit() + + logger.info(f"WebTransport CONNECT sent on stream {stream_id}") + + # Wait for session to be established + try: + await asyncio.wait_for(self._session_established.wait(), timeout=5.0) + logger.info("βœ“ WebTransport session established successfully!") + return True + except asyncio.TimeoutError: + logger.error("βœ— Timeout waiting for WebTransport session acceptance") + return False + + async def send_datagram(self, data: bytes) -> None: + """Send a datagram over the WebTransport session.""" + if self._session_id is None: + logger.error("Cannot send datagram: session not established") + return + + logger.info(f"Sending datagram: {len(data)} bytes") + self._http.send_datagram(stream_id=self._session_id, data=data) + self.transmit() + + async def wait_for_datagram(self, timeout: float = 2.0) -> Optional[bytes]: + """Wait for a datagram response.""" + logger.info(f"Waiting for datagram response (timeout={timeout}s)...") + start = asyncio.get_event_loop().time() + + while asyncio.get_event_loop().time() - start < timeout: + if self._received_datagrams: + return self._received_datagrams.pop(0) + await asyncio.sleep(0.1) + + logger.warning("No datagram received within timeout") + return None + + +async def test_webtransport(host: str, port: int, verify_mode: bool = False): + """Test WebTransport connection to the server.""" + logger.info(f"Testing WebTransport connection to {host}:{port}") + + # Configure QUIC + configuration = QuicConfiguration( + is_client=True, + alpn_protocols=["h3"], # HTTP/3 + max_datagram_frame_size=65536, # Enable QUIC datagrams (required for WebTransport) + verify_mode=0 if not verify_mode else 2, # Skip cert verification for self-signed + ) + + try: + async with connect( + host, + port, + configuration=configuration, + create_protocol=WebTransportClient, + ) as client: + logger.info("βœ“ QUIC connection established") + client = client # type: WebTransportClient + + # Establish WebTransport session + if not await client.establish_session("/"): + logger.error("βœ— Failed to establish WebTransport session") + return False + + # Send a test datagram (simple game JOIN packet) + # Format: version(1) | type(1) | flags(1) | seq(2) | name_len(1) | name + test_packet = bytes([ + 0x01, # version + 0x01, # JOIN packet type + 0x00, # flags + 0x00, 0x01, # seq=1 + 0x04, # name length + ord('T'), ord('E'), ord('S'), ord('T'), # name="TEST" + ]) + + await client.send_datagram(test_packet) + logger.info("βœ“ Test JOIN packet sent") + + # Wait for response + response = await client.wait_for_datagram(timeout=3.0) + if response: + logger.info(f"βœ“ Received response: {len(response)} bytes") + logger.info(f" Response hex: {response.hex()}") + return True + else: + logger.warning("Server did not respond to JOIN packet (might be normal if not implemented)") + return True # Connection worked even if no response + + except Exception as e: + logger.error(f"βœ— Connection failed: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python test_webtransport_client.py [port]") + print("Example: python test_webtransport_client.py np.vaku.org.ua 4433") + sys.exit(1) + + host = sys.argv[1] + port = int(sys.argv[2]) if len(sys.argv) > 2 else 4433 + + success = asyncio.run(test_webtransport(host, port)) + sys.exit(0 if success else 1)