"""Tests for binary codec.""" import pytest from src.shared.binary_codec import BinaryCodec, FieldType, FieldID from src.shared.models import Position, Snake class TestPositionEncoding: """Test position encoding/decoding.""" def test_encode_decode_position(self): """Test position round-trip.""" positions = [ Position(0, 0), Position(39, 29), # Max grid position Position(20, 15), Position(1, 1), ] for pos in positions: encoded = BinaryCodec.encode_position(pos) decoded = BinaryCodec.decode_position(encoded) assert decoded.x == pos.x assert decoded.y == pos.y def test_packed_positions(self): """Test packed position encoding.""" positions = [ Position(0, 0), Position(1, 0), Position(2, 0), Position(3, 0), Position(3, 1), ] # Encode packed = BinaryCodec.encode_packed_positions(positions) # Decode decoded = BinaryCodec.decode_packed_positions(packed, len(positions)) assert len(decoded) == len(positions) for orig, dec in zip(positions, decoded): assert dec.x == orig.x assert dec.y == orig.y def test_delta_encoding(self): """Test delta position encoding.""" # Snake body (adjacent positions) positions = [ Position(5, 5), Position(6, 5), # Right Position(7, 5), # Right Position(7, 6), # Down Position(7, 7), # Down Position(6, 7), # Left Position(6, 6), # Up ] # Encode encoded = BinaryCodec.encode_delta_positions(positions) # Decode decoded = BinaryCodec.decode_delta_positions(encoded, len(positions)) assert len(decoded) == len(positions) for orig, dec in zip(positions, decoded): assert dec.x == orig.x assert dec.y == orig.y class TestVarint: """Test variable-length integer encoding.""" def test_small_values(self): """Test small varint values.""" for value in [0, 1, 10, 127]: encoded = BinaryCodec.encode_varint(value) decoded, offset = BinaryCodec.decode_varint(encoded, 0) assert decoded == value assert offset == len(encoded) def test_large_values(self): """Test large varint values.""" values = [128, 255, 1000, 10000, 65535, 1000000] for value in values: encoded = BinaryCodec.encode_varint(value) decoded, offset = BinaryCodec.decode_varint(encoded, 0) assert decoded == value def test_varint_size(self): """Test varint encoding size.""" # Small values use 1 byte assert len(BinaryCodec.encode_varint(0)) == 1 assert len(BinaryCodec.encode_varint(127)) == 1 # Values >= 128 use 2+ bytes assert len(BinaryCodec.encode_varint(128)) == 2 assert len(BinaryCodec.encode_varint(255)) == 2 assert len(BinaryCodec.encode_varint(16383)) == 2 assert len(BinaryCodec.encode_varint(16384)) == 3 class TestStringEncoding: """Test string encoding.""" def test_short_strings(self): """Test encoding short strings.""" strings = ["", "A", "Alice", "Player123"] for s in strings: encoded = BinaryCodec.encode_string_16(s) decoded, consumed = BinaryCodec.decode_string_16(encoded) assert decoded == s def test_max_length_string(self): """Test 16-character string.""" s = "VeryLongUsername" assert len(s) == 16 encoded = BinaryCodec.encode_string_16(s) decoded, consumed = BinaryCodec.decode_string_16(encoded) assert decoded == s def test_truncation(self): """Test string truncation.""" s = "ThisIsAVeryLongUsernameThatExceedsLimit" encoded = BinaryCodec.encode_string_16(s) decoded, consumed = BinaryCodec.decode_string_16(encoded) # Should be truncated to 16 chars assert len(decoded) <= 16 def test_unicode_strings(self): """Test Unicode string encoding.""" strings = ["Hello世界", "Café", "🎮Player"] for s in strings: encoded = BinaryCodec.encode_string_16(s) decoded, consumed = BinaryCodec.decode_string_16(encoded) # Might be truncated due to UTF-8 byte limits assert decoded.startswith(s[:min(len(s), 10)]) class TestPlayerIdHash: """Test player ID hashing.""" def test_consistent_hashing(self): """Test hash consistency.""" player_id = "550e8400-e29b-41d4-a716-446655440000" hash1 = BinaryCodec.player_id_hash(player_id) hash2 = BinaryCodec.player_id_hash(player_id) assert hash1 == hash2 def test_different_ids(self): """Test different IDs produce different hashes.""" id1 = "550e8400-e29b-41d4-a716-446655440000" id2 = "550e8400-e29b-41d4-a716-446655440001" hash1 = BinaryCodec.player_id_hash(id1) hash2 = BinaryCodec.player_id_hash(id2) assert hash1 != hash2 def test_hash_range(self): """Test hash is within uint32 range.""" for i in range(100): player_id = f"player_{i}" hash_val = BinaryCodec.player_id_hash(player_id) assert 0 <= hash_val <= 0xFFFFFFFF class TestCompression: """Test compression/decompression.""" def test_compress_decompress(self): """Test compression round-trip.""" data = b"This is test data that should compress well. " * 10 compressed = BinaryCodec.compress(data) decompressed = BinaryCodec.decompress(compressed) assert decompressed == data assert len(compressed) < len(data) # Should be smaller def test_small_data(self): """Test compression of small data.""" data = b"short" compressed = BinaryCodec.compress(data) decompressed = BinaryCodec.decompress(compressed) assert decompressed == data # Small data might not compress well # Just verify round-trip works def test_empty_data(self): """Test compression of empty data.""" data = b"" compressed = BinaryCodec.compress(data) decompressed = BinaryCodec.decompress(compressed) assert decompressed == data class TestPayloadDecoding: """Test payload field decoding.""" def test_simple_payload(self): """Test decoding simple payload.""" # Manually construct payload payload = bytearray() payload.append(BinaryCodec.VERSION) # Version payload.append(2) # Field count # Field 1: UPDATE_ID (uint16) payload.append(FieldID.UPDATE_ID) payload.append(FieldType.UINT16) payload.extend(BinaryCodec.encode_varint(2)) # Length payload.extend(b'\x00\x64') # Value: 100 # Field 2: GAME_RUNNING (uint8) payload.append(FieldID.GAME_RUNNING) payload.append(FieldType.UINT8) payload.extend(BinaryCodec.encode_varint(1)) # Length payload.append(1) # True # Decode fields = self._decode_payload(bytes(payload)) assert len(fields) >= 2 # Check fields are present field_ids = [f[0] for f in fields] assert FieldID.UPDATE_ID in field_ids assert FieldID.GAME_RUNNING in field_ids def _decode_payload(self, payload: bytes): """Helper to decode payload.""" if len(payload) < 2: return [] version = payload[0] field_count = payload[1] fields = [] offset = 2 for _ in range(field_count): if offset + 2 > len(payload): break field_id = payload[offset] field_type = payload[offset + 1] offset += 2 length, offset = BinaryCodec.decode_varint(payload, offset) if offset + length > len(payload): break field_data = payload[offset:offset + length] offset += length fields.append((field_id, field_type, field_data)) return fields class TestEdgeCases: """Test edge cases.""" def test_max_grid_position(self): """Test maximum grid position (39, 29).""" pos = Position(39, 29) encoded = BinaryCodec.encode_position(pos) decoded = BinaryCodec.decode_position(encoded) assert decoded.x == 39 assert decoded.y == 29 def test_empty_position_list(self): """Test empty position list.""" positions = [] encoded = BinaryCodec.encode_packed_positions(positions) decoded = BinaryCodec.decode_packed_positions(encoded, 0) assert decoded == [] def test_single_position(self): """Test single position.""" positions = [Position(10, 10)] encoded = BinaryCodec.encode_delta_positions(positions) decoded = BinaryCodec.decode_delta_positions(encoded, 1) assert len(decoded) == 1 assert decoded[0].x == 10 assert decoded[0].y == 10 if __name__ == "__main__": pytest.main([__file__, "-v"])