mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 08:45:38 +02:00
⭐ feat(maitred): Update maitred - hookup to the API (#198)
## Description We are attempting to hookup maitred to the API Maitred duties will be: - [ ] Hookup to the API - [ ] Wait for signal (from the API) to start Steam - [ ] Stop signal to stop the gaming session, clean up Steam... and maybe do the backup ## Summary by CodeRabbit - **New Features** - Introduced Docker-based deployment configurations for both the main and relay applications. - Added new API endpoints enabling real-time machine messaging and enhanced IoT operations. - Expanded database schema and actor types to support improved machine tracking. - **Improvements** - Enhanced real-time communication and relay management with streamlined room handling. - Upgraded dependencies, logging, and error handling for greater stability and performance. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com> Co-authored-by: Kristian Ollikainen <14197772+DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
139
packages/relay/internal/common/common.go
Normal file
139
packages/relay/internal/common/common.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/libp2p/go-reuseport"
|
||||
"github.com/pion/ice/v4"
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var globalWebRTCAPI *webrtc.API
|
||||
var globalWebRTCConfig = webrtc.Configuration{
|
||||
ICETransportPolicy: webrtc.ICETransportPolicyAll,
|
||||
BundlePolicy: webrtc.BundlePolicyBalanced,
|
||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
|
||||
}
|
||||
|
||||
func InitWebRTCAPI() error {
|
||||
var err error
|
||||
flags := GetFlags()
|
||||
|
||||
// Media engine
|
||||
mediaEngine := &webrtc.MediaEngine{}
|
||||
|
||||
// Register additional header extensions to reduce latency
|
||||
// Playout Delay
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeAudio); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Default codecs cover most of our needs
|
||||
err = mediaEngine.RegisterDefaultCodecs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add H.265 for special cases
|
||||
videoRTCPFeedback := []webrtc.RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}}
|
||||
for _, codec := range []webrtc.RTPCodecParameters{
|
||||
{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH265, ClockRate: 90000, RTCPFeedback: videoRTCPFeedback},
|
||||
PayloadType: 48,
|
||||
},
|
||||
{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeRTX, ClockRate: 90000, SDPFmtpLine: "apt=48"},
|
||||
PayloadType: 49,
|
||||
},
|
||||
} {
|
||||
if err = mediaEngine.RegisterCodec(codec, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Interceptor registry
|
||||
interceptorRegistry := &interceptor.Registry{}
|
||||
|
||||
// Use default set
|
||||
err = webrtc.RegisterDefaultInterceptors(mediaEngine, interceptorRegistry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Setting engine
|
||||
settingEngine := webrtc.SettingEngine{}
|
||||
|
||||
// New in v4, reduces CPU usage and latency when enabled
|
||||
settingEngine.EnableSCTPZeroChecksum(true)
|
||||
|
||||
nat11IPs := GetFlags().NAT11IPs
|
||||
if len(nat11IPs) > 0 {
|
||||
settingEngine.SetNAT1To1IPs(nat11IPs, webrtc.ICECandidateTypeHost)
|
||||
}
|
||||
|
||||
muxPort := GetFlags().UDPMuxPort
|
||||
if muxPort > 0 {
|
||||
// Use reuseport to allow multiple listeners on the same port
|
||||
pktListener, err := reuseport.ListenPacket("udp", ":"+strconv.Itoa(muxPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP listener: %w", err)
|
||||
}
|
||||
|
||||
mux := ice.NewMultiUDPMuxDefault(ice.NewUDPMuxDefault(ice.UDPMuxParams{
|
||||
UDPConn: pktListener,
|
||||
}))
|
||||
slog.Info("Using UDP Mux for WebRTC", "port", muxPort)
|
||||
settingEngine.SetICEUDPMux(mux)
|
||||
}
|
||||
|
||||
// Set the UDP port range used by WebRTC
|
||||
err = settingEngine.SetEphemeralUDPPortRange(uint16(flags.WebRTCUDPStart), uint16(flags.WebRTCUDPEnd))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settingEngine.SetIncludeLoopbackCandidate(true) // Just in case
|
||||
|
||||
// Create a new API object with our customized settings
|
||||
globalWebRTCAPI = webrtc.NewAPI(webrtc.WithMediaEngine(mediaEngine), webrtc.WithSettingEngine(settingEngine), webrtc.WithInterceptorRegistry(interceptorRegistry))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWebRTCAPI returns the global WebRTC API
|
||||
func GetWebRTCAPI() *webrtc.API {
|
||||
return globalWebRTCAPI
|
||||
}
|
||||
|
||||
// CreatePeerConnection sets up a new peer connection
|
||||
func CreatePeerConnection(onClose func()) (*webrtc.PeerConnection, error) {
|
||||
pc, err := globalWebRTCAPI.NewPeerConnection(globalWebRTCConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Log connection state changes and handle failed/disconnected connections
|
||||
pc.OnConnectionStateChange(func(connectionState webrtc.PeerConnectionState) {
|
||||
// Close PeerConnection in cases
|
||||
if connectionState == webrtc.PeerConnectionStateFailed ||
|
||||
connectionState == webrtc.PeerConnectionStateDisconnected ||
|
||||
connectionState == webrtc.PeerConnectionStateClosed {
|
||||
err = pc.Close()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close PeerConnection", "err", err)
|
||||
}
|
||||
onClose()
|
||||
}
|
||||
})
|
||||
|
||||
return pc, nil
|
||||
}
|
||||
19
packages/relay/internal/common/crypto.go
Normal file
19
packages/relay/internal/common/crypto.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewULID() (ulid.ULID, error) {
|
||||
return ulid.New(ulid.Timestamp(time.Now()), ulid.Monotonic(rand.Reader, 0))
|
||||
}
|
||||
|
||||
// Helper function to generate PSK from token
|
||||
func GeneratePSKFromToken(token string) ([]byte, error) {
|
||||
// Simple hash-based PSK generation (32 bytes for libp2p)
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hash[:], nil
|
||||
}
|
||||
11
packages/relay/internal/common/extensions.go
Normal file
11
packages/relay/internal/common/extensions.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package common
|
||||
|
||||
const (
|
||||
ExtensionPlayoutDelay string = "http://www.webrtc.org/experiments/rtp-hdrext/playout-delay"
|
||||
)
|
||||
|
||||
// ExtensionMap maps URIs to their IDs based on registration order
|
||||
// IMPORTANT: This must match the order in which extensions are registered in common.go!
|
||||
var ExtensionMap = map[string]uint8{
|
||||
ExtensionPlayoutDelay: 1,
|
||||
}
|
||||
150
packages/relay/internal/common/flags.go
Normal file
150
packages/relay/internal/common/flags.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var globalFlags *Flags
|
||||
|
||||
type Flags struct {
|
||||
Verbose bool // Log everything to console
|
||||
Debug bool // Enable debug mode, implies Verbose
|
||||
EndpointPort int // Port for HTTP/S and WS/S endpoint (TCP)
|
||||
MeshPort int // Port for Mesh connections (TCP)
|
||||
WebRTCUDPStart int // WebRTC UDP port range start - ignored if UDPMuxPort is set
|
||||
WebRTCUDPEnd int // WebRTC UDP port range end - ignored if UDPMuxPort is set
|
||||
STUNServer string // WebRTC STUN server
|
||||
UDPMuxPort int // WebRTC UDP mux port - if set, overrides UDP port range
|
||||
AutoAddLocalIP bool // Automatically add local IP to NAT 1 to 1 IPs
|
||||
NAT11IPs []string // WebRTC NAT 1 to 1 IP(s) - allows specifying host IP(s) if behind NAT
|
||||
TLSCert string // Path to TLS certificate
|
||||
TLSKey string // Path to TLS key
|
||||
ControlSecret string // Shared secret for this relay's control endpoint
|
||||
}
|
||||
|
||||
func (flags *Flags) DebugLog() {
|
||||
slog.Info("Relay flags",
|
||||
"verbose", flags.Verbose,
|
||||
"debug", flags.Debug,
|
||||
"endpointPort", flags.EndpointPort,
|
||||
"meshPort", flags.MeshPort,
|
||||
"webrtcUDPStart", flags.WebRTCUDPStart,
|
||||
"webrtcUDPEnd", flags.WebRTCUDPEnd,
|
||||
"stunServer", flags.STUNServer,
|
||||
"webrtcUDPMux", flags.UDPMuxPort,
|
||||
"autoAddLocalIP", flags.AutoAddLocalIP,
|
||||
"webrtcNAT11IPs", strings.Join(flags.NAT11IPs, ","),
|
||||
"tlsCert", flags.TLSCert,
|
||||
"tlsKey", flags.TLSKey,
|
||||
"controlSecret", flags.ControlSecret,
|
||||
)
|
||||
}
|
||||
|
||||
func getEnvAsInt(name string, defaultVal int) int {
|
||||
valueStr := os.Getenv(name)
|
||||
if value, err := strconv.Atoi(valueStr); err != nil {
|
||||
return defaultVal
|
||||
} else {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvAsBool(name string, defaultVal bool) bool {
|
||||
valueStr := os.Getenv(name)
|
||||
val, err := strconv.ParseBool(valueStr)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func getEnvAsString(name string, defaultVal string) string {
|
||||
valueStr := os.Getenv(name)
|
||||
if len(valueStr) == 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return valueStr
|
||||
}
|
||||
|
||||
func InitFlags() {
|
||||
// Create Flags struct
|
||||
globalFlags = &Flags{}
|
||||
// Get flags
|
||||
flag.BoolVar(&globalFlags.Verbose, "verbose", getEnvAsBool("VERBOSE", false), "Verbose mode")
|
||||
flag.BoolVar(&globalFlags.Debug, "debug", getEnvAsBool("DEBUG", false), "Debug mode")
|
||||
flag.IntVar(&globalFlags.EndpointPort, "endpointPort", getEnvAsInt("ENDPOINT_PORT", 8088), "HTTP endpoint port")
|
||||
flag.IntVar(&globalFlags.MeshPort, "meshPort", getEnvAsInt("MESH_PORT", 8089), "Mesh connections TCP port")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPStart, "webrtcUDPStart", getEnvAsInt("WEBRTC_UDP_START", 10000), "WebRTC UDP port range start")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPEnd, "webrtcUDPEnd", getEnvAsInt("WEBRTC_UDP_END", 20000), "WebRTC UDP port range end")
|
||||
flag.StringVar(&globalFlags.STUNServer, "stunServer", getEnvAsString("STUN_SERVER", "stun.l.google.com:19302"), "WebRTC STUN server")
|
||||
flag.IntVar(&globalFlags.UDPMuxPort, "webrtcUDPMux", getEnvAsInt("WEBRTC_UDP_MUX", 8088), "WebRTC UDP mux port")
|
||||
flag.BoolVar(&globalFlags.AutoAddLocalIP, "autoAddLocalIP", getEnvAsBool("AUTO_ADD_LOCAL_IP", true), "Automatically add local IP to NAT 1 to 1 IPs")
|
||||
// String with comma separated IPs
|
||||
nat11IPs := ""
|
||||
flag.StringVar(&nat11IPs, "webrtcNAT11IPs", getEnvAsString("WEBRTC_NAT_IPS", ""), "WebRTC NAT 1 to 1 IP(s), comma delimited")
|
||||
flag.StringVar(&globalFlags.TLSCert, "tlsCert", getEnvAsString("TLS_CERT", ""), "Path to TLS certificate")
|
||||
flag.StringVar(&globalFlags.TLSKey, "tlsKey", getEnvAsString("TLS_KEY", ""), "Path to TLS key")
|
||||
flag.StringVar(&globalFlags.ControlSecret, "controlSecret", getEnvAsString("CONTROL_SECRET", ""), "Shared secret for control endpoint")
|
||||
// Parse flags
|
||||
flag.Parse()
|
||||
|
||||
// If debug is enabled, verbose is also enabled
|
||||
if globalFlags.Debug {
|
||||
globalFlags.Verbose = true
|
||||
// If Debug is enabled, set ControlSecret to 1234
|
||||
globalFlags.ControlSecret = "1234"
|
||||
}
|
||||
|
||||
// ICE STUN servers
|
||||
globalWebRTCConfig.ICEServers = []webrtc.ICEServer{
|
||||
{
|
||||
URLs: []string{"stun:" + globalFlags.STUNServer},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize NAT 1 to 1 IPs
|
||||
globalFlags.NAT11IPs = []string{}
|
||||
|
||||
// Get local IP
|
||||
if globalFlags.AutoAddLocalIP {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, getLocalIP())
|
||||
}
|
||||
|
||||
// Parse NAT 1 to 1 IPs from string
|
||||
if len(nat11IPs) > 0 {
|
||||
split := strings.Split(nat11IPs, ",")
|
||||
if len(split) > 0 {
|
||||
for _, ip := range split {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, ip)
|
||||
}
|
||||
} else {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, nat11IPs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetFlags() *Flags {
|
||||
return globalFlags
|
||||
}
|
||||
|
||||
// getLocalIP returns local IP, be it either IPv4 or IPv6, skips loopback addresses
|
||||
func getLocalIP() string {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, address := range addrs {
|
||||
if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
|
||||
if ipnet.IP.To4() != nil || ipnet.IP != nil {
|
||||
return ipnet.IP.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
132
packages/relay/internal/common/latency.go
Normal file
132
packages/relay/internal/common/latency.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
gen "relay/internal/proto"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TimestampEntry struct {
|
||||
Stage string `json:"stage"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
// LatencyTracker provides a generic structure for measuring time taken at various stages in message processing.
|
||||
// It can be embedded in message structs for tracking the flow of data and calculating round-trip latency.
|
||||
type LatencyTracker struct {
|
||||
SequenceID string `json:"sequence_id"`
|
||||
Timestamps []TimestampEntry `json:"timestamps"`
|
||||
}
|
||||
|
||||
// NewLatencyTracker initializes a new LatencyTracker with the given sequence ID
|
||||
func NewLatencyTracker(sequenceID string) *LatencyTracker {
|
||||
return &LatencyTracker{
|
||||
SequenceID: sequenceID,
|
||||
Timestamps: make([]TimestampEntry, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddTimestamp adds a new timestamp for a specific stage
|
||||
func (lt *LatencyTracker) AddTimestamp(stage string) {
|
||||
lt.Timestamps = append(lt.Timestamps, TimestampEntry{
|
||||
Stage: stage,
|
||||
// Ensure extremely precise UTC RFC3339 timestamps (down to nanoseconds)
|
||||
Time: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
// TotalLatency calculates the total latency from the earliest to the latest timestamp
|
||||
func (lt *LatencyTracker) TotalLatency() (int64, error) {
|
||||
if len(lt.Timestamps) < 2 {
|
||||
return 0, nil // Not enough timestamps to calculate latency
|
||||
}
|
||||
|
||||
var earliest, latest time.Time
|
||||
for _, ts := range lt.Timestamps {
|
||||
if earliest.IsZero() || ts.Time.Before(earliest) {
|
||||
earliest = ts.Time
|
||||
}
|
||||
if latest.IsZero() || ts.Time.After(latest) {
|
||||
latest = ts.Time
|
||||
}
|
||||
}
|
||||
|
||||
return latest.Sub(earliest).Milliseconds(), nil
|
||||
}
|
||||
|
||||
// PainPoints returns a list of stages where the duration exceeds the given threshold.
|
||||
func (lt *LatencyTracker) PainPoints(threshold time.Duration) []string {
|
||||
var painPoints []string
|
||||
var lastStage string
|
||||
var lastTime time.Time
|
||||
|
||||
for _, ts := range lt.Timestamps {
|
||||
stage := ts.Stage
|
||||
if lastStage == "" {
|
||||
lastStage = stage
|
||||
lastTime = ts.Time
|
||||
continue
|
||||
}
|
||||
|
||||
currentTime := ts.Time
|
||||
if currentTime.Sub(lastTime) > threshold {
|
||||
painPoints = append(painPoints, fmt.Sprintf("%s -> %s", lastStage, stage))
|
||||
}
|
||||
|
||||
lastStage = stage
|
||||
lastTime = currentTime
|
||||
}
|
||||
return painPoints
|
||||
}
|
||||
|
||||
// StageLatency calculates the time taken between two specific stages.
|
||||
func (lt *LatencyTracker) StageLatency(startStage, endStage string) (time.Duration, error) {
|
||||
var startTime, endTime time.Time
|
||||
for _, ts := range lt.Timestamps {
|
||||
if ts.Stage == startStage {
|
||||
startTime = ts.Time
|
||||
}
|
||||
if ts.Stage == endStage {
|
||||
endTime = ts.Time
|
||||
}
|
||||
}
|
||||
|
||||
/*if startTime == "" || endTime == "" {
|
||||
return 0, fmt.Errorf("missing timestamps for stages: %s -> %s", startStage, endStage)
|
||||
}*/
|
||||
|
||||
return endTime.Sub(startTime), nil
|
||||
}
|
||||
|
||||
func LatencyTrackerFromProto(protolt *gen.ProtoLatencyTracker) *LatencyTracker {
|
||||
ret := &LatencyTracker{
|
||||
SequenceID: protolt.GetSequenceId(),
|
||||
Timestamps: make([]TimestampEntry, 0),
|
||||
}
|
||||
|
||||
for _, ts := range protolt.GetTimestamps() {
|
||||
ret.Timestamps = append(ret.Timestamps, TimestampEntry{
|
||||
Stage: ts.GetStage(),
|
||||
Time: ts.GetTime().AsTime(),
|
||||
})
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (lt *LatencyTracker) ToProto() *gen.ProtoLatencyTracker {
|
||||
ret := &gen.ProtoLatencyTracker{
|
||||
SequenceId: lt.SequenceID,
|
||||
Timestamps: make([]*gen.ProtoTimestampEntry, len(lt.Timestamps)),
|
||||
}
|
||||
|
||||
for i, timestamp := range lt.Timestamps {
|
||||
ret.Timestamps[i] = &gen.ProtoTimestampEntry{
|
||||
Stage: timestamp.Stage,
|
||||
Time: timestamppb.New(timestamp.Time),
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
48
packages/relay/internal/common/loghandler.go
Normal file
48
packages/relay/internal/common/loghandler.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CustomHandler struct {
|
||||
Handler slog.Handler
|
||||
}
|
||||
|
||||
func (h *CustomHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
return h.Handler.Enabled(nil, level)
|
||||
}
|
||||
|
||||
func (h *CustomHandler) Handle(_ context.Context, r slog.Record) error {
|
||||
// Format the timestamp as "2006/01/02 15:04:05"
|
||||
timestamp := r.Time.Format("2006/01/02 15:04:05")
|
||||
// Convert level to uppercase string (e.g., "INFO")
|
||||
level := strings.ToUpper(r.Level.String())
|
||||
// Build the message
|
||||
msg := fmt.Sprintf("%s %s %s", timestamp, level, r.Message)
|
||||
|
||||
// Handle additional attributes if they exist
|
||||
var attrs []string
|
||||
r.Attrs(func(a slog.Attr) bool {
|
||||
attrs = append(attrs, fmt.Sprintf("%s=%v", a.Key, a.Value))
|
||||
return true
|
||||
})
|
||||
if len(attrs) > 0 {
|
||||
msg += " " + strings.Join(attrs, " ")
|
||||
}
|
||||
|
||||
// Write the formatted message to stdout
|
||||
_, err := fmt.Fprintln(os.Stdout, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *CustomHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return &CustomHandler{Handler: h.Handler.WithAttrs(attrs)}
|
||||
}
|
||||
|
||||
func (h *CustomHandler) WithGroup(name string) slog.Handler {
|
||||
return &CustomHandler{Handler: h.Handler.WithGroup(name)}
|
||||
}
|
||||
101
packages/relay/internal/common/map.go
Normal file
101
packages/relay/internal/common/map.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
ErrValueNotPointer = errors.New("value is not a pointer")
|
||||
ErrFieldNotFound = errors.New("field not found")
|
||||
ErrTypeMismatch = errors.New("type mismatch")
|
||||
)
|
||||
|
||||
// SafeMap is a generic thread-safe map with its own mutex
|
||||
type SafeMap[K comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
m map[K]V
|
||||
}
|
||||
|
||||
// NewSafeMap creates a new SafeMap instance
|
||||
func NewSafeMap[K comparable, V any]() *SafeMap[K, V] {
|
||||
return &SafeMap[K, V]{
|
||||
m: make(map[K]V),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the map
|
||||
func (sm *SafeMap[K, V]) Get(key K) (V, bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
v, ok := sm.m[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Set adds or updates a value in the map
|
||||
func (sm *SafeMap[K, V]) Set(key K, value V) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.m[key] = value
|
||||
}
|
||||
|
||||
// Delete removes a key from the map
|
||||
func (sm *SafeMap[K, V]) Delete(key K) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
delete(sm.m, key)
|
||||
}
|
||||
|
||||
// Len returns the number of items in the map
|
||||
func (sm *SafeMap[K, V]) Len() int {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return len(sm.m)
|
||||
}
|
||||
|
||||
// Copy creates a shallow copy of the map and returns it
|
||||
func (sm *SafeMap[K, V]) Copy() map[K]V {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
copied := make(map[K]V, len(sm.m))
|
||||
for k, v := range sm.m {
|
||||
copied[k] = v
|
||||
}
|
||||
return copied
|
||||
}
|
||||
|
||||
// Update updates a specific field in the value data
|
||||
func (sm *SafeMap[K, V]) Update(key K, fieldName string, newValue any) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
v, ok := sm.m[key]
|
||||
if !ok {
|
||||
return ErrKeyNotFound
|
||||
}
|
||||
|
||||
// Use reflect to update the field
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
return ErrValueNotPointer
|
||||
}
|
||||
|
||||
rv = rv.Elem()
|
||||
// Check if the field exists
|
||||
field := rv.FieldByName(fieldName)
|
||||
if !field.IsValid() || !field.CanSet() {
|
||||
return ErrFieldNotFound
|
||||
}
|
||||
|
||||
newRV := reflect.ValueOf(newValue)
|
||||
if newRV.Type() != field.Type() {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
|
||||
field.Set(newRV)
|
||||
sm.m[key] = v
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user