Fix WebTransport server implementation and add test client
Server fixes: - Move H3Connection initialization to ProtocolNegotiated event (matches official aioquic pattern) - Fix datagram routing to use session_id instead of flow_id - Add max_datagram_frame_size=65536 to enable QUIC datagrams - Fix send_datagram() to use keyword arguments - Add certificate chain handling for Let's Encrypt - Add no-cache headers to static server Command-line improvements: - Move settings from environment variables to argparse - Add comprehensive CLI arguments with defaults - Default mode=wt, cert=cert.pem, key=key.pem Test clients: - Add test_webtransport_client.py - Python WebTransport client that successfully connects - Add test_http3.py - Basic HTTP/3 connectivity test Client updates: - Auto-configure server URL and certificate hash from /cert-hash.json - Add ES6 module support Status: ✅ Python WebTransport client works perfectly ✅ Server properly handles WebTransport connections and datagrams ❌ Chrome fails due to cached QUIC state (QUIC_IETF_GQUIC_ERROR_MISSING) 🔍 Firefox sends packets but fails differently - to be debugged next session 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -44,7 +44,7 @@ class GameQuicProtocol(QuicConnectionProtocol): # type: ignore[misc]
|
||||
async def send_datagram(self, data: bytes) -> None:
|
||||
logging.debug("QUIC send datagram: %d bytes", len(data))
|
||||
self._quic.send_datagram_frame(data) # type: ignore[attr-defined]
|
||||
await self._loop.run_in_executor(None, self.transmit) # type: ignore[attr-defined]
|
||||
self.transmit() # Send queued QUIC packets immediately
|
||||
|
||||
|
||||
class QuicWebTransportServer(DatagramServerTransport):
|
||||
@@ -65,15 +65,20 @@ class QuicWebTransportServer(DatagramServerTransport):
|
||||
await proto.send_datagram(data)
|
||||
|
||||
async def run(self) -> None:
|
||||
configuration = QuicConfiguration(is_client=False, alpn_protocols=["h3"])
|
||||
configuration = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=["h3"],
|
||||
max_datagram_frame_size=65536 # Enable QUIC datagrams
|
||||
)
|
||||
configuration.load_cert_chain(self.certfile, self.keyfile)
|
||||
|
||||
async def _create_protocol(*args, **kwargs):
|
||||
def _create_protocol(*args, **kwargs):
|
||||
return GameQuicProtocol(*args, on_datagram=self._on_datagram, peers=self._peers, **kwargs)
|
||||
|
||||
logging.info("QUIC datagram server listening on %s:%d", self.host, self.port)
|
||||
self._server = await serve(self.host, self.port, configuration=configuration, create_protocol=_create_protocol)
|
||||
try:
|
||||
await self._server.wait_closed()
|
||||
# Wait indefinitely until cancelled (e.g., by KeyboardInterrupt or timeout)
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
self._server.close()
|
||||
|
||||
@@ -3,25 +3,66 @@ from __future__ import annotations
|
||||
import ssl
|
||||
import logging
|
||||
import threading
|
||||
import json
|
||||
from http.server import ThreadingHTTPServer, SimpleHTTPRequestHandler
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
class _Handler(SimpleHTTPRequestHandler):
|
||||
# Allow passing a base directory at construction time
|
||||
# Allow passing a base directory and cert_hash_json at construction time
|
||||
cert_hash_json: Optional[str] = None
|
||||
|
||||
def __init__(self, *args, directory: str | None = None, **kwargs):
|
||||
super().__init__(*args, directory=directory, **kwargs)
|
||||
|
||||
def end_headers(self):
|
||||
# Add no-cache headers for all static files to force browser to fetch fresh versions
|
||||
self.send_header('Cache-Control', 'no-cache, no-store, must-revalidate')
|
||||
self.send_header('Pragma', 'no-cache')
|
||||
self.send_header('Expires', '0')
|
||||
super().end_headers()
|
||||
|
||||
def start_https_static(host: str, port: int, certfile: str, keyfile: str, docroot: str) -> Tuple[ThreadingHTTPServer, threading.Thread]:
|
||||
def do_GET(self):
|
||||
# Intercept /cert-hash.json requests
|
||||
if self.path == '/cert-hash.json' and self.cert_hash_json:
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.end_headers()
|
||||
self.wfile.write(self.cert_hash_json.encode('utf-8'))
|
||||
else:
|
||||
# Serve regular static files
|
||||
super().do_GET()
|
||||
|
||||
|
||||
def start_https_static(
|
||||
host: str,
|
||||
port: int,
|
||||
certfile: str,
|
||||
keyfile: str,
|
||||
docroot: str,
|
||||
cert_hash_json: Optional[str] = None
|
||||
) -> Tuple[ThreadingHTTPServer, threading.Thread]:
|
||||
"""Start a simple HTTPS static file server in a background thread.
|
||||
|
||||
Args:
|
||||
host: Host to bind to
|
||||
port: Port to bind to
|
||||
certfile: Path to TLS certificate
|
||||
keyfile: Path to TLS private key
|
||||
docroot: Document root directory
|
||||
cert_hash_json: Optional JSON string to serve at /cert-hash.json
|
||||
|
||||
Returns the (httpd, thread). Caller is responsible for calling httpd.shutdown()
|
||||
to stop the server on application exit.
|
||||
"""
|
||||
docroot_path = str(Path(docroot).resolve())
|
||||
|
||||
# Set class variable for the handler
|
||||
if cert_hash_json:
|
||||
_Handler.cert_hash_json = cert_hash_json
|
||||
|
||||
def handler(*args, **kwargs):
|
||||
return _Handler(*args, directory=docroot_path, **kwargs)
|
||||
|
||||
|
||||
132
server/utils.py
132
server/utils.py
@@ -1,5 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def is_newer(a: int, b: int) -> bool:
|
||||
"""Return True if 16-bit sequence number a is newer than b (wrap-aware).
|
||||
@@ -8,3 +14,129 @@ def is_newer(a: int, b: int) -> bool:
|
||||
"""
|
||||
return ((a - b) & 0xFFFF) < 0x8000
|
||||
|
||||
|
||||
def get_cert_sha256_hash(cert_path: str) -> str:
|
||||
"""Calculate SHA-256 hash of a certificate file.
|
||||
|
||||
Returns the hash as a lowercase hex string (64 characters).
|
||||
This hash can be used by WebTransport clients for certificate pinning.
|
||||
"""
|
||||
try:
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# Read the certificate file
|
||||
with open(cert_path, 'rb') as f:
|
||||
cert_data = f.read()
|
||||
|
||||
# Parse the certificate
|
||||
cert = x509.load_pem_x509_certificate(cert_data, default_backend())
|
||||
|
||||
# Get the DER-encoded certificate and hash it
|
||||
cert_der = cert.public_bytes(encoding=x509.Encoding.DER)
|
||||
hash_bytes = hashlib.sha256(cert_der).digest()
|
||||
|
||||
# Return as hex string
|
||||
return hash_bytes.hex()
|
||||
except Exception as e:
|
||||
# Fallback: just hash the file contents (less accurate but works)
|
||||
with open(cert_path, 'rb') as f:
|
||||
return hashlib.sha256(f.read()).digest().hex()
|
||||
|
||||
|
||||
def split_pem_certificates(pem_data: str) -> List[str]:
|
||||
"""Split a PEM file containing multiple certificates into individual certs.
|
||||
|
||||
Args:
|
||||
pem_data: String containing one or more PEM certificates
|
||||
|
||||
Returns:
|
||||
List of individual certificate strings (each includes BEGIN/END markers)
|
||||
"""
|
||||
# Match all certificate blocks
|
||||
cert_pattern = r'(-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----)'
|
||||
certificates = re.findall(cert_pattern, pem_data, re.DOTALL)
|
||||
return certificates
|
||||
|
||||
|
||||
def combine_cert_chain(cert_path: str, chain_paths: List[str]) -> str:
|
||||
"""Combine server cert and chain certs into a single properly-formatted PEM file.
|
||||
|
||||
Args:
|
||||
cert_path: Path to server certificate (may contain full chain)
|
||||
chain_paths: List of paths to intermediate/root certificates (optional)
|
||||
|
||||
Returns:
|
||||
Path to temporary combined certificate file (caller should clean up)
|
||||
"""
|
||||
combined_content = []
|
||||
|
||||
# Read server certificate file
|
||||
with open(cert_path, 'r') as f:
|
||||
cert_data = f.read().strip()
|
||||
certs = split_pem_certificates(cert_data)
|
||||
|
||||
if chain_paths:
|
||||
# If chain files provided separately, only use first cert from cert_path
|
||||
if certs:
|
||||
combined_content.append(certs[0].strip())
|
||||
else:
|
||||
combined_content.append(cert_data)
|
||||
else:
|
||||
# No separate chain files - include all certs from cert_path (fullchain case)
|
||||
if certs:
|
||||
combined_content.extend([c.strip() for c in certs])
|
||||
else:
|
||||
combined_content.append(cert_data)
|
||||
|
||||
# Read separate chain certificate files
|
||||
for chain_path in chain_paths:
|
||||
with open(chain_path, 'r') as f:
|
||||
chain_data = f.read().strip()
|
||||
# Each chain file might contain multiple certs
|
||||
chain_certs = split_pem_certificates(chain_data)
|
||||
if chain_certs:
|
||||
combined_content.extend([c.strip() for c in chain_certs])
|
||||
else:
|
||||
combined_content.append(chain_data)
|
||||
|
||||
# Write to temporary file with proper formatting
|
||||
# Each cert should be separated by a newline
|
||||
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False)
|
||||
temp_file.write('\n'.join(combined_content))
|
||||
temp_file.write('\n') # End with newline
|
||||
temp_file.close()
|
||||
|
||||
return temp_file.name
|
||||
|
||||
|
||||
def extract_server_cert(cert_path: str) -> str:
|
||||
"""Extract just the server certificate from a file that may contain a full chain.
|
||||
|
||||
Args:
|
||||
cert_path: Path to certificate file (may contain chain)
|
||||
|
||||
Returns:
|
||||
Path to temporary file containing only the server certificate
|
||||
"""
|
||||
with open(cert_path, 'r') as f:
|
||||
cert_data = f.read()
|
||||
|
||||
# Split into individual certificates
|
||||
certs = split_pem_certificates(cert_data)
|
||||
|
||||
if not certs:
|
||||
# No certs found, return original path
|
||||
return cert_path
|
||||
|
||||
# First cert is the server cert
|
||||
server_cert = certs[0].strip()
|
||||
|
||||
# Write to temporary file
|
||||
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False)
|
||||
temp_file.write(server_cert)
|
||||
temp_file.write('\n')
|
||||
temp_file.close()
|
||||
|
||||
return temp_file.name
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Optional
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from .transport import DatagramServerTransport, OnDatagram, TransportPeer
|
||||
import logging
|
||||
@@ -11,7 +11,8 @@ import logging
|
||||
try:
|
||||
from aioquic.asyncio import QuicConnectionProtocol, serve
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.h3.connection import H3Connection
|
||||
from aioquic.quic.events import ProtocolNegotiated
|
||||
from aioquic.h3.connection import H3_ALPN, H3Connection
|
||||
from aioquic.h3.events import HeadersReceived
|
||||
# Datagram event names vary by aioquic version; try both
|
||||
try:
|
||||
@@ -21,6 +22,8 @@ try:
|
||||
except Exception: # pragma: no cover - optional dependency not installed
|
||||
QuicConnectionProtocol = object # type: ignore
|
||||
QuicConfiguration = object # type: ignore
|
||||
ProtocolNegotiated = object # type: ignore
|
||||
H3_ALPN = [] # type: ignore
|
||||
H3Connection = object # type: ignore
|
||||
HeadersReceived = object # type: ignore
|
||||
H3DatagramReceived = object # type: ignore
|
||||
@@ -38,55 +41,75 @@ class GameWTProtocol(QuicConnectionProtocol): # type: ignore[misc]
|
||||
super().__init__(*args, **kwargs)
|
||||
self._on_datagram = on_datagram
|
||||
self._sessions = sessions
|
||||
self._http: Optional[H3Connection] = None
|
||||
self._http: Optional[H3Connection] = None # Created after ProtocolNegotiated
|
||||
logging.debug("GameWTProtocol initialized")
|
||||
|
||||
def http_event_received(self, event) -> None: # type: ignore[override]
|
||||
logging.debug("HTTP event received: %s", type(event).__name__)
|
||||
# Headers for CONNECT :protocol = webtransport open a session
|
||||
if isinstance(event, HeadersReceived):
|
||||
headers = {k.decode().lower(): v.decode() for k, v in event.headers}
|
||||
logging.debug("HeadersReceived: %s", headers)
|
||||
method = headers.get(":method")
|
||||
protocol = headers.get(":protocol") or headers.get("sec-webtransport-protocol")
|
||||
logging.debug("Method: %s, Protocol: %s", method, protocol)
|
||||
if method == "CONNECT" and (protocol == "webtransport"):
|
||||
# In WebTransport over H3, datagrams use the CONNECT stream id as flow_id
|
||||
flow_id = event.stream_id # type: ignore[attr-defined]
|
||||
self._sessions[flow_id] = WTSession(flow_id=flow_id, proto=self)
|
||||
logging.info("WT CONNECT accepted: flow_id=%s", flow_id)
|
||||
# In WebTransport over H3, datagrams use the CONNECT stream id as session id
|
||||
session_id = event.stream_id # type: ignore[attr-defined]
|
||||
self._sessions[session_id] = WTSession(flow_id=session_id, proto=self)
|
||||
logging.info("WT CONNECT accepted: session_id=%s", session_id)
|
||||
# Send 2xx to accept the session
|
||||
if self._http is not None:
|
||||
self._http.send_headers(event.stream_id, [(b":status", b"200")])
|
||||
self.transmit() # Actually send the response to complete handshake
|
||||
logging.debug("Sent 200 response for WT CONNECT")
|
||||
else:
|
||||
logging.warning("Unexpected CONNECT: method=%s, protocol=%s", method, protocol)
|
||||
elif isinstance(event, H3DatagramReceived): # type: ignore[misc]
|
||||
# Route datagram to session by flow_id
|
||||
flow_id = getattr(event, "flow_id", None)
|
||||
# Route datagram to session by stream_id (WebTransport session ID)
|
||||
# DatagramReceived has: stream_id (session), data (payload)
|
||||
session_id = getattr(event, "stream_id", None)
|
||||
data = getattr(event, "data", None)
|
||||
if flow_id is None or data is None:
|
||||
logging.debug("DatagramReceived event: session_id=%s, data_len=%s, has_data=%s",
|
||||
session_id, len(data) if data else None, data is not None)
|
||||
if session_id is None or data is None:
|
||||
logging.warning("Invalid datagram event: session_id=%s, data=%s", session_id, data)
|
||||
return
|
||||
sess = self._sessions.get(flow_id)
|
||||
sess = self._sessions.get(session_id)
|
||||
if not sess:
|
||||
logging.warning("No session found for datagram: session_id=%s", session_id)
|
||||
return
|
||||
peer = TransportPeer(addr=(self, flow_id))
|
||||
logging.debug("WT datagram received: flow_id=%s, %d bytes", flow_id, len(data))
|
||||
peer = TransportPeer(addr=(self, session_id))
|
||||
logging.info("WT datagram received: session_id=%s, %d bytes", session_id, len(data))
|
||||
asyncio.ensure_future(self._on_datagram(bytes(data), peer))
|
||||
|
||||
def quic_event_received(self, event) -> None: # type: ignore[override]
|
||||
# Lazily create H3 connection wrapper
|
||||
if self._http is None and hasattr(self, "_quic"):
|
||||
try:
|
||||
self._http = H3Connection(self._quic, enable_webtransport=True) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
self._http = None
|
||||
event_name = type(event).__name__
|
||||
logging.debug("QUIC event received: %s", event_name)
|
||||
|
||||
# Create H3Connection after ALPN protocol is negotiated
|
||||
if isinstance(event, ProtocolNegotiated):
|
||||
if event.alpn_protocol in H3_ALPN:
|
||||
self._http = H3Connection(self._quic, enable_webtransport=True)
|
||||
logging.info("H3Connection created with WebTransport support (ALPN: %s)", event.alpn_protocol)
|
||||
else:
|
||||
logging.warning("Unexpected ALPN protocol: %s", event.alpn_protocol)
|
||||
|
||||
# Pass event to HTTP layer if connection is established
|
||||
if self._http is not None:
|
||||
for http_event in self._http.handle_event(event):
|
||||
self.http_event_received(http_event)
|
||||
|
||||
async def send_h3_datagram(self, flow_id: int, data: bytes) -> None:
|
||||
if self._http is None:
|
||||
return
|
||||
try:
|
||||
if self._http is None:
|
||||
logging.warning("Cannot send datagram: H3Connection not established")
|
||||
return
|
||||
logging.debug("WT send datagram: flow_id=%s, %d bytes", flow_id, len(data))
|
||||
self._http.send_datagram(flow_id, data)
|
||||
await self._loop.run_in_executor(None, self.transmit) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
self._http.send_datagram(stream_id=flow_id, data=data)
|
||||
self.transmit() # Send queued QUIC packets immediately
|
||||
except Exception as e:
|
||||
logging.debug("Failed to send datagram: %s", e)
|
||||
|
||||
|
||||
class WebTransportServer(DatagramServerTransport):
|
||||
@@ -115,15 +138,34 @@ class WebTransportServer(DatagramServerTransport):
|
||||
await proto.send_h3_datagram(flow_id, data)
|
||||
|
||||
async def run(self) -> None:
|
||||
configuration = QuicConfiguration(is_client=False, alpn_protocols=["h3"])
|
||||
configuration.load_cert_chain(self.certfile, self.keyfile)
|
||||
from aioquic.quic.logger import QuicFileLogger
|
||||
import os
|
||||
|
||||
async def _create_protocol(*args, **kwargs):
|
||||
# Enable QUIC logging if debug level
|
||||
quic_logger = None
|
||||
if logging.getLogger().getEffectiveLevel() <= logging.DEBUG:
|
||||
log_dir = os.path.join(os.getcwd(), "quic_logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
quic_logger = QuicFileLogger(log_dir)
|
||||
logging.info("QUIC logging enabled in: %s", log_dir)
|
||||
|
||||
configuration = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=["h3"], # HTTP/3 with WebTransport support
|
||||
max_datagram_frame_size=65536, # Enable QUIC datagrams (required for WebTransport)
|
||||
quic_logger=quic_logger
|
||||
)
|
||||
configuration.load_cert_chain(self.certfile, self.keyfile)
|
||||
logging.debug("QUIC configuration: ALPN=%s, max_datagram_frame_size=%d",
|
||||
configuration.alpn_protocols, configuration.max_datagram_frame_size or 0)
|
||||
|
||||
def _create_protocol(*args, **kwargs):
|
||||
return GameWTProtocol(*args, on_datagram=self._on_datagram, sessions=self._sessions, **kwargs)
|
||||
|
||||
logging.info("WebTransport (H3) server listening on %s:%d", self.host, self.port)
|
||||
self._server = await serve(self.host, self.port, configuration=configuration, create_protocol=_create_protocol)
|
||||
try:
|
||||
await self._server.wait_closed()
|
||||
# Wait indefinitely until cancelled (e.g., by KeyboardInterrupt or timeout)
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
self._server.close()
|
||||
|
||||
Reference in New Issue
Block a user