Fix WebTransport server implementation and add test client

Server fixes:
- Move H3Connection initialization to ProtocolNegotiated event (matches official aioquic pattern)
- Fix datagram routing to use session_id instead of flow_id
- Add max_datagram_frame_size=65536 to enable QUIC datagrams
- Fix send_datagram() to use keyword arguments
- Add certificate chain handling for Let's Encrypt
- Add no-cache headers to static server

Command-line improvements:
- Move settings from environment variables to argparse
- Add comprehensive CLI arguments with defaults
- Default mode=wt, cert=cert.pem, key=key.pem

Test clients:
- Add test_webtransport_client.py - Python WebTransport client that successfully connects
- Add test_http3.py - Basic HTTP/3 connectivity test

Client updates:
- Auto-configure server URL and certificate hash from /cert-hash.json
- Add ES6 module support

Status:
 Python WebTransport client works perfectly
 Server properly handles WebTransport connections and datagrams
 Chrome fails due to cached QUIC state (QUIC_IETF_GQUIC_ERROR_MISSING)
🔍 Firefox sends packets but fails differently - to be debugged next session

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Vladyslav Doloman
2025-10-19 23:50:08 +00:00
parent ed5cb14b30
commit 1de5a8f3e6
9 changed files with 890 additions and 147 deletions

View File

@@ -193,24 +193,61 @@ async function readLoop() {
}
async function connectWT() {
// Check if WebTransport is available
if (typeof WebTransport === 'undefined') {
setStatus('ERROR: WebTransport not supported in this browser. Use Chrome/Edge 97+ or Firefox with flag enabled.');
return;
}
const url = ui.url.value.trim();
const hashHex = ui.hash.value.trim();
if (!url) {
setStatus('ERROR: Please enter a server URL');
return;
}
setStatus('connecting...');
console.log('Connecting to:', url);
const opts = {};
if (hashHex) {
const bytes = new Uint8Array(hashHex.match(/.{1,2}/g).map(b => parseInt(b,16)));
opts.serverCertificateHashes = [{ algorithm: 'sha-256', value: bytes }];
try {
const bytes = new Uint8Array(hashHex.match(/.{1,2}/g).map(b => parseInt(b,16)));
opts.serverCertificateHashes = [{ algorithm: 'sha-256', value: bytes }];
console.log('Using certificate hash:', hashHex);
} catch (err) {
setStatus('ERROR: Invalid certificate hash format');
return;
}
} else {
console.warn('No certificate hash provided - connection may fail for self-signed certs');
}
try {
const wt = new WebTransport(url, opts);
await wt.ready;
transport = wt;
writer = wt.datagrams.writable.getWriter();
setStatus('connected to server');
console.log('WebTransport connected successfully');
readLoop();
requestAnimationFrame(render);
// Send JOIN immediately (spectator → player upon space handled below)
const pkt = buildJoin(1, nextSeq(), ui.name.value.trim());
await writer.write(pkt);
} catch (err) {
console.error('WebTransport connection failed:', err);
if (err.message.includes('certificate')) {
setStatus('ERROR: Certificate validation failed. Check cert hash or use valid CA cert.');
} else if (err.message.includes('net::')) {
setStatus('ERROR: Network error - is server running on ' + url + '?');
} else {
setStatus('ERROR: ' + err.message);
}
}
const wt = new WebTransport(url, opts);
await wt.ready;
transport = wt;
writer = wt.datagrams.writable.getWriter();
setStatus('connected');
readLoop();
requestAnimationFrame(render);
// Send JOIN immediately (spectator → player upon space handled below)
const pkt = buildJoin(1, nextSeq(), ui.name.value.trim());
await writer.write(pkt);
}
function dirFromKey(e) {
@@ -229,7 +266,47 @@ async function onKey(e) {
try { await writer.write(pkt); } catch (err) { /* ignore */ }
}
// Auto-configure on page load
async function autoConfigureFromServer() {
try {
// Try to fetch cert hash from the static server
const response = await fetch('/cert-hash.json');
if (!response.ok) {
setStatus('Auto-config unavailable, please enter manually');
return;
}
const config = await response.json();
// Auto-populate fields
if (config.wtUrl) {
// Use the hostname from browser but with the WT port from server
const hostname = window.location.hostname || '127.0.0.1';
const wtPort = config.wtPort || 4433;
ui.url.value = `https://${hostname}:${wtPort}/`;
}
if (config.sha256) {
ui.hash.value = config.sha256;
}
setStatus('Auto-configured from server');
console.log('Auto-configured:', config);
} catch (err) {
console.warn('Auto-config failed:', err);
// Fallback: just populate URL from browser location
const hostname = window.location.hostname || '127.0.0.1';
ui.url.value = `https://${hostname}:4433/`;
setStatus('Manual configuration required');
}
}
ui.connect.onclick = () => { connectWT().catch(e => setStatus('connect failed: ' + e)); };
window.addEventListener('keydown', onKey);
window.addEventListener('resize', render);
// Auto-configure when page loads
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', autoConfigureFromServer);
} else {
autoConfigureFromServer();
}

View File

@@ -15,16 +15,15 @@
</head>
<body>
<div id="ui">
<label>Server URL (WebTransport): <input id="url" value="https://localhost:4433/"/></label>
<label>SHA-256 Cert Hash (hex, optional for self-signed): <input id="hash" placeholder="e.g., aabbcc..."/></label>
<label>Server URL (WebTransport): <input id="url" value="https://localhost:4433/" placeholder="Auto-configured..."/></label>
<label>Certificate Hash (SHA-256, auto-configured): <input id="hash" placeholder="Auto-configured from server..."/></label>
<label>Name: <input id="name" value="guest"/></label>
<button id="connect">Connect</button>
<span id="status"></span>
</div>
<div id="overlay">press space to join</div>
<canvas id="view"></canvas>
<script src="protocol.js"></script>
<script src="client.js"></script>
<script type="module" src="client.js"></script>
</body>
</html>

