from __future__ import annotations import asyncio from dataclasses import dataclass from typing import Dict, List, Optional, Tuple from collections import deque from .config import ServerConfig from .model import GameState, PlayerSession, Snake, Coord from .protocol import ( Direction, InputEvent, PacketType, build_join_ack, build_join_deny, build_config_update, build_input_broadcast, build_state_full, build_state_full_body, build_body_tlv_for_dirs, build_body_2bit_chunk, build_state_delta, build_state_delta_body, build_part, SnakeDelta, pack_header, parse_input, parse_join, ) from .transport import DatagramServerTransport, TransportPeer @dataclass class ServerRuntime: config: ServerConfig state: GameState seq: int = 0 version: int = 1 update_id: int = 0 def next_seq(self) -> int: self.seq = (self.seq + 1) & 0xFFFF return self.seq def next_update_id(self) -> int: self.update_id = (self.update_id + 1) & 0xFFFF return self.update_id class GameServer: def __init__(self, transport: DatagramServerTransport, config: ServerConfig): self.transport = transport self.runtime = ServerRuntime(config=config, state=GameState(config.width, config.height)) self.sessions: Dict[int, PlayerSession] = {} self._config_update_interval_ticks = 50 # periodic resend async def on_datagram(self, data: bytes, peer: TransportPeer) -> None: # Minimal header parse if len(data) < 5: return ver = data[0] ptype = PacketType(data[1]) flags = data[2] # seq = int.from_bytes(data[3:5], 'big') # currently unused off = 5 if ptype == PacketType.JOIN: await self._handle_join(data, off, peer) elif ptype == PacketType.INPUT: await self._handle_input(data, off, peer) else: # ignore others in skeleton return async def broadcast_config_update(self) -> None: r = self.runtime payload = build_config_update( version=r.version, seq=r.next_seq(), tick=r.state.tick & 0xFFFF, tick_rate=r.config.tick_rate, wrap_edges=r.config.wrap_edges, apples_per_snake=r.config.apples_per_snake, apples_cap=r.config.apples_cap, ) # Broadcast to all sessions (requires real peer handles) for session in list(self.sessions.values()): await self.transport.send(payload, TransportPeer(session.peer)) async def relay_input_broadcast( self, *, from_player_id: int, input_seq: int, base_tick: int, events: List[InputEvent], apply_at_tick: int | None, ) -> None: r = self.runtime payload = build_input_broadcast( version=r.version, seq=r.next_seq(), tick=r.state.tick & 0xFFFF, player_id=from_player_id, input_seq=input_seq, base_tick=base_tick & 0xFFFF, events=events, apply_at_tick=apply_at_tick, ) for session in list(self.sessions.values()): if session.player_id == from_player_id: continue await self.transport.send(payload, TransportPeer(session.peer)) async def tick_loop(self) -> None: r = self.runtime tick_duration = 1.0 / max(1, r.config.tick_rate) next_cfg_resend = self._config_update_interval_ticks while True: start = asyncio.get_event_loop().time() # process inputs, update snakes, collisions, apples, deltas await self._simulate_tick() r.state.tick = (r.state.tick + 1) & 0xFFFFFFFF next_cfg_resend -= 1 if next_cfg_resend <= 0: await self.broadcast_config_update() next_cfg_resend = self._config_update_interval_ticks elapsed = asyncio.get_event_loop().time() - start await asyncio.sleep(max(0.0, tick_duration - elapsed)) # --- Simulation --- def _consume_input_for_snake(self, s: Snake) -> None: # Consume at most one input; skip 180° turns when length>1 while s.input_buf: nd = s.input_buf[0] # 180-degree check if s.length > 1 and ((int(nd) ^ int(s.direction)) == 2): s.input_buf.popleft() continue # Accept s.direction = nd s.input_buf.popleft() break def _step_from(self, x: int, y: int, d: Direction) -> Tuple[int, int, bool]: dx = 1 if d == Direction.RIGHT else -1 if d == Direction.LEFT else 0 dy = 1 if d == Direction.DOWN else -1 if d == Direction.UP else 0 nx, ny = x + dx, y + dy st = self.runtime.state wrap = self.runtime.config.wrap_edges wrapped = False if wrap: if nx < 0: nx = st.width - 1; wrapped = True elif nx >= st.width: nx = 0; wrapped = True if ny < 0: ny = st.height - 1; wrapped = True elif ny >= st.height: ny = 0; wrapped = True return nx, ny, wrapped async def _simulate_tick(self) -> None: st = self.runtime.state cfg = self.runtime.config apples_eaten: List[Coord] = [] apples_before = set(st.apples) changes: List[SnakeDelta] = [] # Prepare snapshot of tails to allow moving into own tail when it vacates tails: Dict[int, Coord] = {} for sid, s in st.snakes.items(): if s.length > 1: tails[sid] = s.body[-1] # Process snakes for sid, s in st.snakes.items(): # Consume one input if available self._consume_input_for_snake(s) hx, hy = s.body[0] nx, ny, wrapped = self._step_from(hx, hy, s.direction) # Check wall if no wrap obstacle = False if not self.runtime.config.wrap_edges and not st.in_bounds(nx, ny): obstacle = True else: # Bounds correction already handled by _step_from # Occupancy check occ = st.occupancy.get((nx, ny)) if occ is not None: # Allow moving into own tail if it will vacate (not growing) own_tail_ok = (s.length > 1 and (nx, ny) == tails.get(sid)) if not own_tail_ok: obstacle = True if obstacle: s.blocked = True # shrink tail by 1 (to min 1) if s.length > 1: tx, ty = s.body.pop() st.occupancy.pop((tx, ty), None) changes.append( SnakeDelta( snake_id=sid, head_moved=False, tail_removed=(s.length > 1), grew=False, blocked=True, new_head_x=hx, new_head_y=hy, direction=s.direction, ) ) continue # Move or grow s.blocked = False will_grow = (nx, ny) in st.apples # Add new head s.body.appendleft((nx, ny)) st.occupancy[(nx, ny)] = (sid, 0) if will_grow: # eat apple; no tail removal this tick st.apples.remove((nx, ny)) apples_eaten.append((nx, ny)) changes.append( SnakeDelta( snake_id=sid, head_moved=True, tail_removed=False, grew=True, blocked=False, new_head_x=nx, new_head_y=ny, direction=s.direction, ) ) else: # normal move: remove tail (unless length==0) tx, ty = s.body.pop() if (tx, ty) in st.occupancy: st.occupancy.pop((tx, ty), None) changes.append( SnakeDelta( snake_id=sid, head_moved=True, tail_removed=True, grew=False, blocked=False, new_head_x=nx, new_head_y=ny, direction=s.direction, ) ) # Replenish apples to target self._ensure_apples() apples_after = set(st.apples) apples_added = sorted(list(apples_after - apples_before)) apples_removed = sorted(list(apples_before - apples_after)) # Broadcast a basic delta (currently full snakes + apples as delta) update_id = self.runtime.next_update_id() # Serialize full delta body once, then partition if large full_body = build_state_delta_body( update_id=update_id, changes=changes, apples_added=apples_added, apples_removed=apples_removed, ) MTU = 1200 # soft limit for payload to avoid fragmentation if len(full_body) <= MTU: packet = pack_header(self.runtime.version, PacketType.STATE_DELTA, 0, self.runtime.next_seq(), st.tick & 0xFFFF) + full_body for session in list(self.sessions.values()): await self.transport.send(packet, TransportPeer(session.peer)) else: # Partition by splitting snake changes across parts; include apples only in the first part parts: List[bytes] = [] remaining = list(changes) idx = 0 part_index = 0 while remaining: chunk: List[SnakeDelta] = [] # greedy pack until size would exceed MTU # start with apples only for first part add_ap = apples_added if part_index == 0 else [] rem_ap = apples_removed if part_index == 0 else [] # Try adding changes one by one for i, ch in enumerate(remaining): tmp_body = build_state_delta_body(update_id=update_id, changes=chunk + [ch], apples_added=add_ap, apples_removed=rem_ap) if len(tmp_body) > MTU and chunk: break if len(tmp_body) > MTU: # even a single change + apples doesn't fit; force single chunk.append(ch) i += 1 break chunk.append(ch) # Remove chunked items remaining = remaining[len(chunk) :] # Build chunk body chunk_body = build_state_delta_body(update_id=update_id, changes=chunk, apples_added=add_ap, apples_removed=rem_ap) parts.append(chunk_body) part_index += 1 # Emit PART packets total = len(parts) for i, body in enumerate(parts): pkt = build_part( version=self.runtime.version, seq=self.runtime.next_seq(), tick=st.tick & 0xFFFF, update_id=update_id, part_index=i, parts_total=total, inner_type=PacketType.STATE_DELTA, chunk_payload=body, ) for session in list(self.sessions.values()): await self.transport.send(pkt, TransportPeer(session.peer)) # --- Join / Spawn --- def _allocate_player_id(self) -> Optional[int]: used = {s.player_id for s in self.sessions.values()} for pid in range(self.runtime.config.players_max): if pid not in used: return pid return None def _choose_color_id(self) -> int: used = {s.color_id for s in self.sessions.values()} for cid in range(32): if cid not in used: return cid return 0 def _neighbors(self, x: int, y: int) -> List[Tuple[Direction, Coord]]: return [ (Direction.UP, (x, y - 1)), (Direction.RIGHT, (x + 1, y)), (Direction.DOWN, (x, y + 1)), (Direction.LEFT, (x - 1, y)), ] def _find_spawn(self) -> Optional[Snake]: st = self.runtime.state # Try to find a 3-cell straight strip for y in range(st.height): for x in range(st.width): if not st.cell_free(x, y): continue for d, (nx, ny) in self._neighbors(x, y): # check two cells in direction x2, y2 = nx, ny x3, y3 = nx + (1 if d == Direction.RIGHT else -1 if d == Direction.LEFT else 0), ny + (1 if d == Direction.DOWN else -1 if d == Direction.UP else 0) if st.in_bounds(x2, y2) and st.in_bounds(x3, y3) and st.cell_free(x2, y2) and st.cell_free(x3, y3): body = [ (x, y), (x2, y2), (x3, y3), ] return Snake(snake_id=-1, head=(x, y), direction=d, body=deque(body)) # Fallback: any single free cell for y in range(st.height): for x in range(st.width): if st.cell_free(x, y): return Snake(snake_id=-1, head=(x, y), direction=Direction.RIGHT, body=deque([(x, y)])) return None def _ensure_apples(self) -> None: st = self.runtime.state cfg = self.runtime.config import random random.seed(0xC0FFEE) target = 3 if not self.sessions else min(cfg.apples_cap, max(0, len(self.sessions) * cfg.apples_per_snake)) # grow apples up to target while len(st.apples) < target: x = random.randrange(st.width) y = random.randrange(st.height) if st.cell_free(x, y) and (x, y) not in st.apples: st.apples.append((x, y)) # shrink if too many if len(st.apples) > target: st.apples = st.apples[:target] async def _handle_join(self, buf: bytes, off: int, peer: TransportPeer) -> None: name, preferred, off2 = parse_join(buf, off) # enforce name <=16 bytes utf-8 name = name.encode("utf-8")[:16].decode("utf-8", errors="ignore") pid = self._allocate_player_id() if pid is None: payload = build_join_deny(version=self.runtime.version, seq=self.runtime.next_seq(), reason="Server full") await self.transport.send(payload, peer) return # Spawn snake snake = self._find_spawn() if snake is None: payload = build_join_deny(version=self.runtime.version, seq=self.runtime.next_seq(), reason="No free cell, please wait") await self.transport.send(payload, peer) return # Register session and snake color_id = preferred if (preferred is not None) else self._choose_color_id() session = PlayerSession(player_id=pid, name=name, color_id=color_id, peer=peer.addr) self.sessions[pid] = session snake.snake_id = pid self.runtime.state.snakes[pid] = snake self.runtime.state.occupy_snake(snake) self._ensure_apples() # Send join_ack cfg = self.runtime.config ack = build_join_ack( version=self.runtime.version, seq=self.runtime.next_seq(), player_id=pid, color_id=color_id, width=cfg.width, height=cfg.height, tick_rate=cfg.tick_rate, wrap_edges=cfg.wrap_edges, apples_per_snake=cfg.apples_per_snake, apples_cap=cfg.apples_cap, compression_mode=0 if cfg.compression_mode == "none" else 1, ) await self.transport.send(ack, peer) # Send initial state_full (partitioned if needed) snakes_dirs: List[Tuple[int, int, int, int, List[Direction]]] = [] for s in self.runtime.state.snakes.values(): dirs: List[Direction] = [] # build directions from consecutive coords: from head toward tail coords = list(s.body) for i in range(len(coords) - 1): x0, y0 = coords[i] x1, y1 = coords[i + 1] if x1 == x0 and y1 == y0 - 1: dirs.append(Direction.UP) elif x1 == x0 + 1 and y1 == y0: dirs.append(Direction.RIGHT) elif x1 == x0 and y1 == y0 + 1: dirs.append(Direction.DOWN) elif x1 == x0 - 1 and y1 == y0: dirs.append(Direction.LEFT) hx, hy = coords[0] snakes_dirs.append((s.snake_id, s.length, hx, hy, dirs)) # Build body and partition across parts if needed body = build_state_full_body(snakes=snakes_dirs, apples=self.runtime.state.apples) MTU = 1200 if len(body) <= MTU: full = pack_header( self.runtime.version, PacketType.STATE_FULL, 0, self.runtime.next_seq(), self.runtime.state.tick & 0xFFFF ) + body await self.transport.send(full, peer) else: # Partition by packing whole snakes first, apples only in first part; chunk a single oversized snake using 2-bit chunks update_id = self.runtime.next_update_id() parts: List[bytes] = [] # Prepare apples buffer for first part def encode_apples(apples: List[Coord]) -> bytes: from .protocol import quic_varint_encode b = bytearray() b.extend(quic_varint_encode(len(apples))) for ax, ay in apples: b.append(ax & 0xFF) b.append(ay & 0xFF) return bytes(b) apples_encoded = encode_apples(self.runtime.state.apples) # Cursor over snakes, building per part bodies remaining = list(snakes_dirs) first = True while remaining: # Start part body with snakes count placeholder; we'll rebuild as we pack part_snakes: List[bytes] = [] packed_snakes = 0 budget = MTU # Reserve apples size on first part apples_this = apples_encoded if first else encode_apples([]) budget -= len(apples_this) # Greedily add snakes i = 0 while i < len(remaining): sid, slen, hx, hy, dirs = remaining[i] # Build this snake record with chosen TLV tlv = build_body_tlv_for_dirs(dirs) record = bytearray() # one snake: id u8, len u16, head x y, TLV record.append(sid & 0xFF) record.extend(int(slen & 0xFFFF).to_bytes(2, "big")) record.append(hx & 0xFF) record.append(hy & 0xFF) record.extend(tlv) if len(record) <= budget: part_snakes.append(bytes(record)) budget -= len(record) packed_snakes += 1 i += 1 continue # If single snake doesn't fit and no snakes packed yet, chunk this snake if packed_snakes == 0: # Use 2-bit chunking; compute directions chunk size to fit budget minus fixed headers (id/len/head + TLV type/len + chunk header) # Fixed snake header = 1+2+1+1 = 5 bytes; TLV type/len (worst-case 2 bytes for small values) + chunk header 4 bytes overhead = 5 + 2 + 4 dir_capacity_bytes = max(0, budget - overhead) if dir_capacity_bytes <= 0: break # Convert capacity bytes to number of directions (each 2 bits → 4 dirs per byte) dir_capacity = dir_capacity_bytes * 4 if dir_capacity <= 0: break # Emit at least one direction dirs_in_chunk = max(1, min(dir_capacity, len(dirs))) tlv_chunk = build_body_2bit_chunk(dirs, start_index=0, dirs_in_chunk=dirs_in_chunk) record = bytearray() record.append(sid & 0xFF) record.extend(int(slen & 0xFFFF).to_bytes(2, "big")) record.append(hx & 0xFF) record.append(hy & 0xFF) record.extend(tlv_chunk) if len(record) <= budget: part_snakes.append(bytes(record)) budget -= len(record) # Replace remaining snake with truncated version (advance dirs) remaining[i] = (sid, slen, hx, hy, list(dirs)[dirs_in_chunk:]) packed_snakes += 1 else: break else: break # Build part body: snakes_count + snakes + apples body_part = bytearray() from .protocol import quic_varint_encode body_part.extend(quic_varint_encode(len(part_snakes))) for rec in part_snakes: body_part.extend(rec) body_part.extend(apples_this) parts.append(bytes(body_part)) # Advance remaining list remaining = remaining[packed_snakes:] first = False # Emit part packets total = len(parts) for i, chunk_body in enumerate(parts): pkt = build_part( version=self.runtime.version, seq=self.runtime.next_seq(), tick=self.runtime.state.tick & 0xFFFF, update_id=update_id, part_index=i, parts_total=total, inner_type=PacketType.STATE_FULL, chunk_payload=chunk_body, ) await self.transport.send(pkt, peer) async def _handle_input(self, buf: bytes, off: int, peer: TransportPeer) -> None: try: ack_seq, input_seq, base_tick, events, off2 = parse_input(buf, off) except Exception: return # Find player by peer player_id = None for pid, sess in self.sessions.items(): if sess.peer is peer.addr: player_id = pid break if player_id is None: return # Relay to others immediately for prediction await self.relay_input_broadcast( from_player_id=player_id, input_seq=input_seq, base_tick=base_tick, events=events, apply_at_tick=None, ) async def main() -> None: from .transport import InMemoryTransport cfg = ServerConfig() server = GameServer(transport=InMemoryTransport(lambda d, p: server.on_datagram(d, p)), config=cfg) # In-memory transport never returns; run tick loop in parallel await asyncio.gather(server.transport.run(), server.tick_loop()) if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: pass