from __future__ import annotations import hashlib import re import tempfile from typing import List, Optional from pathlib import Path def is_newer(a: int, b: int) -> bool: """Return True if 16-bit sequence number a is newer than b (wrap-aware). Uses half-range window on unsigned 16-bit arithmetic. """ return ((a - b) & 0xFFFF) < 0x8000 def get_cert_sha256_hash(cert_path: str) -> str: """Calculate SHA-256 hash of a certificate file. Returns the hash as a lowercase hex string (64 characters). This hash can be used by WebTransport clients for certificate pinning. """ try: from cryptography import x509 from cryptography.hazmat.backends import default_backend # Read the certificate file with open(cert_path, 'rb') as f: cert_data = f.read() # Parse the certificate cert = x509.load_pem_x509_certificate(cert_data, default_backend()) # Get the DER-encoded certificate and hash it cert_der = cert.public_bytes(encoding=x509.Encoding.DER) hash_bytes = hashlib.sha256(cert_der).digest() # Return as hex string return hash_bytes.hex() except Exception as e: # Fallback: just hash the file contents (less accurate but works) with open(cert_path, 'rb') as f: return hashlib.sha256(f.read()).digest().hex() def split_pem_certificates(pem_data: str) -> List[str]: """Split a PEM file containing multiple certificates into individual certs. Args: pem_data: String containing one or more PEM certificates Returns: List of individual certificate strings (each includes BEGIN/END markers) """ # Match all certificate blocks cert_pattern = r'(-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----)' certificates = re.findall(cert_pattern, pem_data, re.DOTALL) return certificates def combine_cert_chain(cert_path: str, chain_paths: List[str]) -> str: """Combine server cert and chain certs into a single properly-formatted PEM file. Args: cert_path: Path to server certificate (may contain full chain) chain_paths: List of paths to intermediate/root certificates (optional) Returns: Path to temporary combined certificate file (caller should clean up) """ combined_content = [] # Read server certificate file with open(cert_path, 'r') as f: cert_data = f.read().strip() certs = split_pem_certificates(cert_data) if chain_paths: # If chain files provided separately, only use first cert from cert_path if certs: combined_content.append(certs[0].strip()) else: combined_content.append(cert_data) else: # No separate chain files - include all certs from cert_path (fullchain case) if certs: combined_content.extend([c.strip() for c in certs]) else: combined_content.append(cert_data) # Read separate chain certificate files for chain_path in chain_paths: with open(chain_path, 'r') as f: chain_data = f.read().strip() # Each chain file might contain multiple certs chain_certs = split_pem_certificates(chain_data) if chain_certs: combined_content.extend([c.strip() for c in chain_certs]) else: combined_content.append(chain_data) # Write to temporary file with proper formatting # Each cert should be separated by a newline temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False) temp_file.write('\n'.join(combined_content)) temp_file.write('\n') # End with newline temp_file.close() return temp_file.name def extract_server_cert(cert_path: str) -> str: """Extract just the server certificate from a file that may contain a full chain. Args: cert_path: Path to certificate file (may contain chain) Returns: Path to temporary file containing only the server certificate """ with open(cert_path, 'r') as f: cert_data = f.read() # Split into individual certificates certs = split_pem_certificates(cert_data) if not certs: # No certs found, return original path return cert_path # First cert is the server cert server_cert = certs[0].strip() # Write to temporary file temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False) temp_file.write(server_cert) temp_file.write('\n') temp_file.close() return temp_file.name