390
run.py
View File

@@ -1,5 +1,5 @@
import asyncio
import os
import argparse
import logging
import sys
@@ -7,132 +7,332 @@ from server.server import GameServer
from server.config import ServerConfig
async def _run_tasks_with_optional_timeout(tasks):
"""Await tasks, optionally honoring RUN_SECONDS env var to cancel after a timeout."""
timeout_s = os.environ.get("RUN_SECONDS")
async def _run_tasks_with_optional_timeout(tasks, timeout_s=None):
"""Await tasks, optionally with a timeout to cancel after specified seconds."""
if not timeout_s:
await asyncio.gather(*tasks)
return
try:
await asyncio.wait_for(asyncio.gather(*tasks), timeout=float(timeout_s))
except asyncio.TimeoutError:
logging.info("Timeout reached (RUN_SECONDS=%s); stopping server tasks...", timeout_s)
logging.info("Timeout reached (%s seconds); stopping server tasks...", timeout_s)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
async def run_in_memory():
async def run_in_memory(args):
from server.transport import InMemoryTransport
cfg = ServerConfig()
cfg = ServerConfig(
width=args.width,
height=args.height,
tick_rate=args.tick_rate,
wrap_edges=args.wrap_edges,
apples_per_snake=args.apples_per_snake,
apples_cap=args.apples_cap,
players_max=args.players_max,
)
server = GameServer(transport=InMemoryTransport(lambda d, p: server.on_datagram(d, p)), config=cfg)
tasks = [asyncio.create_task(server.transport.run()), asyncio.create_task(server.tick_loop())]
await _run_tasks_with_optional_timeout(tasks)
await _run_tasks_with_optional_timeout(tasks, args.run_seconds)
async def run_quic():
async def run_quic(args):
import os
from server.quic_transport import QuicWebTransportServer
cfg = ServerConfig()
host = os.environ.get("QUIC_HOST", "0.0.0.0")
port = int(os.environ.get("QUIC_PORT", "4433"))
cert = os.environ["QUIC_CERT"]
key = os.environ["QUIC_KEY"]
server = GameServer(transport=QuicWebTransportServer(host, port, cert, key, lambda d, p: server.on_datagram(d, p)), config=cfg)
tasks = [asyncio.create_task(server.transport.run()), asyncio.create_task(server.tick_loop())]
await _run_tasks_with_optional_timeout(tasks)
from server.utils import combine_cert_chain
cfg = ServerConfig(
width=args.width,
height=args.height,
tick_rate=args.tick_rate,
wrap_edges=args.wrap_edges,
apples_per_snake=args.apples_per_snake,
apples_cap=args.apples_cap,
players_max=args.players_max,
)
async def run_webtransport():
from server.webtransport_server import WebTransportServer
from server.static_server import start_https_static
cfg = ServerConfig()
host = os.environ.get("WT_HOST", os.environ.get("QUIC_HOST", "0.0.0.0"))
port = int(os.environ.get("WT_PORT", os.environ.get("QUIC_PORT", "4433")))
cert = os.environ.get("WT_CERT") or os.environ["QUIC_CERT"]
key = os.environ.get("WT_KEY") or os.environ["QUIC_KEY"]
# Optional static HTTPS server for client assets
static = os.environ.get("STATIC", "1")
static_host = os.environ.get("STATIC_HOST", host)
static_port = int(os.environ.get("STATIC_PORT", "8443"))
static_root = os.environ.get("STATIC_ROOT", "client")
httpd = None
if static == "1":
httpd, _t = start_https_static(static_host, static_port, cert, key, static_root)
print(f"HTTPS static server: https://{static_host}:{static_port}/ serving '{static_root}'")
server = GameServer(transport=WebTransportServer(host, port, cert, key, lambda d, p: server.on_datagram(d, p)), config=cfg)
print(f"WebTransport server: https://{host}:{port}/ (HTTP/3)")
# Handle certificate chain
# Check if cert file contains multiple certificates (like Let's Encrypt fullchain.pem)
with open(args.cert, 'r') as f:
cert_content = f.read()
from server.utils import split_pem_certificates
certs_in_file = split_pem_certificates(cert_content)
cert_file = args.cert
temp_cert_file = None
if args.cert_chain:
# Combine cert + chain files into single file
temp_cert_file = combine_cert_chain(args.cert, args.cert_chain)
cert_file = temp_cert_file
logging.info("Combined certificate with %d chain file(s)", len(args.cert_chain))
elif len(certs_in_file) > 1:
# Certificate file contains full chain - reformat it for aioquic
temp_cert_file = combine_cert_chain(args.cert, [])
cert_file = temp_cert_file
logging.info("Reformatted certificate chain from fullchain file (%d certs)", len(certs_in_file))
server = GameServer(transport=QuicWebTransportServer(args.quic_host, args.quic_port, cert_file, args.key, lambda d, p: server.on_datagram(d, p)), config=cfg)
logging.info("QUIC server: %s:%d", args.quic_host, args.quic_port)
try:
tasks = [asyncio.create_task(server.transport.run()), asyncio.create_task(server.tick_loop())]
await _run_tasks_with_optional_timeout(tasks)
await _run_tasks_with_optional_timeout(tasks, args.run_seconds)
finally:
if temp_cert_file:
os.unlink(temp_cert_file) # Clean up temp combined cert file
async def run_webtransport(args):
import json
import os
from server.webtransport_server import WebTransportServer
from server.static_server import start_https_static
from server.utils import get_cert_sha256_hash, combine_cert_chain, extract_server_cert
cfg = ServerConfig(
width=args.width,
height=args.height,
tick_rate=args.tick_rate,
wrap_edges=args.wrap_edges,
apples_per_snake=args.apples_per_snake,
apples_cap=args.apples_cap,
players_max=args.players_max,
)
# Handle certificate chain
# Check if cert file contains multiple certificates (like Let's Encrypt fullchain.pem)
with open(args.cert, 'r') as f:
cert_content = f.read()
from server.utils import split_pem_certificates
certs_in_file = split_pem_certificates(cert_content)
cert_file = args.cert
temp_cert_file = None
if args.cert_chain:
# Combine cert + chain files into single file
temp_cert_file = combine_cert_chain(args.cert, args.cert_chain)
cert_file = temp_cert_file
logging.info("Combined certificate with %d chain file(s)", len(args.cert_chain))
elif len(certs_in_file) > 1:
# Certificate file contains full chain - reformat it for aioquic
temp_cert_file = combine_cert_chain(args.cert, [])
cert_file = temp_cert_file
logging.info("Reformatted certificate chain from fullchain file (%d certs)", len(certs_in_file))
# Calculate certificate hash for WebTransport client (only hash server cert, not chain)
server_cert_file = extract_server_cert(args.cert)
cert_hash = get_cert_sha256_hash(server_cert_file)
if server_cert_file != args.cert:
os.unlink(server_cert_file) # Clean up temp file
# Prepare cert-hash.json content
cert_hash_json = json.dumps({
"sha256": cert_hash,
"wtUrl": f"https://{args.wt_host}:{args.wt_port}/",
"wtPort": args.wt_port
})
# Optional static HTTPS server for client assets
httpd = None
if args.static:
httpd, _t = start_https_static(
args.static_host,
args.static_port,
cert_file,
args.key,
args.static_root,
cert_hash_json=cert_hash_json
)
print(f"HTTPS static server: https://{args.static_host}:{args.static_port}/ serving '{args.static_root}'")
print(f"Certificate SHA-256: {cert_hash}")
server = GameServer(transport=WebTransportServer(args.wt_host, args.wt_port, cert_file, args.key, lambda d, p: server.on_datagram(d, p)), config=cfg)
print(f"WebTransport server: https://{args.wt_host}:{args.wt_port}/ (HTTP/3)")
try:
tasks = [asyncio.create_task(server.transport.run()), asyncio.create_task(server.tick_loop())]
await _run_tasks_with_optional_timeout(tasks, args.run_seconds)
finally:
if httpd is not None:
httpd.shutdown()
if temp_cert_file:
os.unlink(temp_cert_file) # Clean up temp combined cert file
async def run_net(args):
import os
from server.webtransport_server import WebTransportServer
from server.quic_transport import QuicWebTransportServer
from server.multi_transport import MultiTransport
from server.utils import combine_cert_chain
cfg = ServerConfig(
width=args.width,
height=args.height,
tick_rate=args.tick_rate,
wrap_edges=args.wrap_edges,
apples_per_snake=args.apples_per_snake,
apples_cap=args.apples_cap,
players_max=args.players_max,
)
# Handle certificate chain
# Check if cert file contains multiple certificates (like Let's Encrypt fullchain.pem)
with open(args.cert, 'r') as f:
cert_content = f.read()
from server.utils import split_pem_certificates
certs_in_file = split_pem_certificates(cert_content)
cert_file = args.cert
temp_cert_file = None
if args.cert_chain:
# Combine cert + chain files into single file
temp_cert_file = combine_cert_chain(args.cert, args.cert_chain)
cert_file = temp_cert_file
logging.info("Combined certificate with %d chain file(s)", len(args.cert_chain))
elif len(certs_in_file) > 1:
# Certificate file contains full chain - reformat it for aioquic
temp_cert_file = combine_cert_chain(args.cert, [])
cert_file = temp_cert_file
logging.info("Reformatted certificate chain from fullchain file (%d certs)", len(certs_in_file))
server: GameServer
wt = WebTransportServer(args.wt_host, args.wt_port, cert_file, args.key, lambda d, p: server.on_datagram(d, p))
qu = QuicWebTransportServer(args.quic_host, args.quic_port, cert_file, args.key, lambda d, p: server.on_datagram(d, p))
m = MultiTransport(wt, qu)
server = GameServer(transport=m, config=cfg)
logging.info("WebTransport server: %s:%d", args.wt_host, args.wt_port)
logging.info("QUIC server: %s:%d", args.quic_host, args.quic_port)
try:
tasks = [asyncio.create_task(m.run()), asyncio.create_task(server.tick_loop())]
await _run_tasks_with_optional_timeout(tasks, args.run_seconds)
finally:
if temp_cert_file:
os.unlink(temp_cert_file) # Clean up temp combined cert file
def parse_args():
parser = argparse.ArgumentParser(
description="Real-time multiplayer Snake game server",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python run.py # WebTransport mode with defaults
python run.py --mode mem # In-memory mode (testing)
python run.py --mode quic --quic-port 5000 # QUIC mode on custom port
python run.py --cert my.pem --key my.key # Custom TLS certificate
python run.py --no-static # Disable static HTTPS server
python run.py --width 80 --height 60 # Custom field size
"""
)
# Transport mode
parser.add_argument("--mode", choices=["mem", "wt", "quic", "net"], default="wt",
help="Transport mode: mem (in-memory), wt (WebTransport/HTTP3), quic (QUIC datagrams), net (both wt+quic). Default: wt")
# TLS certificate/key
parser.add_argument("--cert", default="cert.pem",
help="TLS certificate file path (may contain full chain). Default: cert.pem")
parser.add_argument("--key", default="key.pem",
help="TLS private key file path. Default: key.pem")
parser.add_argument("--cert-chain", action="append", dest="cert_chain",
help="Intermediate certificate file (can be used multiple times for long chains)")
# WebTransport server settings
parser.add_argument("--wt-host", default="0.0.0.0",
help="WebTransport server host. Default: 0.0.0.0")
parser.add_argument("--wt-port", type=int, default=4433,
help="WebTransport server port. Default: 4433")
# QUIC server settings
parser.add_argument("--quic-host", default="0.0.0.0",
help="QUIC server host. Default: 0.0.0.0")
parser.add_argument("--quic-port", type=int, default=4433,
help="QUIC server port. Default: 4433 (or 4443 in net mode)")
# Static HTTPS server settings
parser.add_argument("--static", dest="static", action="store_true", default=True,
help="Enable static HTTPS server for client files (default in wt mode)")
parser.add_argument("--no-static", dest="static", action="store_false",
help="Disable static HTTPS server")
parser.add_argument("--static-host", default=None,
help="Static server host. Default: same as wt-host")
parser.add_argument("--static-port", type=int, default=8443,
help="Static server port. Default: 8443")
parser.add_argument("--static-root", default="client",
help="Static files directory. Default: client")
# Game configuration
parser.add_argument("--width", type=int, default=60,
help="Field width (3-255). Default: 60")
parser.add_argument("--height", type=int, default=40,
help="Field height (3-255). Default: 40")
parser.add_argument("--tick-rate", type=int, default=10,
help="Server tick rate in TPS (5-30). Default: 10")
parser.add_argument("--wrap-edges", action="store_true", default=False,
help="Enable edge wrapping (default: disabled)")
parser.add_argument("--apples-per-snake", type=int, default=1,
help="Apples per connected snake (1-12). Default: 1")
parser.add_argument("--apples-cap", type=int, default=255,
help="Maximum total apples (0-255). Default: 255")
parser.add_argument("--players-max", type=int, default=32,
help="Maximum concurrent players. Default: 32")
# Logging and testing
parser.add_argument("--log-level", default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level. Default: INFO")
parser.add_argument("--run-seconds", type=float, default=None,
help="Optional timeout in seconds for testing")
args = parser.parse_args()
# Post-process: static-host defaults to wt-host
if args.static_host is None:
args.static_host = args.wt_host
# In net mode, default quic-port to 4443 if not explicitly set
if args.mode == "net" and "--quic-port" not in sys.argv:
args.quic_port = 4443
# Validate TLS files for modes that need them
if args.mode in ("wt", "quic", "net"):
import os
if not os.path.exists(args.cert):
parser.error(f"Certificate file not found: {args.cert}")
if not os.path.exists(args.key):
parser.error(f"Key file not found: {args.key}")
return args
if __name__ == "__main__":
try:
if any(a in ("-h", "--help") for a in sys.argv[1:]):
print(
"Usage: python run.py [--mode mem|quic|wt|net] [--log-level LEVEL] [--run-seconds N]\n"
" TLS (for wt/quic): set QUIC_CERT/QUIC_KEY or WT_CERT/WT_KEY env vars\n"
" WT static server (MODE=wt): STATIC=1 [STATIC_HOST/PORT/ROOT]\n"
"Examples:\n MODE=wt QUIC_CERT=cert.pem QUIC_KEY=key.pem python run.py\n MODE=mem python run.py"
)
sys.exit(0)
args = parse_args()
# Logging setup
level = os.environ.get("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=getattr(logging, level, logging.INFO), format="[%(asctime)s] %(levelname)s: %(message)s")
mode = os.environ.get("MODE", "mem").lower()
if mode == "wt":
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="[%(asctime)s] %(levelname)s: %(message)s"
)
# Run appropriate mode
if args.mode == "wt":
logging.info("Starting in WebTransport mode")
asyncio.run(run_webtransport())
elif mode == "quic":
asyncio.run(run_webtransport(args))
elif args.mode == "quic":
logging.info("Starting in QUIC datagram mode")
asyncio.run(run_quic())
elif mode == "net":
asyncio.run(run_quic(args))
elif args.mode == "net":
logging.info("Starting in combined WebTransport+QUIC mode")
from server.webtransport_server import WebTransportServer
from server.quic_transport import QuicWebTransportServer
from server.multi_transport import MultiTransport
cfg = ServerConfig()
host_wt = os.environ.get("WT_HOST", os.environ.get("QUIC_HOST", "0.0.0.0"))
port_wt = int(os.environ.get("WT_PORT", os.environ.get("QUIC_PORT", "4433")))
host_quic = os.environ.get("QUIC_HOST", host_wt)
port_quic = int(os.environ.get("QUIC_PORT", "4443"))
cert = os.environ.get("WT_CERT") or os.environ.get("QUIC_CERT")
key = os.environ.get("WT_KEY") or os.environ.get("QUIC_KEY")
if not cert or not key:
raise SystemExit("WT/QUIC cert/key required: set WT_CERT/WT_KEY or QUIC_CERT/QUIC_KEY")
async def _run_net():
server: GameServer
wt = WebTransportServer(host_wt, port_wt, cert, key, lambda d, p: server.on_datagram(d, p))
qu = QuicWebTransportServer(host_quic, port_quic, cert, key, lambda d, p: server.on_datagram(d, p))
m = MultiTransport(wt, qu)
server = GameServer(transport=m, config=cfg)
await asyncio.gather(m.run(), server.tick_loop())
asyncio.run(_run_net())
else:
asyncio.run(run_net(args))
else: # mem
logging.info("Starting in in-memory transport mode")
asyncio.run(run_in_memory())
asyncio.run(run_in_memory(args))
except KeyboardInterrupt:
pass
async def _run_tasks_with_optional_timeout(tasks):
"""Await tasks, optionally honoring RUN_SECONDS env var to cancel after a timeout."""
timeout_s = os.environ.get("RUN_SECONDS")
if not timeout_s:
await asyncio.gather(*tasks)
return
try:
await asyncio.wait_for(asyncio.gather(*tasks), timeout=float(timeout_s))
except asyncio.TimeoutError:
logging.info("Timeout reached (RUN_SECONDS=%s); stopping server tasks...", timeout_s)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)

View File

@@ -44,7 +44,7 @@ class GameQuicProtocol(QuicConnectionProtocol): # type: ignore[misc]
async def send_datagram(self, data: bytes) -> None:
logging.debug("QUIC send datagram: %d bytes", len(data))
self._quic.send_datagram_frame(data) # type: ignore[attr-defined]
await self._loop.run_in_executor(None, self.transmit) # type: ignore[attr-defined]
self.transmit() # Send queued QUIC packets immediately
class QuicWebTransportServer(DatagramServerTransport):
@@ -65,15 +65,20 @@ class QuicWebTransportServer(DatagramServerTransport):
await proto.send_datagram(data)
async def run(self) -> None:
configuration = QuicConfiguration(is_client=False, alpn_protocols=["h3"])
configuration = QuicConfiguration(
is_client=False,
alpn_protocols=["h3"],
max_datagram_frame_size=65536 # Enable QUIC datagrams
)
configuration.load_cert_chain(self.certfile, self.keyfile)
async def _create_protocol(*args, **kwargs):
def _create_protocol(*args, **kwargs):
return GameQuicProtocol(*args, on_datagram=self._on_datagram, peers=self._peers, **kwargs)
logging.info("QUIC datagram server listening on %s:%d", self.host, self.port)
self._server = await serve(self.host, self.port, configuration=configuration, create_protocol=_create_protocol)
try:
await self._server.wait_closed()
# Wait indefinitely until cancelled (e.g., by KeyboardInterrupt or timeout)
await asyncio.Event().wait()
finally:
self._server.close()

View File

@@ -3,25 +3,66 @@ from __future__ import annotations
import ssl
import logging
import threading
import json
from http.server import ThreadingHTTPServer, SimpleHTTPRequestHandler
from pathlib import Path
from typing import Tuple
from typing import Tuple, Optional
class _Handler(SimpleHTTPRequestHandler):
# Allow passing a base directory at construction time
# Allow passing a base directory and cert_hash_json at construction time
cert_hash_json: Optional[str] = None
def __init__(self, *args, directory: str | None = None, **kwargs):
super().__init__(*args, directory=directory, **kwargs)
def end_headers(self):
# Add no-cache headers for all static files to force browser to fetch fresh versions
self.send_header('Cache-Control', 'no-cache, no-store, must-revalidate')
self.send_header('Pragma', 'no-cache')
self.send_header('Expires', '0')
super().end_headers()
def start_https_static(host: str, port: int, certfile: str, keyfile: str, docroot: str) -> Tuple[ThreadingHTTPServer, threading.Thread]:
def do_GET(self):
# Intercept /cert-hash.json requests
if self.path == '/cert-hash.json' and self.cert_hash_json:
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(self.cert_hash_json.encode('utf-8'))
else:
# Serve regular static files
super().do_GET()
def start_https_static(
host: str,
port: int,
certfile: str,
keyfile: str,
docroot: str,
cert_hash_json: Optional[str] = None
) -> Tuple[ThreadingHTTPServer, threading.Thread]:
"""Start a simple HTTPS static file server in a background thread.
Args:
host: Host to bind to
port: Port to bind to
certfile: Path to TLS certificate
keyfile: Path to TLS private key
docroot: Document root directory
cert_hash_json: Optional JSON string to serve at /cert-hash.json
Returns the (httpd, thread). Caller is responsible for calling httpd.shutdown()
to stop the server on application exit.
"""
docroot_path = str(Path(docroot).resolve())
# Set class variable for the handler
if cert_hash_json:
_Handler.cert_hash_json = cert_hash_json
def handler(*args, **kwargs):
return _Handler(*args, directory=docroot_path, **kwargs)

View File

@@ -1,5 +1,11 @@
from __future__ import annotations
import hashlib
import re
import tempfile
from typing import List, Optional
from pathlib import Path
def is_newer(a: int, b: int) -> bool:
"""Return True if 16-bit sequence number a is newer than b (wrap-aware).
@@ -8,3 +14,129 @@ def is_newer(a: int, b: int) -> bool:
"""
return ((a - b) & 0xFFFF) < 0x8000
def get_cert_sha256_hash(cert_path: str) -> str:
"""Calculate SHA-256 hash of a certificate file.
Returns the hash as a lowercase hex string (64 characters).
This hash can be used by WebTransport clients for certificate pinning.
"""
try:
from cryptography import x509
from cryptography.hazmat.backends import default_backend
# Read the certificate file
with open(cert_path, 'rb') as f:
cert_data = f.read()
# Parse the certificate
cert = x509.load_pem_x509_certificate(cert_data, default_backend())
# Get the DER-encoded certificate and hash it
cert_der = cert.public_bytes(encoding=x509.Encoding.DER)
hash_bytes = hashlib.sha256(cert_der).digest()
# Return as hex string
return hash_bytes.hex()
except Exception as e:
# Fallback: just hash the file contents (less accurate but works)
with open(cert_path, 'rb') as f:
return hashlib.sha256(f.read()).digest().hex()
def split_pem_certificates(pem_data: str) -> List[str]:
"""Split a PEM file containing multiple certificates into individual certs.
Args:
pem_data: String containing one or more PEM certificates
Returns:
List of individual certificate strings (each includes BEGIN/END markers)
"""
# Match all certificate blocks
cert_pattern = r'(-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----)'
certificates = re.findall(cert_pattern, pem_data, re.DOTALL)
return certificates
def combine_cert_chain(cert_path: str, chain_paths: List[str]) -> str:
"""Combine server cert and chain certs into a single properly-formatted PEM file.
Args:
cert_path: Path to server certificate (may contain full chain)
chain_paths: List of paths to intermediate/root certificates (optional)
Returns:
Path to temporary combined certificate file (caller should clean up)
"""
combined_content = []
# Read server certificate file
with open(cert_path, 'r') as f:
cert_data = f.read().strip()
certs = split_pem_certificates(cert_data)
if chain_paths:
# If chain files provided separately, only use first cert from cert_path
if certs:
combined_content.append(certs[0].strip())
else:
combined_content.append(cert_data)
else:
# No separate chain files - include all certs from cert_path (fullchain case)
if certs:
combined_content.extend([c.strip() for c in certs])
else:
combined_content.append(cert_data)
# Read separate chain certificate files
for chain_path in chain_paths:
with open(chain_path, 'r') as f:
chain_data = f.read().strip()
# Each chain file might contain multiple certs
chain_certs = split_pem_certificates(chain_data)
if chain_certs:
combined_content.extend([c.strip() for c in chain_certs])
else:
combined_content.append(chain_data)
# Write to temporary file with proper formatting
# Each cert should be separated by a newline
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False)
temp_file.write('\n'.join(combined_content))
temp_file.write('\n') # End with newline
temp_file.close()
return temp_file.name
def extract_server_cert(cert_path: str) -> str:
"""Extract just the server certificate from a file that may contain a full chain.
Args:
cert_path: Path to certificate file (may contain chain)
Returns:
Path to temporary file containing only the server certificate
"""
with open(cert_path, 'r') as f:
cert_data = f.read()
# Split into individual certificates
certs = split_pem_certificates(cert_data)
if not certs:
# No certs found, return original path
return cert_path
# First cert is the server cert
server_cert = certs[0].strip()
# Write to temporary file
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False)
temp_file.write(server_cert)
temp_file.write('\n')
temp_file.close()
return temp_file.name

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Optional
from typing import Callable, Dict, Optional
from .transport import DatagramServerTransport, OnDatagram, TransportPeer
import logging
@@ -11,7 +11,8 @@ import logging
try:
from aioquic.asyncio import QuicConnectionProtocol, serve
from aioquic.quic.configuration import QuicConfiguration
from aioquic.h3.connection import H3Connection
from aioquic.quic.events import ProtocolNegotiated
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.h3.events import HeadersReceived
# Datagram event names vary by aioquic version; try both
try:
@@ -21,6 +22,8 @@ try:
except Exception: # pragma: no cover - optional dependency not installed
QuicConnectionProtocol = object # type: ignore
QuicConfiguration = object # type: ignore
ProtocolNegotiated = object # type: ignore
H3_ALPN = [] # type: ignore
H3Connection = object # type: ignore
HeadersReceived = object # type: ignore
H3DatagramReceived = object # type: ignore
@@ -38,55 +41,75 @@ class GameWTProtocol(QuicConnectionProtocol): # type: ignore[misc]
super().__init__(*args, **kwargs)
self._on_datagram = on_datagram
self._sessions = sessions
self._http: Optional[H3Connection] = None
self._http: Optional[H3Connection] = None # Created after ProtocolNegotiated
logging.debug("GameWTProtocol initialized")
def http_event_received(self, event) -> None: # type: ignore[override]
logging.debug("HTTP event received: %s", type(event).__name__)
# Headers for CONNECT :protocol = webtransport open a session
if isinstance(event, HeadersReceived):
headers = {k.decode().lower(): v.decode() for k, v in event.headers}
logging.debug("HeadersReceived: %s", headers)
method = headers.get(":method")
protocol = headers.get(":protocol") or headers.get("sec-webtransport-protocol")
logging.debug("Method: %s, Protocol: %s", method, protocol)
if method == "CONNECT" and (protocol == "webtransport"):
# In WebTransport over H3, datagrams use the CONNECT stream id as flow_id
flow_id = event.stream_id # type: ignore[attr-defined]
self._sessions[flow_id] = WTSession(flow_id=flow_id, proto=self)
logging.info("WT CONNECT accepted: flow_id=%s", flow_id)
# In WebTransport over H3, datagrams use the CONNECT stream id as session id
session_id = event.stream_id # type: ignore[attr-defined]
self._sessions[session_id] = WTSession(flow_id=session_id, proto=self)
logging.info("WT CONNECT accepted: session_id=%s", session_id)
# Send 2xx to accept the session
if self._http is not None:
self._http.send_headers(event.stream_id, [(b":status", b"200")])
self.transmit() # Actually send the response to complete handshake
logging.debug("Sent 200 response for WT CONNECT")
else:
logging.warning("Unexpected CONNECT: method=%s, protocol=%s", method, protocol)
elif isinstance(event, H3DatagramReceived): # type: ignore[misc]
# Route datagram to session by flow_id
flow_id = getattr(event, "flow_id", None)
# Route datagram to session by stream_id (WebTransport session ID)
# DatagramReceived has: stream_id (session), data (payload)
session_id = getattr(event, "stream_id", None)
data = getattr(event, "data", None)
if flow_id is None or data is None:
logging.debug("DatagramReceived event: session_id=%s, data_len=%s, has_data=%s",
session_id, len(data) if data else None, data is not None)
if session_id is None or data is None:
logging.warning("Invalid datagram event: session_id=%s, data=%s", session_id, data)
return
sess = self._sessions.get(flow_id)
sess = self._sessions.get(session_id)
if not sess:
logging.warning("No session found for datagram: session_id=%s", session_id)
return
peer = TransportPeer(addr=(self, flow_id))
logging.debug("WT datagram received: flow_id=%s, %d bytes", flow_id, len(data))
peer = TransportPeer(addr=(self, session_id))
logging.info("WT datagram received: session_id=%s, %d bytes", session_id, len(data))
asyncio.ensure_future(self._on_datagram(bytes(data), peer))
def quic_event_received(self, event) -> None: # type: ignore[override]
# Lazily create H3 connection wrapper
if self._http is None and hasattr(self, "_quic"):
try:
self._http = H3Connection(self._quic, enable_webtransport=True) # type: ignore[attr-defined]
except Exception:
self._http = None
event_name = type(event).__name__
logging.debug("QUIC event received: %s", event_name)
# Create H3Connection after ALPN protocol is negotiated
if isinstance(event, ProtocolNegotiated):
if event.alpn_protocol in H3_ALPN:
self._http = H3Connection(self._quic, enable_webtransport=True)
logging.info("H3Connection created with WebTransport support (ALPN: %s)", event.alpn_protocol)
else:
logging.warning("Unexpected ALPN protocol: %s", event.alpn_protocol)
# Pass event to HTTP layer if connection is established
if self._http is not None:
for http_event in self._http.handle_event(event):
self.http_event_received(http_event)
async def send_h3_datagram(self, flow_id: int, data: bytes) -> None:
if self._http is None:
return
try:
if self._http is None:
logging.warning("Cannot send datagram: H3Connection not established")
return
logging.debug("WT send datagram: flow_id=%s, %d bytes", flow_id, len(data))
self._http.send_datagram(flow_id, data)
await self._loop.run_in_executor(None, self.transmit) # type: ignore[attr-defined]
except Exception:
pass
self._http.send_datagram(stream_id=flow_id, data=data)
self.transmit() # Send queued QUIC packets immediately
except Exception as e:
logging.debug("Failed to send datagram: %s", e)
class WebTransportServer(DatagramServerTransport):
@@ -115,15 +138,34 @@ class WebTransportServer(DatagramServerTransport):
await proto.send_h3_datagram(flow_id, data)
async def run(self) -> None:
configuration = QuicConfiguration(is_client=False, alpn_protocols=["h3"])
configuration.load_cert_chain(self.certfile, self.keyfile)
from aioquic.quic.logger import QuicFileLogger
import os
async def _create_protocol(*args, **kwargs):
# Enable QUIC logging if debug level
quic_logger = None
if logging.getLogger().getEffectiveLevel() <= logging.DEBUG:
log_dir = os.path.join(os.getcwd(), "quic_logs")
os.makedirs(log_dir, exist_ok=True)
quic_logger = QuicFileLogger(log_dir)
logging.info("QUIC logging enabled in: %s", log_dir)
configuration = QuicConfiguration(
is_client=False,
alpn_protocols=["h3"], # HTTP/3 with WebTransport support
max_datagram_frame_size=65536, # Enable QUIC datagrams (required for WebTransport)
quic_logger=quic_logger
)
configuration.load_cert_chain(self.certfile, self.keyfile)
logging.debug("QUIC configuration: ALPN=%s, max_datagram_frame_size=%d",
configuration.alpn_protocols, configuration.max_datagram_frame_size or 0)
def _create_protocol(*args, **kwargs):
return GameWTProtocol(*args, on_datagram=self._on_datagram, sessions=self._sessions, **kwargs)
logging.info("WebTransport (H3) server listening on %s:%d", self.host, self.port)
self._server = await serve(self.host, self.port, configuration=configuration, create_protocol=_create_protocol)
try:
await self._server.wait_closed()
# Wait indefinitely until cancelled (e.g., by KeyboardInterrupt or timeout)
await asyncio.Event().wait()
finally:
self._server.close()

59
test_http3.py Normal file
View File

@@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""Simple HTTP/3 client to test the server."""
import asyncio
import logging
from aioquic.asyncio import connect
from aioquic.quic.configuration import QuicConfiguration
logging.basicConfig(level=logging.DEBUG)
async def test_http3(url: str):
"""Test HTTP/3 connection to the server."""
print(f"Testing HTTP/3 connection to: {url}")
# Parse URL
if url.startswith("https://"):
url = url[8:]
host, port = url.split(":")
port = int(port.rstrip("/"))
print(f"Connecting to {host}:{port}...")
# Create QUIC configuration
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["h3"],
verify_mode=0 # Skip certificate verification for testing
)
try:
async with connect(
host,
port,
configuration=configuration,
) as protocol:
print(f"✓ QUIC connection established!")
print(f" ALPN protocol: {protocol._quic.tls.alpn_negotiated}")
print(f" Remote address: {protocol._quic.remote_address}")
# Just test the connection, don't send HTTP requests yet
await asyncio.sleep(1)
print("✓ Connection successful!")
return True
except Exception as e:
print(f"✗ Connection failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
import sys
url = sys.argv[1] if len(sys.argv) > 1 else "https://127.0.0.1:4433"
success = asyncio.run(test_http3(url))
sys.exit(0 if success else 1)

188
test_webtransport_client.py Normal file
View File

@@ -0,0 +1,188 @@
#!/usr/bin/env python3
"""
WebTransport test client to verify server is working correctly.
"""
import asyncio
import logging
import sys
from typing import Optional
from aioquic.asyncio.client import connect
from aioquic.asyncio.protocol import QuicConnectionProtocol
from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DatagramReceived, H3Event, HeadersReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import QuicEvent
logging.basicConfig(
level=logging.DEBUG,
format="[%(asctime)s] %(levelname)s: %(message)s"
)
logger = logging.getLogger(__name__)
class WebTransportClient(QuicConnectionProtocol):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._http: Optional[H3Connection] = None
self._session_id: Optional[int] = None
self._session_established = asyncio.Event()
self._received_datagrams = []
def quic_event_received(self, event: QuicEvent) -> None:
logger.debug(f"QUIC event: {type(event).__name__}")
# Create H3Connection on first event
if self._http is None:
self._http = H3Connection(self._quic, enable_webtransport=True)
logger.info("H3Connection created with WebTransport support")
# Process event through HTTP/3 layer
if self._http is not None:
for http_event in self._http.handle_event(event):
self.http_event_received(http_event)
def http_event_received(self, event: H3Event) -> None:
logger.debug(f"HTTP event: {type(event).__name__}")
if isinstance(event, HeadersReceived):
headers = dict(event.headers)
status = headers.get(b":status", b"").decode()
logger.info(f"Received headers: status={status}")
if status == "200" and self._session_id is None:
self._session_id = event.stream_id
logger.info(f"WebTransport session established! session_id={self._session_id}")
self._session_established.set()
else:
logger.error(f"WebTransport session rejected: status={status}, headers={headers}")
elif isinstance(event, DatagramReceived):
logger.info(f"Received datagram: {len(event.data)} bytes, flow_id={event.flow_id}")
self._received_datagrams.append(event.data)
async def establish_session(self, path: str = "/") -> None:
"""Send WebTransport CONNECT request."""
logger.info(f"Sending WebTransport CONNECT request for path: {path}")
# Allocate stream ID for CONNECT request
stream_id = self._quic.get_next_available_stream_id()
# Send CONNECT request with WebTransport protocol
headers = [
(b":method", b"CONNECT"),
(b":scheme", b"https"),
(b":authority", b"np.vaku.org.ua:4433"),
(b":path", path.encode()),
(b":protocol", b"webtransport"),
]
self._http.send_headers(stream_id=stream_id, headers=headers, end_stream=False)
self.transmit()
logger.info(f"WebTransport CONNECT sent on stream {stream_id}")
# Wait for session to be established
try:
await asyncio.wait_for(self._session_established.wait(), timeout=5.0)
logger.info("✓ WebTransport session established successfully!")
return True
except asyncio.TimeoutError:
logger.error("✗ Timeout waiting for WebTransport session acceptance")
return False
async def send_datagram(self, data: bytes) -> None:
"""Send a datagram over the WebTransport session."""
if self._session_id is None:
logger.error("Cannot send datagram: session not established")
return
logger.info(f"Sending datagram: {len(data)} bytes")
self._http.send_datagram(stream_id=self._session_id, data=data)
self.transmit()
async def wait_for_datagram(self, timeout: float = 2.0) -> Optional[bytes]:
"""Wait for a datagram response."""
logger.info(f"Waiting for datagram response (timeout={timeout}s)...")
start = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start < timeout:
if self._received_datagrams:
return self._received_datagrams.pop(0)
await asyncio.sleep(0.1)
logger.warning("No datagram received within timeout")
return None
async def test_webtransport(host: str, port: int, verify_mode: bool = False):
"""Test WebTransport connection to the server."""
logger.info(f"Testing WebTransport connection to {host}:{port}")
# Configure QUIC
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["h3"], # HTTP/3
max_datagram_frame_size=65536, # Enable QUIC datagrams (required for WebTransport)
verify_mode=0 if not verify_mode else 2, # Skip cert verification for self-signed
)
try:
async with connect(
host,
port,
configuration=configuration,
create_protocol=WebTransportClient,
) as client:
logger.info("✓ QUIC connection established")
client = client # type: WebTransportClient
# Establish WebTransport session
if not await client.establish_session("/"):
logger.error("✗ Failed to establish WebTransport session")
return False
# Send a test datagram (simple game JOIN packet)
# Format: version(1) | type(1) | flags(1) | seq(2) | name_len(1) | name
test_packet = bytes([
0x01, # version
0x01, # JOIN packet type
0x00, # flags
0x00, 0x01, # seq=1
0x04, # name length
ord('T'), ord('E'), ord('S'), ord('T'), # name="TEST"
])
await client.send_datagram(test_packet)
logger.info("✓ Test JOIN packet sent")
# Wait for response
response = await client.wait_for_datagram(timeout=3.0)
if response:
logger.info(f"✓ Received response: {len(response)} bytes")
logger.info(f" Response hex: {response.hex()}")
return True
else:
logger.warning("Server did not respond to JOIN packet (might be normal if not implemented)")
return True # Connection worked even if no response
except Exception as e:
logger.error(f"✗ Connection failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python test_webtransport_client.py <host> [port]")
print("Example: python test_webtransport_client.py np.vaku.org.ua 4433")
sys.exit(1)
host = sys.argv[1]
port = int(sys.argv[2]) if len(sys.argv) > 2 else 4433
success = asyncio.run(test_webtransport(host, port))
sys.exit(0 if success else 1)