from __future__ import annotations from dataclasses import dataclass from enum import IntEnum from typing import Iterable, List, Sequence, Tuple class PacketType(IntEnum): JOIN = 0 JOIN_ACK = 1 JOIN_DENY = 2 INPUT = 3 INPUT_BROADCAST = 4 STATE_DELTA = 5 STATE_FULL = 6 PART = 7 CONFIG_UPDATE = 8 PING = 9 PONG = 10 ERROR = 11 class BodyTLV(IntEnum): BODY_2BIT = 0x00 BODY_RLE = 0x01 BODY_2BIT_CHUNK = 0x10 BODY_RLE_CHUNK = 0x11 class Direction(IntEnum): UP = 0 RIGHT = 1 DOWN = 2 LEFT = 3 def quic_varint_encode(value: int) -> bytes: """Encode QUIC varint (RFC 9000).""" if value < 0: raise ValueError("varint must be non-negative") if value <= 0x3F: # 6 bits, 1 byte, 00xx xxxx return bytes([value & 0x3F]) if value <= 0x3FFF: # 14 bits, 2 bytes, 01xx xxxx v = 0x4000 | value return v.to_bytes(2, "big") if value <= 0x3FFFFFFF: # 30 bits, 4 bytes, 10xx xxxx v = 0x80000000 | value return v.to_bytes(4, "big") if value <= 0x3FFFFFFFFFFFFFFF: # 62 bits, 8 bytes, 11xx xxxx v = 0xC000000000000000 | value return v.to_bytes(8, "big") raise ValueError("varint too large") def quic_varint_decode(buf: bytes, offset: int = 0) -> Tuple[int, int]: """Decode QUIC varint starting at offset. Returns (value, next_offset).""" first = buf[offset] prefix = first >> 6 if prefix == 0: return (first & 0x3F, offset + 1) if prefix == 1: v = int.from_bytes(buf[offset : offset + 2], "big") & 0x3FFF return (v, offset + 2) if prefix == 2: v = int.from_bytes(buf[offset : offset + 4], "big") & 0x3FFFFFFF return (v, offset + 4) v = int.from_bytes(buf[offset : offset + 8], "big") & 0x3FFFFFFFFFFFFFFF return (v, offset + 8) def pack_header(version: int, ptype: PacketType, flags: int, seq: int, tick: int | None) -> bytes: """Pack common header fields. Layout: - ver: u8 - type: u8 - flags: u8 - seq: u16 (network order) - tick: optional u16 """ if not (0 <= version <= 255): raise ValueError("version out of range") if not (0 <= flags <= 255): raise ValueError("flags out of range") if not (0 <= seq <= 0xFFFF): raise ValueError("seq out of range") parts = bytearray() parts.append(version & 0xFF) parts.append(int(ptype) & 0xFF) parts.append(flags & 0xFF) parts.extend(seq.to_bytes(2, "big")) if tick is not None: if not (0 <= tick <= 0xFFFF): raise ValueError("tick out of range") parts.extend(tick.to_bytes(2, "big")) return bytes(parts) def unpack_header(buf: bytes, expect_tick: bool) -> Tuple[int, PacketType, int, int, int | None, int]: """Unpack header; returns (ver, type, flags, seq, tick, next_offset).""" if len(buf) < 5: raise ValueError("buffer too small for header") ver = buf[0] ptype = PacketType(buf[1]) flags = buf[2] seq = int.from_bytes(buf[3:5], "big") off = 5 tick = None if expect_tick: if len(buf) < 7: raise ValueError("buffer too small for tick") tick = int.from_bytes(buf[5:7], "big") off = 7 return ver, ptype, flags, seq, tick, off # Message builders/parsers (subset) def build_config_update( *, version: int, seq: int, tick: int, tick_rate: int, wrap_edges: bool, apples_per_snake: int, apples_cap: int, ) -> bytes: header = pack_header(version, PacketType.CONFIG_UPDATE, 0, seq, tick) body = bytearray() body.append(tick_rate & 0xFF) # u8 body.append(1 if wrap_edges else 0) # bool u8 body.append(apples_per_snake & 0xFF) # u8 body.append(apples_cap & 0xFF) # u8 return header + bytes(body) @dataclass class InputEvent: rel_tick_offset: int # relative to base_tick direction: Direction def build_input_broadcast( *, version: int, seq: int, tick: int, player_id: int, input_seq: int, base_tick: int, events: Sequence[InputEvent], apply_at_tick: int | None = None, ) -> bytes: header = pack_header(version, PacketType.INPUT_BROADCAST, 0, seq, tick) body = bytearray() body.append(player_id & 0xFF) body.extend(int(input_seq & 0xFFFF).to_bytes(2, "big")) body.extend(int(base_tick & 0xFFFF).to_bytes(2, "big")) # number of events as QUIC varint body.extend(quic_varint_encode(len(events))) for ev in events: # rel offset as QUIC varint, direction as u8 (low 2 bits) body.extend(quic_varint_encode(int(ev.rel_tick_offset))) body.append(int(ev.direction) & 0x03) # Optional absolute apply_at_tick presence flag + value if apply_at_tick is None: body.append(0) else: body.append(1) body.extend(int(apply_at_tick & 0xFFFF).to_bytes(2, "big")) return header + bytes(body) def pack_body_tlv(t: BodyTLV, payload: bytes) -> bytes: return quic_varint_encode(int(t)) + quic_varint_encode(len(payload)) + payload def bitpack_2bit_directions(directions: Iterable[Direction]) -> bytes: out = bytearray() acc = 0 bits = 0 for d in directions: acc |= (int(d) & 0x03) << bits bits += 2 if bits >= 8: out.append(acc & 0xFF) acc = acc >> 8 bits -= 8 if bits: out.append(acc & 0xFF) # zero-padded high bits return bytes(out) def rle_pack_directions(directions: Sequence[Direction]) -> bytes: """Pack directions as runs: dir (u8) + count (QUIC varint >=1).""" if not directions: return b"" out = bytearray() run_dir = directions[0] run_len = 1 for d in directions[1:]: if d == run_dir: run_len += 1 else: out.append(int(run_dir) & 0xFF) out.extend(quic_varint_encode(run_len)) run_dir = d run_len = 1 out.append(int(run_dir) & 0xFF) out.extend(quic_varint_encode(run_len)) return bytes(out) def build_body_tlv_for_dirs(directions: Sequence[Direction]) -> bytes: """Choose the more compact of 2-bit stream vs RLE for given directions and return its TLV.""" # 2-bit body size two_bit_payload = bitpack_2bit_directions(directions) two_bit = pack_body_tlv(BodyTLV.BODY_2BIT, two_bit_payload) # RLE rle_payload = rle_pack_directions(directions) rle = pack_body_tlv(BodyTLV.BODY_RLE, rle_payload) return two_bit if len(two_bit) <= len(rle) else rle # Join / Ack / Deny def build_join(name_utf8: bytes, preferred_color_id: int | None = None) -> bytes: # Client-side helper (not used in server) raise NotImplementedError def parse_join(buf: bytes, offset: int) -> Tuple[str, int | None, int]: name_len, off = quic_varint_decode(buf, offset) name_b = buf[off : off + name_len] off += name_len preferred = None if off < len(buf): preferred = buf[off] off += 1 name = name_b.decode("utf-8", errors="ignore") return name, preferred, off def build_join_ack( *, version: int, seq: int, player_id: int, color_id: int, width: int, height: int, tick_rate: int, wrap_edges: bool, apples_per_snake: int, apples_cap: int, compression_mode: int, ) -> bytes: header = pack_header(version, PacketType.JOIN_ACK, 0, seq, None) body = bytearray() body.append(player_id & 0xFF) body.append(color_id & 0xFF) body.append(width & 0xFF) body.append(height & 0xFF) body.append(tick_rate & 0xFF) body.append(1 if wrap_edges else 0) body.append(apples_per_snake & 0xFF) body.append(apples_cap & 0xFF) body.append(compression_mode & 0xFF) return header + bytes(body) def build_join_deny(*, version: int, seq: int, reason: str) -> bytes: header = pack_header(version, PacketType.JOIN_DENY, 0, seq, None) rb = reason.encode("utf-8")[:64] return header + quic_varint_encode(len(rb)) + rb # Input (client -> server) def parse_input(buf: bytes, offset: int) -> Tuple[int, int, int, List[InputEvent], int]: ack_seq = int.from_bytes(buf[offset : offset + 2], "big") offset += 2 input_seq = int.from_bytes(buf[offset : offset + 2], "big") offset += 2 base_tick = int.from_bytes(buf[offset : offset + 2], "big") offset += 2 n_ev, offset = quic_varint_decode(buf, offset) events: List[InputEvent] = [] for _ in range(n_ev): rel, offset = quic_varint_decode(buf, offset) d = Direction(buf[offset] & 0x03) offset += 1 events.append(InputEvent(rel_tick_offset=int(rel), direction=d)) return ack_seq, input_seq, base_tick, events, offset # State snapshot (server -> client) def build_state_full( *, version: int, seq: int, tick: int, snakes: Sequence[Tuple[int, int, int, int, Sequence[Direction]]], apples: Sequence[Tuple[int, int]], ) -> bytes: """Build a minimal state_full: per-snake header + BODY_2BIT TLV; apples list. snakes: sequence of (snake_id, len, head_x, head_y, body_dirs_from_head) apples: sequence of (x, y) """ header = pack_header(version, PacketType.STATE_FULL, 0, seq, tick & 0xFFFF) body = build_state_full_body(snakes=snakes, apples=apples) return header + body def build_state_full_body( *, snakes: Sequence[Tuple[int, int, int, int, Sequence[Direction]]], apples: Sequence[Tuple[int, int]] ) -> bytes: body = bytearray() # snakes count body.extend(quic_varint_encode(len(snakes))) for sid, slen, hx, hy, dirs in snakes: body.append(sid & 0xFF) body.extend(int(slen & 0xFFFF).to_bytes(2, "big")) body.append(hx & 0xFF) body.append(hy & 0xFF) tlv = build_body_tlv_for_dirs(dirs) body.extend(tlv) # apples body.extend(quic_varint_encode(len(apples))) for ax, ay in apples: body.append(ax & 0xFF) body.append(ay & 0xFF) return bytes(body) def build_body_2bit_chunk(directions: Sequence[Direction], start_index: int, dirs_in_chunk: int) -> bytes: """Build chunk TLV for a sub-range of directions (2-bit).""" sub = directions[start_index : start_index + dirs_in_chunk] payload = bytearray() payload.extend(int(start_index & 0xFFFF).to_bytes(2, "big")) payload.extend(int(dirs_in_chunk & 0xFFFF).to_bytes(2, "big")) payload.extend(bitpack_2bit_directions(sub)) return pack_body_tlv(BodyTLV.BODY_2BIT_CHUNK, bytes(payload)) def estimate_body_2bit_bytes(snake_len: int) -> int: used_bits = max(0, (snake_len - 1) * 2) return (used_bits + 7) // 8 # --- State Delta --- @dataclass class SnakeDelta: snake_id: int head_moved: bool tail_removed: bool grew: bool blocked: bool new_head_x: int = 0 new_head_y: int = 0 direction: Direction = Direction.RIGHT def build_state_delta_body( *, update_id: int, changes: Sequence[SnakeDelta], apples_added: Sequence[Tuple[int, int]], apples_removed: Sequence[Tuple[int, int]] ) -> bytes: body = bytearray() body.extend(int(update_id & 0xFFFF).to_bytes(2, "big")) # snakes body.extend(quic_varint_encode(len(changes))) for ch in changes: body.append(ch.snake_id & 0xFF) flags = ( (1 if ch.head_moved else 0) | ((1 if ch.tail_removed else 0) << 1) | ((1 if ch.grew else 0) << 2) | ((1 if ch.blocked else 0) << 3) ) body.append(flags & 0xFF) body.append(int(ch.direction) & 0x03) if ch.head_moved: body.append(ch.new_head_x & 0xFF) body.append(ch.new_head_y & 0xFF) # apples added body.extend(quic_varint_encode(len(apples_added))) for ax, ay in apples_added: body.append(ax & 0xFF) body.append(ay & 0xFF) # apples removed body.extend(quic_varint_encode(len(apples_removed))) for rx, ry in apples_removed: body.append(rx & 0xFF) body.append(ry & 0xFF) return bytes(body) def build_state_delta( *, version: int, seq: int, tick: int, update_id: int, changes: Sequence[SnakeDelta], apples_added: Sequence[Tuple[int, int]], apples_removed: Sequence[Tuple[int, int]] ) -> bytes: header = pack_header(version, PacketType.STATE_DELTA, 0, seq, tick & 0xFFFF) body = build_state_delta_body(update_id=update_id, changes=changes, apples_added=apples_added, apples_removed=apples_removed) return header + body def build_part( *, version: int, seq: int, tick: int, update_id: int, part_index: int, parts_total: int, inner_type: PacketType, chunk_payload: bytes, ) -> bytes: header = pack_header(version, PacketType.PART, 0, seq, tick & 0xFFFF) body = bytearray() body.extend(int(update_id & 0xFFFF).to_bytes(2, "big")) body.append(part_index & 0xFF) body.append(parts_total & 0xFF) body.append(int(inner_type) & 0xFF) # include the chunk payload bytes body.extend(chunk_payload) return header + bytes(body)