"""Tests for UDP protocol with sequence numbers.""" import pytest from src.shared.udp_protocol import UDPProtocol, SequenceTracker from src.shared.binary_codec import MessageType class TestSequenceNumbers: """Test sequence number wrapping and validation.""" def test_is_seq_newer_basic(self): """Test basic sequence comparison.""" assert UDPProtocol.is_seq_newer(1, 0) == True assert UDPProtocol.is_seq_newer(100, 50) == True assert UDPProtocol.is_seq_newer(0, 0) == False # Duplicate assert UDPProtocol.is_seq_newer(50, 100) == False # Old def test_is_seq_newer_wrapping(self): """Test sequence wrapping around UINT32_MAX.""" # Near wrapping boundary last_seq = 0xFFFFFFFF - 5 # UINT32_MAX - 5 = 4294967290 # Small increments should work assert UDPProtocol.is_seq_newer(0xFFFFFFFF - 4, last_seq) == True assert UDPProtocol.is_seq_newer(0xFFFFFFFF - 3, last_seq) == True assert UDPProtocol.is_seq_newer(0xFFFFFFFF, last_seq) == True # Wrapped around assert UDPProtocol.is_seq_newer(0, last_seq) == True assert UDPProtocol.is_seq_newer(1, last_seq) == True assert UDPProtocol.is_seq_newer(5, last_seq) == True assert UDPProtocol.is_seq_newer(10, last_seq) == True # Old packets (before wrap) assert UDPProtocol.is_seq_newer(0xFFFFFFFF - 10, last_seq) == False def test_is_seq_newer_window(self): """Test window size enforcement.""" last_seq = 1000 window = 100 # Within window assert UDPProtocol.is_seq_newer(1050, last_seq, window) == True assert UDPProtocol.is_seq_newer(1100, last_seq, window) == True # Exactly at window boundary assert UDPProtocol.is_seq_newer(1101, last_seq, window) == False # Too far ahead assert UDPProtocol.is_seq_newer(1200, last_seq, window) == False def test_sequence_wraparound_multiple_times(self): """Test multiple wraparounds.""" # Start near max last_seq = 0xFFFFFFFF - 2 # Increment through wrap assert UDPProtocol.is_seq_newer(0xFFFFFFFF - 1, last_seq) == True last_seq = 0xFFFFFFFF - 1 assert UDPProtocol.is_seq_newer(0xFFFFFFFF, last_seq) == True last_seq = 0xFFFFFFFF assert UDPProtocol.is_seq_newer(0, last_seq) == True last_seq = 0 assert UDPProtocol.is_seq_newer(1, last_seq) == True class TestSequenceTracker: """Test SequenceTracker class.""" def test_basic_tracking(self): """Test basic sequence tracking.""" tracker = SequenceTracker() assert tracker.should_accept(1) == True assert tracker.should_accept(2) == True assert tracker.should_accept(3) == True # Duplicate assert tracker.should_accept(3) == False # Old assert tracker.should_accept(2) == False def test_reordering_within_window(self): """Test packet reordering within window.""" tracker = SequenceTracker() # Receive out of order assert tracker.should_accept(5) == True assert tracker.should_accept(3) == False # Older, reject assert tracker.should_accept(6) == True assert tracker.should_accept(4) == False # Older, reject assert tracker.should_accept(7) == True def test_wrapping_tracking(self): """Test tracking through wraparound.""" tracker = SequenceTracker() tracker.last_seq = 0xFFFFFFFF - 5 # Accept packets through wrap assert tracker.should_accept(0xFFFFFFFF - 4) == True assert tracker.should_accept(0xFFFFFFFF - 3) == True assert tracker.should_accept(0xFFFFFFFF) == True assert tracker.should_accept(0) == True assert tracker.should_accept(1) == True def test_cleanup(self): """Test sequence set cleanup.""" tracker = SequenceTracker() # Add many sequences for i in range(1, 1500): tracker.should_accept(i) # Should have cleaned up old sequences assert len(tracker.received_seqs) <= 1000 class TestUDPPackets: """Test UDP packet creation and parsing.""" def test_create_and_parse_packet(self): """Test packet creation and parsing round-trip.""" seq_num = 12345 msg_type = MessageType.PARTIAL_STATE_UPDATE update_id = 678 payload = b"test payload data" # Create packet packet = UDPProtocol.create_packet(seq_num, msg_type, update_id, payload, compress=False) # Parse packet result = UDPProtocol.parse_packet(packet) assert result is not None parsed_seq, parsed_type, parsed_id, parsed_payload = result assert parsed_seq == seq_num assert parsed_type == msg_type assert parsed_id == update_id assert parsed_payload == payload def test_packet_compression(self): """Test packet compression.""" seq_num = 100 msg_type = MessageType.GAME_META_UPDATE update_id = 200 payload = b"x" * 500 # Compressible payload # Create with compression packet_compressed = UDPProtocol.create_packet(seq_num, msg_type, update_id, payload, compress=True) # Create without compression packet_uncompressed = UDPProtocol.create_packet(seq_num, msg_type, update_id, payload, compress=False) # Compressed should be smaller assert len(packet_compressed) < len(packet_uncompressed) # Both should parse correctly result = UDPProtocol.parse_packet(packet_compressed) assert result is not None _, _, _, parsed_payload = result assert parsed_payload == payload def test_update_id_wrapping(self): """Test update ID wrapping.""" assert UDPProtocol.next_update_id(0xFFFF) == 0 assert UDPProtocol.next_update_id(0xFFFE) == 0xFFFF assert UDPProtocol.next_update_id(0) == 1 def test_sequence_wrapping(self): """Test sequence number wrapping.""" assert UDPProtocol.next_sequence(0xFFFFFFFF) == 0 assert UDPProtocol.next_sequence(0xFFFFFFFE) == 0xFFFFFFFF assert UDPProtocol.next_sequence(0) == 1 class TestEdgeCases: """Test edge cases and error handling.""" def test_invalid_packet(self): """Test parsing invalid packet.""" # Too short assert UDPProtocol.parse_packet(b"short") is None # Empty assert UDPProtocol.parse_packet(b"") is None def test_corrupted_compression(self): """Test handling corrupted compressed data.""" seq_num = 100 msg_type = 0x81 # Compression flag set update_id = 200 # Create packet header with invalid compressed payload import struct header = struct.pack('>IBH', seq_num, msg_type, update_id) packet = header + b"invalid compressed data" # Should return None due to decompression failure result = UDPProtocol.parse_packet(packet) assert result is None def test_large_sequence_gap(self): """Test very large sequence gaps.""" tracker = SequenceTracker() tracker.last_seq = 100 # Very large gap (suspicious) assert tracker.should_accept(2000) == False # But within window is ok assert tracker.should_accept(1100) == True if __name__ == "__main__": pytest.main([__file__, "-v"])