mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 16:55:37 +02:00
⭐ feat: Migrate from WebSocket to libp2p for peer-to-peer connectivity (#286)
## Description Whew, some stuff is still not re-implemented, but it's working! Rabbit's gonna explode with the amount of changes I reckon 😅 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a peer-to-peer relay system using libp2p with enhanced stream forwarding, room state synchronization, and mDNS peer discovery. - Added decentralized room and participant management, metrics publishing, and safe, size-limited, concurrent message streaming with robust framing and callback dispatching. - Implemented asynchronous, callback-driven message handling over custom libp2p streams replacing WebSocket signaling. - **Improvements** - Migrated signaling and stream protocols from WebSocket to libp2p, improving reliability and scalability. - Simplified configuration and environment variables, removing deprecated flags and adding persistent data support. - Enhanced logging, error handling, and connection management for better observability and robustness. - Refined RTP header extension registration and NAT IP handling for improved WebRTC performance. - **Bug Fixes** - Improved ICE candidate buffering and SDP negotiation in WebRTC connections. - Fixed NAT IP and UDP port range configuration issues. - **Refactor** - Modularized codebase, reorganized relay and server logic, and removed deprecated WebSocket-based components. - Streamlined message structures, removed obsolete enums and message types, and simplified SafeMap concurrency. - Replaced WebSocket signaling with libp2p stream protocols in server and relay components. - **Chores** - Updated and cleaned dependencies across Go, Rust, and JavaScript packages. - Added `.gitignore` for persistent data directory in relay package. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com> Co-authored-by: Philipp Neumann <3daquawolf@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e67a8d2b32
commit
6e82eff9e2
@@ -2,12 +2,13 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
|
||||
"github.com/libp2p/go-reuseport"
|
||||
"github.com/pion/ice/v4"
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var globalWebRTCAPI *webrtc.API
|
||||
@@ -24,17 +25,9 @@ func InitWebRTCAPI() error {
|
||||
// Media engine
|
||||
mediaEngine := &webrtc.MediaEngine{}
|
||||
|
||||
// Register additional header extensions to reduce latency
|
||||
// Playout Delay
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeAudio); err != nil {
|
||||
return err
|
||||
// Register our extensions
|
||||
if err := RegisterExtensions(mediaEngine); err != nil {
|
||||
return fmt.Errorf("failed to register extensions: %w", err)
|
||||
}
|
||||
|
||||
// Default codecs cover most of our needs
|
||||
@@ -75,9 +68,10 @@ func InitWebRTCAPI() error {
|
||||
// New in v4, reduces CPU usage and latency when enabled
|
||||
settingEngine.EnableSCTPZeroChecksum(true)
|
||||
|
||||
nat11IPs := GetFlags().NAT11IPs
|
||||
if len(nat11IPs) > 0 {
|
||||
settingEngine.SetNAT1To1IPs(nat11IPs, webrtc.ICECandidateTypeHost)
|
||||
nat11IP := GetFlags().NAT11IP
|
||||
if len(nat11IP) > 0 {
|
||||
settingEngine.SetNAT1To1IPs([]string{nat11IP}, webrtc.ICECandidateTypeSrflx)
|
||||
slog.Info("Using NAT 1:1 IP for WebRTC", "nat11_ip", nat11IP)
|
||||
}
|
||||
|
||||
muxPort := GetFlags().UDPMuxPort
|
||||
@@ -85,7 +79,7 @@ func InitWebRTCAPI() error {
|
||||
// Use reuseport to allow multiple listeners on the same port
|
||||
pktListener, err := reuseport.ListenPacket("udp", ":"+strconv.Itoa(muxPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP listener: %w", err)
|
||||
return fmt.Errorf("failed to create WebRTC muxed UDP listener: %w", err)
|
||||
}
|
||||
|
||||
mux := ice.NewMultiUDPMuxDefault(ice.NewUDPMuxDefault(ice.UDPMuxParams{
|
||||
@@ -95,10 +89,13 @@ func InitWebRTCAPI() error {
|
||||
settingEngine.SetICEUDPMux(mux)
|
||||
}
|
||||
|
||||
// Set the UDP port range used by WebRTC
|
||||
err = settingEngine.SetEphemeralUDPPortRange(uint16(flags.WebRTCUDPStart), uint16(flags.WebRTCUDPEnd))
|
||||
if err != nil {
|
||||
return err
|
||||
if flags.WebRTCUDPStart > 0 && flags.WebRTCUDPEnd > 0 && flags.WebRTCUDPStart < flags.WebRTCUDPEnd {
|
||||
// Set the UDP port range used by WebRTC
|
||||
err = settingEngine.SetEphemeralUDPPortRange(uint16(flags.WebRTCUDPStart), uint16(flags.WebRTCUDPEnd))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("Using WebRTC UDP Port Range", "start", flags.WebRTCUDPStart, "end", flags.WebRTCUDPEnd)
|
||||
}
|
||||
|
||||
settingEngine.SetIncludeLoopbackCandidate(true) // Just in case
|
||||
@@ -109,11 +106,6 @@ func InitWebRTCAPI() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWebRTCAPI returns the global WebRTC API
|
||||
func GetWebRTCAPI() *webrtc.API {
|
||||
return globalWebRTCAPI
|
||||
}
|
||||
|
||||
// CreatePeerConnection sets up a new peer connection
|
||||
func CreatePeerConnection(onClose func()) (*webrtc.PeerConnection, error) {
|
||||
pc, err := globalWebRTCAPI.NewPeerConnection(globalWebRTCConfig)
|
||||
|
||||
@@ -1,19 +1,51 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
func NewULID() (ulid.ULID, error) {
|
||||
return ulid.New(ulid.Timestamp(time.Now()), ulid.Monotonic(rand.Reader, 0))
|
||||
}
|
||||
|
||||
// Helper function to generate PSK from token
|
||||
func GeneratePSKFromToken(token string) ([]byte, error) {
|
||||
// Simple hash-based PSK generation (32 bytes for libp2p)
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hash[:], nil
|
||||
// GenerateED25519Key generates a new ED25519 key
|
||||
func GenerateED25519Key() (ed25519.PrivateKey, error) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ED25519 key pair: %w", err)
|
||||
}
|
||||
return priv, nil
|
||||
}
|
||||
|
||||
// SaveED25519Key saves an ED25519 private key to a path as a binary file
|
||||
func SaveED25519Key(privateKey ed25519.PrivateKey, filePath string) error {
|
||||
if privateKey == nil {
|
||||
return errors.New("private key cannot be nil")
|
||||
}
|
||||
if len(privateKey) != ed25519.PrivateKeySize {
|
||||
return errors.New("private key must be exactly 64 bytes for ED25519")
|
||||
}
|
||||
if err := os.WriteFile(filePath, privateKey, 0600); err != nil {
|
||||
return fmt.Errorf("failed to save ED25519 key to %s: %w", filePath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadED25519Key loads an ED25519 private key binary file from a path
|
||||
func LoadED25519Key(filePath string) (ed25519.PrivateKey, error) {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read ED25519 key from %s: %w", filePath, err)
|
||||
}
|
||||
if len(data) != ed25519.PrivateKeySize {
|
||||
return nil, fmt.Errorf("ED25519 key must be exactly %d bytes, got %d", ed25519.PrivateKeySize, len(data))
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,45 @@
|
||||
package common
|
||||
|
||||
import "github.com/pion/webrtc/v4"
|
||||
|
||||
const (
|
||||
ExtensionPlayoutDelay string = "http://www.webrtc.org/experiments/rtp-hdrext/playout-delay"
|
||||
)
|
||||
|
||||
// ExtensionMap maps URIs to their IDs based on registration order
|
||||
// IMPORTANT: This must match the order in which extensions are registered in common.go!
|
||||
var ExtensionMap = map[string]uint8{
|
||||
ExtensionPlayoutDelay: 1,
|
||||
// ExtensionMap maps audio/video extension URIs to their IDs based on registration order
|
||||
var ExtensionMap = map[webrtc.RTPCodecType]map[string]uint8{}
|
||||
|
||||
func RegisterExtensions(mediaEngine *webrtc.MediaEngine) error {
|
||||
// Register additional header extensions to reduce latency
|
||||
// Playout Delay (Video)
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return err
|
||||
}
|
||||
// Playout Delay (Audio)
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeAudio); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register the extension IDs for both audio and video
|
||||
ExtensionMap[webrtc.RTPCodecTypeAudio] = map[string]uint8{
|
||||
ExtensionPlayoutDelay: 1,
|
||||
}
|
||||
ExtensionMap[webrtc.RTPCodecTypeVideo] = map[string]uint8{
|
||||
ExtensionPlayoutDelay: 1,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetExtension(codecType webrtc.RTPCodecType, extURI string) (uint8, bool) {
|
||||
cType, ok := ExtensionMap[codecType]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
extID, ok := cType[extURI]
|
||||
return extID, ok
|
||||
}
|
||||
|
||||
@@ -2,47 +2,43 @@ package common
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
var globalFlags *Flags
|
||||
|
||||
type Flags struct {
|
||||
Verbose bool // Log everything to console
|
||||
Debug bool // Enable debug mode, implies Verbose
|
||||
EndpointPort int // Port for HTTP/S and WS/S endpoint (TCP)
|
||||
MeshPort int // Port for Mesh connections (TCP)
|
||||
WebRTCUDPStart int // WebRTC UDP port range start - ignored if UDPMuxPort is set
|
||||
WebRTCUDPEnd int // WebRTC UDP port range end - ignored if UDPMuxPort is set
|
||||
STUNServer string // WebRTC STUN server
|
||||
UDPMuxPort int // WebRTC UDP mux port - if set, overrides UDP port range
|
||||
AutoAddLocalIP bool // Automatically add local IP to NAT 1 to 1 IPs
|
||||
NAT11IPs []string // WebRTC NAT 1 to 1 IP(s) - allows specifying host IP(s) if behind NAT
|
||||
TLSCert string // Path to TLS certificate
|
||||
TLSKey string // Path to TLS key
|
||||
ControlSecret string // Shared secret for this relay's control endpoint
|
||||
RegenIdentity bool // Remove old identity on startup and regenerate it
|
||||
Verbose bool // Log everything to console
|
||||
Debug bool // Enable debug mode, implies Verbose
|
||||
EndpointPort int // Port for HTTP/S and WS/S endpoint (TCP)
|
||||
WebRTCUDPStart int // WebRTC UDP port range start - ignored if UDPMuxPort is set
|
||||
WebRTCUDPEnd int // WebRTC UDP port range end - ignored if UDPMuxPort is set
|
||||
STUNServer string // WebRTC STUN server
|
||||
UDPMuxPort int // WebRTC UDP mux port - if set, overrides UDP port range
|
||||
AutoAddLocalIP bool // Automatically add local IP to NAT 1 to 1 IPs
|
||||
NAT11IP string // WebRTC NAT 1 to 1 IP - allows specifying IP of relay if behind NAT
|
||||
PersistDir string // Directory to save persistent data to
|
||||
}
|
||||
|
||||
func (flags *Flags) DebugLog() {
|
||||
slog.Info("Relay flags",
|
||||
slog.Debug("Relay flags",
|
||||
"regenIdentity", flags.RegenIdentity,
|
||||
"verbose", flags.Verbose,
|
||||
"debug", flags.Debug,
|
||||
"endpointPort", flags.EndpointPort,
|
||||
"meshPort", flags.MeshPort,
|
||||
"webrtcUDPStart", flags.WebRTCUDPStart,
|
||||
"webrtcUDPEnd", flags.WebRTCUDPEnd,
|
||||
"stunServer", flags.STUNServer,
|
||||
"webrtcUDPMux", flags.UDPMuxPort,
|
||||
"autoAddLocalIP", flags.AutoAddLocalIP,
|
||||
"webrtcNAT11IPs", strings.Join(flags.NAT11IPs, ","),
|
||||
"tlsCert", flags.TLSCert,
|
||||
"tlsKey", flags.TLSKey,
|
||||
"controlSecret", flags.ControlSecret,
|
||||
"webrtcNAT11IPs", flags.NAT11IP,
|
||||
"persistDir", flags.PersistDir,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -76,29 +72,25 @@ func InitFlags() {
|
||||
// Create Flags struct
|
||||
globalFlags = &Flags{}
|
||||
// Get flags
|
||||
flag.BoolVar(&globalFlags.RegenIdentity, "regenIdentity", getEnvAsBool("REGEN_IDENTITY", false), "Regenerate identity on startup")
|
||||
flag.BoolVar(&globalFlags.Verbose, "verbose", getEnvAsBool("VERBOSE", false), "Verbose mode")
|
||||
flag.BoolVar(&globalFlags.Debug, "debug", getEnvAsBool("DEBUG", false), "Debug mode")
|
||||
flag.IntVar(&globalFlags.EndpointPort, "endpointPort", getEnvAsInt("ENDPOINT_PORT", 8088), "HTTP endpoint port")
|
||||
flag.IntVar(&globalFlags.MeshPort, "meshPort", getEnvAsInt("MESH_PORT", 8089), "Mesh connections TCP port")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPStart, "webrtcUDPStart", getEnvAsInt("WEBRTC_UDP_START", 10000), "WebRTC UDP port range start")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPEnd, "webrtcUDPEnd", getEnvAsInt("WEBRTC_UDP_END", 20000), "WebRTC UDP port range end")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPStart, "webrtcUDPStart", getEnvAsInt("WEBRTC_UDP_START", 0), "WebRTC UDP port range start")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPEnd, "webrtcUDPEnd", getEnvAsInt("WEBRTC_UDP_END", 0), "WebRTC UDP port range end")
|
||||
flag.StringVar(&globalFlags.STUNServer, "stunServer", getEnvAsString("STUN_SERVER", "stun.l.google.com:19302"), "WebRTC STUN server")
|
||||
flag.IntVar(&globalFlags.UDPMuxPort, "webrtcUDPMux", getEnvAsInt("WEBRTC_UDP_MUX", 8088), "WebRTC UDP mux port")
|
||||
flag.BoolVar(&globalFlags.AutoAddLocalIP, "autoAddLocalIP", getEnvAsBool("AUTO_ADD_LOCAL_IP", true), "Automatically add local IP to NAT 1 to 1 IPs")
|
||||
// String with comma separated IPs
|
||||
nat11IPs := ""
|
||||
flag.StringVar(&nat11IPs, "webrtcNAT11IPs", getEnvAsString("WEBRTC_NAT_IPS", ""), "WebRTC NAT 1 to 1 IP(s), comma delimited")
|
||||
flag.StringVar(&globalFlags.TLSCert, "tlsCert", getEnvAsString("TLS_CERT", ""), "Path to TLS certificate")
|
||||
flag.StringVar(&globalFlags.TLSKey, "tlsKey", getEnvAsString("TLS_KEY", ""), "Path to TLS key")
|
||||
flag.StringVar(&globalFlags.ControlSecret, "controlSecret", getEnvAsString("CONTROL_SECRET", ""), "Shared secret for control endpoint")
|
||||
nat11IP := ""
|
||||
flag.StringVar(&nat11IP, "webrtcNAT11IP", getEnvAsString("WEBRTC_NAT_IP", ""), "WebRTC NAT 1 to 1 IP")
|
||||
flag.StringVar(&globalFlags.PersistDir, "persistDir", getEnvAsString("PERSIST_DIR", "./persist-data"), "Directory to save persistent data to")
|
||||
// Parse flags
|
||||
flag.Parse()
|
||||
|
||||
// If debug is enabled, verbose is also enabled
|
||||
if globalFlags.Debug {
|
||||
globalFlags.Verbose = true
|
||||
// If Debug is enabled, set ControlSecret to 1234
|
||||
globalFlags.ControlSecret = "1234"
|
||||
}
|
||||
|
||||
// ICE STUN servers
|
||||
@@ -108,24 +100,11 @@ func InitFlags() {
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize NAT 1 to 1 IPs
|
||||
globalFlags.NAT11IPs = []string{}
|
||||
|
||||
// Get local IP
|
||||
if globalFlags.AutoAddLocalIP {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, getLocalIP())
|
||||
}
|
||||
|
||||
// Parse NAT 1 to 1 IPs from string
|
||||
if len(nat11IPs) > 0 {
|
||||
split := strings.Split(nat11IPs, ",")
|
||||
if len(split) > 0 {
|
||||
for _, ip := range split {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, ip)
|
||||
}
|
||||
} else {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, nat11IPs)
|
||||
}
|
||||
if len(nat11IP) > 0 {
|
||||
globalFlags.NAT11IP = nat11IP
|
||||
} else if globalFlags.AutoAddLocalIP {
|
||||
globalFlags.NAT11IP = getLocalIP()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
gen "relay/internal/proto"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type TimestampEntry struct {
|
||||
|
||||
175
packages/relay/internal/common/safebufio.go
Normal file
175
packages/relay/internal/common/safebufio.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// MaxSize is the maximum allowed data size (1MB)
|
||||
const MaxSize = 1024 * 1024
|
||||
|
||||
// SafeBufioRW wraps a bufio.ReadWriter for sending and receiving JSON and protobufs safely
|
||||
type SafeBufioRW struct {
|
||||
brw *bufio.ReadWriter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSafeBufioRW(brw *bufio.ReadWriter) *SafeBufioRW {
|
||||
return &SafeBufioRW{brw: brw}
|
||||
}
|
||||
|
||||
// SendJSON serializes the given data as JSON and sends it with a 4-byte length prefix
|
||||
func (bu *SafeBufioRW) SendJSON(data interface{}) error {
|
||||
bu.mutex.Lock()
|
||||
defer bu.mutex.Unlock()
|
||||
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(jsonData) > MaxSize {
|
||||
return errors.New("JSON data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Write the 4-byte length prefix
|
||||
if err = binary.Write(bu.brw, binary.BigEndian, uint32(len(jsonData))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the JSON data
|
||||
if _, err = bu.brw.Write(jsonData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush the writer to ensure data is sent
|
||||
return bu.brw.Flush()
|
||||
}
|
||||
|
||||
// ReceiveJSON reads a 4-byte length prefix, then reads and unmarshals the JSON
|
||||
func (bu *SafeBufioRW) ReceiveJSON(dest interface{}) error {
|
||||
bu.mutex.RLock()
|
||||
defer bu.mutex.RUnlock()
|
||||
|
||||
// Read the 4-byte length prefix
|
||||
var length uint32
|
||||
if err := binary.Read(bu.brw, binary.BigEndian, &length); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if length > MaxSize {
|
||||
return errors.New("received JSON data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Read the JSON data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(bu.brw, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, dest)
|
||||
}
|
||||
|
||||
// Receive reads a 4-byte length prefix, then reads the raw data
|
||||
func (bu *SafeBufioRW) Receive() ([]byte, error) {
|
||||
bu.mutex.RLock()
|
||||
defer bu.mutex.RUnlock()
|
||||
|
||||
// Read the 4-byte length prefix
|
||||
var length uint32
|
||||
if err := binary.Read(bu.brw, binary.BigEndian, &length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if length > MaxSize {
|
||||
return nil, errors.New("received data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Read the raw data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(bu.brw, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// SendProto serializes the given protobuf message and sends it with a 4-byte length prefix
|
||||
func (bu *SafeBufioRW) SendProto(msg proto.Message) error {
|
||||
bu.mutex.Lock()
|
||||
defer bu.mutex.Unlock()
|
||||
|
||||
protoData, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(protoData) > MaxSize {
|
||||
return errors.New("protobuf data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Write the 4-byte length prefix
|
||||
if err = binary.Write(bu.brw, binary.BigEndian, uint32(len(protoData))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the Protobuf data
|
||||
if _, err := bu.brw.Write(protoData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush the writer to ensure data is sent
|
||||
return bu.brw.Flush()
|
||||
}
|
||||
|
||||
// ReceiveProto reads a 4-byte length prefix, then reads and unmarshals the protobuf
|
||||
func (bu *SafeBufioRW) ReceiveProto(msg proto.Message) error {
|
||||
bu.mutex.RLock()
|
||||
defer bu.mutex.RUnlock()
|
||||
|
||||
// Read the 4-byte length prefix
|
||||
var length uint32
|
||||
if err := binary.Read(bu.brw, binary.BigEndian, &length); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if length > MaxSize {
|
||||
return errors.New("received Protobuf data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Read the Protobuf data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(bu.brw, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return proto.Unmarshal(data, msg)
|
||||
}
|
||||
|
||||
// Write writes raw data to the underlying buffer
|
||||
func (bu *SafeBufioRW) Write(data []byte) (int, error) {
|
||||
bu.mutex.Lock()
|
||||
defer bu.mutex.Unlock()
|
||||
|
||||
if len(data) > MaxSize {
|
||||
return 0, errors.New("data exceeds maximum size")
|
||||
}
|
||||
|
||||
n, err := bu.brw.Write(data)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Flush the writer to ensure data is sent
|
||||
if err = bu.brw.Flush(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,18 +1,11 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
ErrValueNotPointer = errors.New("value is not a pointer")
|
||||
ErrFieldNotFound = errors.New("field not found")
|
||||
ErrTypeMismatch = errors.New("type mismatch")
|
||||
)
|
||||
|
||||
// SafeMap is a generic thread-safe map with its own mutex
|
||||
type SafeMap[K comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
@@ -34,6 +27,14 @@ func (sm *SafeMap[K, V]) Get(key K) (V, bool) {
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Has checks if a key exists in the map
|
||||
func (sm *SafeMap[K, V]) Has(key K) bool {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
_, ok := sm.m[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Set adds or updates a value in the map
|
||||
func (sm *SafeMap[K, V]) Set(key K, value V) {
|
||||
sm.mu.Lock()
|
||||
@@ -66,36 +67,31 @@ func (sm *SafeMap[K, V]) Copy() map[K]V {
|
||||
return copied
|
||||
}
|
||||
|
||||
// Update updates a specific field in the value data
|
||||
func (sm *SafeMap[K, V]) Update(key K, fieldName string, newValue any) error {
|
||||
// Range iterates over the map and applies a function to each key-value pair
|
||||
func (sm *SafeMap[K, V]) Range(f func(K, V) bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
for k, v := range sm.m {
|
||||
if !f(k, v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SafeMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return json.Marshal(sm.m)
|
||||
}
|
||||
|
||||
func (sm *SafeMap[K, V]) UnmarshalJSON(data []byte) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
v, ok := sm.m[key]
|
||||
if !ok {
|
||||
return ErrKeyNotFound
|
||||
}
|
||||
|
||||
// Use reflect to update the field
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
return ErrValueNotPointer
|
||||
}
|
||||
|
||||
rv = rv.Elem()
|
||||
// Check if the field exists
|
||||
field := rv.FieldByName(fieldName)
|
||||
if !field.IsValid() || !field.CanSet() {
|
||||
return ErrFieldNotFound
|
||||
}
|
||||
|
||||
newRV := reflect.ValueOf(newValue)
|
||||
if newRV.Type() != field.Type() {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
|
||||
field.Set(newRV)
|
||||
sm.m[key] = v
|
||||
|
||||
return nil
|
||||
return json.Unmarshal(data, &sm.m)
|
||||
}
|
||||
|
||||
func (sm *SafeMap[K, V]) String() string {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return fmt.Sprintf("%+v", sm.m)
|
||||
}
|
||||
Reference in New Issue
Block a user