mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 16:55:37 +02:00
⭐ feat: Migrate from WebSocket to libp2p for peer-to-peer connectivity (#286)
## Description Whew, some stuff is still not re-implemented, but it's working! Rabbit's gonna explode with the amount of changes I reckon 😅 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a peer-to-peer relay system using libp2p with enhanced stream forwarding, room state synchronization, and mDNS peer discovery. - Added decentralized room and participant management, metrics publishing, and safe, size-limited, concurrent message streaming with robust framing and callback dispatching. - Implemented asynchronous, callback-driven message handling over custom libp2p streams replacing WebSocket signaling. - **Improvements** - Migrated signaling and stream protocols from WebSocket to libp2p, improving reliability and scalability. - Simplified configuration and environment variables, removing deprecated flags and adding persistent data support. - Enhanced logging, error handling, and connection management for better observability and robustness. - Refined RTP header extension registration and NAT IP handling for improved WebRTC performance. - **Bug Fixes** - Improved ICE candidate buffering and SDP negotiation in WebRTC connections. - Fixed NAT IP and UDP port range configuration issues. - **Refactor** - Modularized codebase, reorganized relay and server logic, and removed deprecated WebSocket-based components. - Streamlined message structures, removed obsolete enums and message types, and simplified SafeMap concurrency. - Replaced WebSocket signaling with libp2p stream protocols in server and relay components. - **Chores** - Updated and cleaned dependencies across Go, Rust, and JavaScript packages. - Added `.gitignore` for persistent data directory in relay package. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com> Co-authored-by: Philipp Neumann <3daquawolf@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e67a8d2b32
commit
6e82eff9e2
@@ -2,12 +2,13 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
|
||||
"github.com/libp2p/go-reuseport"
|
||||
"github.com/pion/ice/v4"
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var globalWebRTCAPI *webrtc.API
|
||||
@@ -24,17 +25,9 @@ func InitWebRTCAPI() error {
|
||||
// Media engine
|
||||
mediaEngine := &webrtc.MediaEngine{}
|
||||
|
||||
// Register additional header extensions to reduce latency
|
||||
// Playout Delay
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeAudio); err != nil {
|
||||
return err
|
||||
// Register our extensions
|
||||
if err := RegisterExtensions(mediaEngine); err != nil {
|
||||
return fmt.Errorf("failed to register extensions: %w", err)
|
||||
}
|
||||
|
||||
// Default codecs cover most of our needs
|
||||
@@ -75,9 +68,10 @@ func InitWebRTCAPI() error {
|
||||
// New in v4, reduces CPU usage and latency when enabled
|
||||
settingEngine.EnableSCTPZeroChecksum(true)
|
||||
|
||||
nat11IPs := GetFlags().NAT11IPs
|
||||
if len(nat11IPs) > 0 {
|
||||
settingEngine.SetNAT1To1IPs(nat11IPs, webrtc.ICECandidateTypeHost)
|
||||
nat11IP := GetFlags().NAT11IP
|
||||
if len(nat11IP) > 0 {
|
||||
settingEngine.SetNAT1To1IPs([]string{nat11IP}, webrtc.ICECandidateTypeSrflx)
|
||||
slog.Info("Using NAT 1:1 IP for WebRTC", "nat11_ip", nat11IP)
|
||||
}
|
||||
|
||||
muxPort := GetFlags().UDPMuxPort
|
||||
@@ -85,7 +79,7 @@ func InitWebRTCAPI() error {
|
||||
// Use reuseport to allow multiple listeners on the same port
|
||||
pktListener, err := reuseport.ListenPacket("udp", ":"+strconv.Itoa(muxPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP listener: %w", err)
|
||||
return fmt.Errorf("failed to create WebRTC muxed UDP listener: %w", err)
|
||||
}
|
||||
|
||||
mux := ice.NewMultiUDPMuxDefault(ice.NewUDPMuxDefault(ice.UDPMuxParams{
|
||||
@@ -95,10 +89,13 @@ func InitWebRTCAPI() error {
|
||||
settingEngine.SetICEUDPMux(mux)
|
||||
}
|
||||
|
||||
// Set the UDP port range used by WebRTC
|
||||
err = settingEngine.SetEphemeralUDPPortRange(uint16(flags.WebRTCUDPStart), uint16(flags.WebRTCUDPEnd))
|
||||
if err != nil {
|
||||
return err
|
||||
if flags.WebRTCUDPStart > 0 && flags.WebRTCUDPEnd > 0 && flags.WebRTCUDPStart < flags.WebRTCUDPEnd {
|
||||
// Set the UDP port range used by WebRTC
|
||||
err = settingEngine.SetEphemeralUDPPortRange(uint16(flags.WebRTCUDPStart), uint16(flags.WebRTCUDPEnd))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("Using WebRTC UDP Port Range", "start", flags.WebRTCUDPStart, "end", flags.WebRTCUDPEnd)
|
||||
}
|
||||
|
||||
settingEngine.SetIncludeLoopbackCandidate(true) // Just in case
|
||||
@@ -109,11 +106,6 @@ func InitWebRTCAPI() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWebRTCAPI returns the global WebRTC API
|
||||
func GetWebRTCAPI() *webrtc.API {
|
||||
return globalWebRTCAPI
|
||||
}
|
||||
|
||||
// CreatePeerConnection sets up a new peer connection
|
||||
func CreatePeerConnection(onClose func()) (*webrtc.PeerConnection, error) {
|
||||
pc, err := globalWebRTCAPI.NewPeerConnection(globalWebRTCConfig)
|
||||
|
||||
@@ -1,19 +1,51 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
func NewULID() (ulid.ULID, error) {
|
||||
return ulid.New(ulid.Timestamp(time.Now()), ulid.Monotonic(rand.Reader, 0))
|
||||
}
|
||||
|
||||
// Helper function to generate PSK from token
|
||||
func GeneratePSKFromToken(token string) ([]byte, error) {
|
||||
// Simple hash-based PSK generation (32 bytes for libp2p)
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hash[:], nil
|
||||
// GenerateED25519Key generates a new ED25519 key
|
||||
func GenerateED25519Key() (ed25519.PrivateKey, error) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ED25519 key pair: %w", err)
|
||||
}
|
||||
return priv, nil
|
||||
}
|
||||
|
||||
// SaveED25519Key saves an ED25519 private key to a path as a binary file
|
||||
func SaveED25519Key(privateKey ed25519.PrivateKey, filePath string) error {
|
||||
if privateKey == nil {
|
||||
return errors.New("private key cannot be nil")
|
||||
}
|
||||
if len(privateKey) != ed25519.PrivateKeySize {
|
||||
return errors.New("private key must be exactly 64 bytes for ED25519")
|
||||
}
|
||||
if err := os.WriteFile(filePath, privateKey, 0600); err != nil {
|
||||
return fmt.Errorf("failed to save ED25519 key to %s: %w", filePath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadED25519Key loads an ED25519 private key binary file from a path
|
||||
func LoadED25519Key(filePath string) (ed25519.PrivateKey, error) {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read ED25519 key from %s: %w", filePath, err)
|
||||
}
|
||||
if len(data) != ed25519.PrivateKeySize {
|
||||
return nil, fmt.Errorf("ED25519 key must be exactly %d bytes, got %d", ed25519.PrivateKeySize, len(data))
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,45 @@
|
||||
package common
|
||||
|
||||
import "github.com/pion/webrtc/v4"
|
||||
|
||||
const (
|
||||
ExtensionPlayoutDelay string = "http://www.webrtc.org/experiments/rtp-hdrext/playout-delay"
|
||||
)
|
||||
|
||||
// ExtensionMap maps URIs to their IDs based on registration order
|
||||
// IMPORTANT: This must match the order in which extensions are registered in common.go!
|
||||
var ExtensionMap = map[string]uint8{
|
||||
ExtensionPlayoutDelay: 1,
|
||||
// ExtensionMap maps audio/video extension URIs to their IDs based on registration order
|
||||
var ExtensionMap = map[webrtc.RTPCodecType]map[string]uint8{}
|
||||
|
||||
func RegisterExtensions(mediaEngine *webrtc.MediaEngine) error {
|
||||
// Register additional header extensions to reduce latency
|
||||
// Playout Delay (Video)
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return err
|
||||
}
|
||||
// Playout Delay (Audio)
|
||||
if err := mediaEngine.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{
|
||||
URI: ExtensionPlayoutDelay,
|
||||
}, webrtc.RTPCodecTypeAudio); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register the extension IDs for both audio and video
|
||||
ExtensionMap[webrtc.RTPCodecTypeAudio] = map[string]uint8{
|
||||
ExtensionPlayoutDelay: 1,
|
||||
}
|
||||
ExtensionMap[webrtc.RTPCodecTypeVideo] = map[string]uint8{
|
||||
ExtensionPlayoutDelay: 1,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetExtension(codecType webrtc.RTPCodecType, extURI string) (uint8, bool) {
|
||||
cType, ok := ExtensionMap[codecType]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
extID, ok := cType[extURI]
|
||||
return extID, ok
|
||||
}
|
||||
|
||||
@@ -2,47 +2,43 @@ package common
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
var globalFlags *Flags
|
||||
|
||||
type Flags struct {
|
||||
Verbose bool // Log everything to console
|
||||
Debug bool // Enable debug mode, implies Verbose
|
||||
EndpointPort int // Port for HTTP/S and WS/S endpoint (TCP)
|
||||
MeshPort int // Port for Mesh connections (TCP)
|
||||
WebRTCUDPStart int // WebRTC UDP port range start - ignored if UDPMuxPort is set
|
||||
WebRTCUDPEnd int // WebRTC UDP port range end - ignored if UDPMuxPort is set
|
||||
STUNServer string // WebRTC STUN server
|
||||
UDPMuxPort int // WebRTC UDP mux port - if set, overrides UDP port range
|
||||
AutoAddLocalIP bool // Automatically add local IP to NAT 1 to 1 IPs
|
||||
NAT11IPs []string // WebRTC NAT 1 to 1 IP(s) - allows specifying host IP(s) if behind NAT
|
||||
TLSCert string // Path to TLS certificate
|
||||
TLSKey string // Path to TLS key
|
||||
ControlSecret string // Shared secret for this relay's control endpoint
|
||||
RegenIdentity bool // Remove old identity on startup and regenerate it
|
||||
Verbose bool // Log everything to console
|
||||
Debug bool // Enable debug mode, implies Verbose
|
||||
EndpointPort int // Port for HTTP/S and WS/S endpoint (TCP)
|
||||
WebRTCUDPStart int // WebRTC UDP port range start - ignored if UDPMuxPort is set
|
||||
WebRTCUDPEnd int // WebRTC UDP port range end - ignored if UDPMuxPort is set
|
||||
STUNServer string // WebRTC STUN server
|
||||
UDPMuxPort int // WebRTC UDP mux port - if set, overrides UDP port range
|
||||
AutoAddLocalIP bool // Automatically add local IP to NAT 1 to 1 IPs
|
||||
NAT11IP string // WebRTC NAT 1 to 1 IP - allows specifying IP of relay if behind NAT
|
||||
PersistDir string // Directory to save persistent data to
|
||||
}
|
||||
|
||||
func (flags *Flags) DebugLog() {
|
||||
slog.Info("Relay flags",
|
||||
slog.Debug("Relay flags",
|
||||
"regenIdentity", flags.RegenIdentity,
|
||||
"verbose", flags.Verbose,
|
||||
"debug", flags.Debug,
|
||||
"endpointPort", flags.EndpointPort,
|
||||
"meshPort", flags.MeshPort,
|
||||
"webrtcUDPStart", flags.WebRTCUDPStart,
|
||||
"webrtcUDPEnd", flags.WebRTCUDPEnd,
|
||||
"stunServer", flags.STUNServer,
|
||||
"webrtcUDPMux", flags.UDPMuxPort,
|
||||
"autoAddLocalIP", flags.AutoAddLocalIP,
|
||||
"webrtcNAT11IPs", strings.Join(flags.NAT11IPs, ","),
|
||||
"tlsCert", flags.TLSCert,
|
||||
"tlsKey", flags.TLSKey,
|
||||
"controlSecret", flags.ControlSecret,
|
||||
"webrtcNAT11IPs", flags.NAT11IP,
|
||||
"persistDir", flags.PersistDir,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -76,29 +72,25 @@ func InitFlags() {
|
||||
// Create Flags struct
|
||||
globalFlags = &Flags{}
|
||||
// Get flags
|
||||
flag.BoolVar(&globalFlags.RegenIdentity, "regenIdentity", getEnvAsBool("REGEN_IDENTITY", false), "Regenerate identity on startup")
|
||||
flag.BoolVar(&globalFlags.Verbose, "verbose", getEnvAsBool("VERBOSE", false), "Verbose mode")
|
||||
flag.BoolVar(&globalFlags.Debug, "debug", getEnvAsBool("DEBUG", false), "Debug mode")
|
||||
flag.IntVar(&globalFlags.EndpointPort, "endpointPort", getEnvAsInt("ENDPOINT_PORT", 8088), "HTTP endpoint port")
|
||||
flag.IntVar(&globalFlags.MeshPort, "meshPort", getEnvAsInt("MESH_PORT", 8089), "Mesh connections TCP port")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPStart, "webrtcUDPStart", getEnvAsInt("WEBRTC_UDP_START", 10000), "WebRTC UDP port range start")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPEnd, "webrtcUDPEnd", getEnvAsInt("WEBRTC_UDP_END", 20000), "WebRTC UDP port range end")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPStart, "webrtcUDPStart", getEnvAsInt("WEBRTC_UDP_START", 0), "WebRTC UDP port range start")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPEnd, "webrtcUDPEnd", getEnvAsInt("WEBRTC_UDP_END", 0), "WebRTC UDP port range end")
|
||||
flag.StringVar(&globalFlags.STUNServer, "stunServer", getEnvAsString("STUN_SERVER", "stun.l.google.com:19302"), "WebRTC STUN server")
|
||||
flag.IntVar(&globalFlags.UDPMuxPort, "webrtcUDPMux", getEnvAsInt("WEBRTC_UDP_MUX", 8088), "WebRTC UDP mux port")
|
||||
flag.BoolVar(&globalFlags.AutoAddLocalIP, "autoAddLocalIP", getEnvAsBool("AUTO_ADD_LOCAL_IP", true), "Automatically add local IP to NAT 1 to 1 IPs")
|
||||
// String with comma separated IPs
|
||||
nat11IPs := ""
|
||||
flag.StringVar(&nat11IPs, "webrtcNAT11IPs", getEnvAsString("WEBRTC_NAT_IPS", ""), "WebRTC NAT 1 to 1 IP(s), comma delimited")
|
||||
flag.StringVar(&globalFlags.TLSCert, "tlsCert", getEnvAsString("TLS_CERT", ""), "Path to TLS certificate")
|
||||
flag.StringVar(&globalFlags.TLSKey, "tlsKey", getEnvAsString("TLS_KEY", ""), "Path to TLS key")
|
||||
flag.StringVar(&globalFlags.ControlSecret, "controlSecret", getEnvAsString("CONTROL_SECRET", ""), "Shared secret for control endpoint")
|
||||
nat11IP := ""
|
||||
flag.StringVar(&nat11IP, "webrtcNAT11IP", getEnvAsString("WEBRTC_NAT_IP", ""), "WebRTC NAT 1 to 1 IP")
|
||||
flag.StringVar(&globalFlags.PersistDir, "persistDir", getEnvAsString("PERSIST_DIR", "./persist-data"), "Directory to save persistent data to")
|
||||
// Parse flags
|
||||
flag.Parse()
|
||||
|
||||
// If debug is enabled, verbose is also enabled
|
||||
if globalFlags.Debug {
|
||||
globalFlags.Verbose = true
|
||||
// If Debug is enabled, set ControlSecret to 1234
|
||||
globalFlags.ControlSecret = "1234"
|
||||
}
|
||||
|
||||
// ICE STUN servers
|
||||
@@ -108,24 +100,11 @@ func InitFlags() {
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize NAT 1 to 1 IPs
|
||||
globalFlags.NAT11IPs = []string{}
|
||||
|
||||
// Get local IP
|
||||
if globalFlags.AutoAddLocalIP {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, getLocalIP())
|
||||
}
|
||||
|
||||
// Parse NAT 1 to 1 IPs from string
|
||||
if len(nat11IPs) > 0 {
|
||||
split := strings.Split(nat11IPs, ",")
|
||||
if len(split) > 0 {
|
||||
for _, ip := range split {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, ip)
|
||||
}
|
||||
} else {
|
||||
globalFlags.NAT11IPs = append(globalFlags.NAT11IPs, nat11IPs)
|
||||
}
|
||||
if len(nat11IP) > 0 {
|
||||
globalFlags.NAT11IP = nat11IP
|
||||
} else if globalFlags.AutoAddLocalIP {
|
||||
globalFlags.NAT11IP = getLocalIP()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
gen "relay/internal/proto"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type TimestampEntry struct {
|
||||
|
||||
175
packages/relay/internal/common/safebufio.go
Normal file
175
packages/relay/internal/common/safebufio.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// MaxSize is the maximum allowed data size (1MB)
|
||||
const MaxSize = 1024 * 1024
|
||||
|
||||
// SafeBufioRW wraps a bufio.ReadWriter for sending and receiving JSON and protobufs safely
|
||||
type SafeBufioRW struct {
|
||||
brw *bufio.ReadWriter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSafeBufioRW(brw *bufio.ReadWriter) *SafeBufioRW {
|
||||
return &SafeBufioRW{brw: brw}
|
||||
}
|
||||
|
||||
// SendJSON serializes the given data as JSON and sends it with a 4-byte length prefix
|
||||
func (bu *SafeBufioRW) SendJSON(data interface{}) error {
|
||||
bu.mutex.Lock()
|
||||
defer bu.mutex.Unlock()
|
||||
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(jsonData) > MaxSize {
|
||||
return errors.New("JSON data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Write the 4-byte length prefix
|
||||
if err = binary.Write(bu.brw, binary.BigEndian, uint32(len(jsonData))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the JSON data
|
||||
if _, err = bu.brw.Write(jsonData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush the writer to ensure data is sent
|
||||
return bu.brw.Flush()
|
||||
}
|
||||
|
||||
// ReceiveJSON reads a 4-byte length prefix, then reads and unmarshals the JSON
|
||||
func (bu *SafeBufioRW) ReceiveJSON(dest interface{}) error {
|
||||
bu.mutex.RLock()
|
||||
defer bu.mutex.RUnlock()
|
||||
|
||||
// Read the 4-byte length prefix
|
||||
var length uint32
|
||||
if err := binary.Read(bu.brw, binary.BigEndian, &length); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if length > MaxSize {
|
||||
return errors.New("received JSON data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Read the JSON data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(bu.brw, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, dest)
|
||||
}
|
||||
|
||||
// Receive reads a 4-byte length prefix, then reads the raw data
|
||||
func (bu *SafeBufioRW) Receive() ([]byte, error) {
|
||||
bu.mutex.RLock()
|
||||
defer bu.mutex.RUnlock()
|
||||
|
||||
// Read the 4-byte length prefix
|
||||
var length uint32
|
||||
if err := binary.Read(bu.brw, binary.BigEndian, &length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if length > MaxSize {
|
||||
return nil, errors.New("received data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Read the raw data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(bu.brw, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// SendProto serializes the given protobuf message and sends it with a 4-byte length prefix
|
||||
func (bu *SafeBufioRW) SendProto(msg proto.Message) error {
|
||||
bu.mutex.Lock()
|
||||
defer bu.mutex.Unlock()
|
||||
|
||||
protoData, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(protoData) > MaxSize {
|
||||
return errors.New("protobuf data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Write the 4-byte length prefix
|
||||
if err = binary.Write(bu.brw, binary.BigEndian, uint32(len(protoData))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the Protobuf data
|
||||
if _, err := bu.brw.Write(protoData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush the writer to ensure data is sent
|
||||
return bu.brw.Flush()
|
||||
}
|
||||
|
||||
// ReceiveProto reads a 4-byte length prefix, then reads and unmarshals the protobuf
|
||||
func (bu *SafeBufioRW) ReceiveProto(msg proto.Message) error {
|
||||
bu.mutex.RLock()
|
||||
defer bu.mutex.RUnlock()
|
||||
|
||||
// Read the 4-byte length prefix
|
||||
var length uint32
|
||||
if err := binary.Read(bu.brw, binary.BigEndian, &length); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if length > MaxSize {
|
||||
return errors.New("received Protobuf data exceeds maximum size")
|
||||
}
|
||||
|
||||
// Read the Protobuf data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(bu.brw, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return proto.Unmarshal(data, msg)
|
||||
}
|
||||
|
||||
// Write writes raw data to the underlying buffer
|
||||
func (bu *SafeBufioRW) Write(data []byte) (int, error) {
|
||||
bu.mutex.Lock()
|
||||
defer bu.mutex.Unlock()
|
||||
|
||||
if len(data) > MaxSize {
|
||||
return 0, errors.New("data exceeds maximum size")
|
||||
}
|
||||
|
||||
n, err := bu.brw.Write(data)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Flush the writer to ensure data is sent
|
||||
if err = bu.brw.Flush(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,18 +1,11 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
ErrValueNotPointer = errors.New("value is not a pointer")
|
||||
ErrFieldNotFound = errors.New("field not found")
|
||||
ErrTypeMismatch = errors.New("type mismatch")
|
||||
)
|
||||
|
||||
// SafeMap is a generic thread-safe map with its own mutex
|
||||
type SafeMap[K comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
@@ -34,6 +27,14 @@ func (sm *SafeMap[K, V]) Get(key K) (V, bool) {
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Has checks if a key exists in the map
|
||||
func (sm *SafeMap[K, V]) Has(key K) bool {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
_, ok := sm.m[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Set adds or updates a value in the map
|
||||
func (sm *SafeMap[K, V]) Set(key K, value V) {
|
||||
sm.mu.Lock()
|
||||
@@ -66,36 +67,31 @@ func (sm *SafeMap[K, V]) Copy() map[K]V {
|
||||
return copied
|
||||
}
|
||||
|
||||
// Update updates a specific field in the value data
|
||||
func (sm *SafeMap[K, V]) Update(key K, fieldName string, newValue any) error {
|
||||
// Range iterates over the map and applies a function to each key-value pair
|
||||
func (sm *SafeMap[K, V]) Range(f func(K, V) bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
for k, v := range sm.m {
|
||||
if !f(k, v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SafeMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return json.Marshal(sm.m)
|
||||
}
|
||||
|
||||
func (sm *SafeMap[K, V]) UnmarshalJSON(data []byte) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
v, ok := sm.m[key]
|
||||
if !ok {
|
||||
return ErrKeyNotFound
|
||||
}
|
||||
|
||||
// Use reflect to update the field
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
return ErrValueNotPointer
|
||||
}
|
||||
|
||||
rv = rv.Elem()
|
||||
// Check if the field exists
|
||||
field := rv.FieldByName(fieldName)
|
||||
if !field.IsValid() || !field.CanSet() {
|
||||
return ErrFieldNotFound
|
||||
}
|
||||
|
||||
newRV := reflect.ValueOf(newValue)
|
||||
if newRV.Type() != field.Type() {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
|
||||
field.Set(newRV)
|
||||
sm.m[key] = v
|
||||
|
||||
return nil
|
||||
return json.Unmarshal(data, &sm.m)
|
||||
}
|
||||
|
||||
func (sm *SafeMap[K, V]) String() string {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return fmt.Sprintf("%+v", sm.m)
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"log/slog"
|
||||
gen "relay/internal/proto"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type OnMessageCallback func(data []byte)
|
||||
|
||||
// NestriDataChannel is a custom data channel with callbacks
|
||||
type NestriDataChannel struct {
|
||||
*webrtc.DataChannel
|
||||
@@ -37,7 +40,7 @@ func NewNestriDataChannel(dc *webrtc.DataChannel) *NestriDataChannel {
|
||||
// Handle message type callback
|
||||
if callback, ok := ndc.callbacks["input"]; ok {
|
||||
go callback(msg.Data)
|
||||
} // TODO: Log unknown message type?
|
||||
} // We don't care about unhandled messages
|
||||
})
|
||||
|
||||
return ndc
|
||||
|
||||
@@ -1,18 +1,32 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
"encoding/json"
|
||||
"relay/internal/common"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
// MessageBase is the base type for WS/DC messages.
|
||||
// MessageBase is the base type for any JSON message
|
||||
type MessageBase struct {
|
||||
PayloadType string `json:"payload_type"`
|
||||
Latency *common.LatencyTracker `json:"latency,omitempty"`
|
||||
Type string `json:"payload_type"`
|
||||
Latency *common.LatencyTracker `json:"latency,omitempty"`
|
||||
}
|
||||
|
||||
type MessageRaw struct {
|
||||
MessageBase
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func NewMessageRaw(t string, data json.RawMessage) *MessageRaw {
|
||||
return &MessageRaw{
|
||||
MessageBase: MessageBase{
|
||||
Type: t,
|
||||
},
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// MessageLog represents a log message.
|
||||
type MessageLog struct {
|
||||
MessageBase
|
||||
Level string `json:"level"`
|
||||
@@ -20,7 +34,17 @@ type MessageLog struct {
|
||||
Time string `json:"time"`
|
||||
}
|
||||
|
||||
// MessageMetrics represents a metrics/heartbeat message.
|
||||
func NewMessageLog(t string, level, message, time string) *MessageLog {
|
||||
return &MessageLog{
|
||||
MessageBase: MessageBase{
|
||||
Type: t,
|
||||
},
|
||||
Level: level,
|
||||
Message: message,
|
||||
Time: time,
|
||||
}
|
||||
}
|
||||
|
||||
type MessageMetrics struct {
|
||||
MessageBase
|
||||
UsageCPU float64 `json:"usage_cpu"`
|
||||
@@ -29,104 +53,42 @@ type MessageMetrics struct {
|
||||
PipelineLatency float64 `json:"pipeline_latency"`
|
||||
}
|
||||
|
||||
// MessageICECandidate represents an ICE candidate message.
|
||||
type MessageICECandidate struct {
|
||||
MessageBase
|
||||
Candidate webrtc.ICECandidateInit `json:"candidate"`
|
||||
}
|
||||
|
||||
// MessageSDP represents an SDP message.
|
||||
type MessageSDP struct {
|
||||
MessageBase
|
||||
SDP webrtc.SessionDescription `json:"sdp"`
|
||||
}
|
||||
|
||||
// JoinerType is an enum for the type of incoming room joiner
|
||||
type JoinerType int
|
||||
|
||||
const (
|
||||
JoinerNode JoinerType = iota
|
||||
JoinerClient
|
||||
)
|
||||
|
||||
func (jt *JoinerType) String() string {
|
||||
switch *jt {
|
||||
case JoinerNode:
|
||||
return "node"
|
||||
case JoinerClient:
|
||||
return "client"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// MessageJoin is used to tell us that either participant or ingest wants to join the room
|
||||
type MessageJoin struct {
|
||||
MessageBase
|
||||
JoinerType JoinerType `json:"joiner_type"`
|
||||
}
|
||||
|
||||
// AnswerType is an enum for the type of answer, signaling Room state for a joiner
|
||||
type AnswerType int
|
||||
|
||||
const (
|
||||
AnswerOffline AnswerType = iota // For participant/client, when the room is offline without stream
|
||||
AnswerInUse // For ingest/node joiner, when the room is already in use by another ingest/node
|
||||
AnswerOK // For both, when the join request is handled successfully
|
||||
)
|
||||
|
||||
// MessageAnswer is used to send the answer to a join request
|
||||
type MessageAnswer struct {
|
||||
MessageBase
|
||||
AnswerType AnswerType `json:"answer_type"`
|
||||
}
|
||||
|
||||
// SendLogMessageWS sends a log message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendLogMessageWS(level, message string) error {
|
||||
msg := MessageLog{
|
||||
MessageBase: MessageBase{PayloadType: "log"},
|
||||
Level: level,
|
||||
Message: message,
|
||||
Time: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
return ws.SendJSON(msg)
|
||||
}
|
||||
|
||||
// SendMetricsMessageWS sends a metrics message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendMetricsMessageWS(usageCPU, usageMemory float64, uptime uint64, pipelineLatency float64) error {
|
||||
msg := MessageMetrics{
|
||||
MessageBase: MessageBase{PayloadType: "metrics"},
|
||||
func NewMessageMetrics(t string, usageCPU, usageMemory float64, uptime uint64, pipelineLatency float64) *MessageMetrics {
|
||||
return &MessageMetrics{
|
||||
MessageBase: MessageBase{
|
||||
Type: t,
|
||||
},
|
||||
UsageCPU: usageCPU,
|
||||
UsageMemory: usageMemory,
|
||||
Uptime: uptime,
|
||||
PipelineLatency: pipelineLatency,
|
||||
}
|
||||
return ws.SendJSON(msg)
|
||||
}
|
||||
|
||||
// SendICECandidateMessageWS sends an ICE candidate message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendICECandidateMessageWS(candidate webrtc.ICECandidateInit) error {
|
||||
msg := MessageICECandidate{
|
||||
MessageBase: MessageBase{PayloadType: "ice"},
|
||||
Candidate: candidate,
|
||||
}
|
||||
return ws.SendJSON(msg)
|
||||
type MessageICE struct {
|
||||
MessageBase
|
||||
Candidate webrtc.ICECandidateInit `json:"candidate"`
|
||||
}
|
||||
|
||||
// SendSDPMessageWS sends an SDP message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendSDPMessageWS(sdp webrtc.SessionDescription) error {
|
||||
msg := MessageSDP{
|
||||
MessageBase: MessageBase{PayloadType: "sdp"},
|
||||
SDP: sdp,
|
||||
func NewMessageICE(t string, candidate webrtc.ICECandidateInit) *MessageICE {
|
||||
return &MessageICE{
|
||||
MessageBase: MessageBase{
|
||||
Type: t,
|
||||
},
|
||||
Candidate: candidate,
|
||||
}
|
||||
return ws.SendJSON(msg)
|
||||
}
|
||||
|
||||
// SendAnswerMessageWS sends an answer message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendAnswerMessageWS(answer AnswerType) error {
|
||||
msg := MessageAnswer{
|
||||
MessageBase: MessageBase{PayloadType: "answer"},
|
||||
AnswerType: answer,
|
||||
}
|
||||
return ws.SendJSON(msg)
|
||||
type MessageSDP struct {
|
||||
MessageBase
|
||||
SDP webrtc.SessionDescription `json:"sdp"`
|
||||
}
|
||||
|
||||
func NewMessageSDP(t string, sdp webrtc.SessionDescription) *MessageSDP {
|
||||
return &MessageSDP{
|
||||
MessageBase: MessageBase{
|
||||
Type: t,
|
||||
},
|
||||
SDP: sdp,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
"google.golang.org/protobuf/proto"
|
||||
gen "relay/internal/proto"
|
||||
)
|
||||
|
||||
// SendMeshHandshake sends a handshake message to another relay.
|
||||
func (ws *SafeWebSocket) SendMeshHandshake(relayID, publicKey string) error {
|
||||
msg := &gen.MeshMessage{
|
||||
Type: &gen.MeshMessage_Handshake{
|
||||
Handshake: &gen.Handshake{
|
||||
RelayId: relayID,
|
||||
DhPublicKey: publicKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ws.SendBinary(data)
|
||||
}
|
||||
|
||||
// SendMeshHandshakeResponse sends a handshake response to a relay.
|
||||
func (ws *SafeWebSocket) SendMeshHandshakeResponse(relayID, dhPublicKey string, approvals map[string]string) error {
|
||||
msg := &gen.MeshMessage{
|
||||
Type: &gen.MeshMessage_HandshakeResponse{
|
||||
HandshakeResponse: &gen.HandshakeResponse{
|
||||
RelayId: relayID,
|
||||
DhPublicKey: dhPublicKey,
|
||||
Approvals: approvals,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ws.SendBinary(data)
|
||||
}
|
||||
|
||||
// SendMeshForwardSDP sends a forwarded SDP message to another relay
|
||||
func (ws *SafeWebSocket) SendMeshForwardSDP(roomName, participantID string, sdp webrtc.SessionDescription) error {
|
||||
msg := &gen.MeshMessage{
|
||||
Type: &gen.MeshMessage_ForwardSdp{
|
||||
ForwardSdp: &gen.ForwardSDP{
|
||||
RoomName: roomName,
|
||||
ParticipantId: participantID,
|
||||
Sdp: sdp.SDP,
|
||||
Type: sdp.Type.String(),
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ws.SendBinary(data)
|
||||
}
|
||||
|
||||
// SendMeshForwardICE sends a forwarded ICE candidate to another relay
|
||||
func (ws *SafeWebSocket) SendMeshForwardICE(roomName, participantID string, candidate webrtc.ICECandidateInit) error {
|
||||
var sdpMLineIndex uint32
|
||||
if candidate.SDPMLineIndex != nil {
|
||||
sdpMLineIndex = uint32(*candidate.SDPMLineIndex)
|
||||
}
|
||||
|
||||
msg := &gen.MeshMessage{
|
||||
Type: &gen.MeshMessage_ForwardIce{
|
||||
ForwardIce: &gen.ForwardICE{
|
||||
RoomName: roomName,
|
||||
ParticipantId: participantID,
|
||||
Candidate: &gen.ICECandidateInit{
|
||||
Candidate: candidate.Candidate,
|
||||
SdpMid: candidate.SDPMid,
|
||||
SdpMLineIndex: &sdpMLineIndex,
|
||||
UsernameFragment: candidate.UsernameFragment,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ws.SendBinary(data)
|
||||
}
|
||||
|
||||
func (ws *SafeWebSocket) SendMeshForwardIngest(roomName string) error {
|
||||
msg := &gen.MeshMessage{
|
||||
Type: &gen.MeshMessage_ForwardIngest{
|
||||
ForwardIngest: &gen.ForwardIngest{
|
||||
RoomName: roomName,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ws.SendBinary(data)
|
||||
}
|
||||
|
||||
func (ws *SafeWebSocket) SendMeshStreamRequest(roomName string) error {
|
||||
msg := &gen.MeshMessage{
|
||||
Type: &gen.MeshMessage_StreamRequest{
|
||||
StreamRequest: &gen.StreamRequest{
|
||||
RoomName: roomName,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ws.SendBinary(data)
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gorilla/websocket"
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// OnMessageCallback is a callback for messages of given type
|
||||
type OnMessageCallback func(data []byte)
|
||||
|
||||
// SafeWebSocket is a websocket with a mutex
|
||||
type SafeWebSocket struct {
|
||||
*websocket.Conn
|
||||
sync.Mutex
|
||||
closed bool
|
||||
closeCallback func() // Callback to call on close
|
||||
closeChan chan struct{} // Channel to signal closure
|
||||
callbacks map[string]OnMessageCallback // MessageBase type -> callback
|
||||
binaryCallback OnMessageCallback // Binary message callback
|
||||
sharedSecret []byte
|
||||
}
|
||||
|
||||
// NewSafeWebSocket creates a new SafeWebSocket from *websocket.Conn
|
||||
func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
|
||||
ws := &SafeWebSocket{
|
||||
Conn: conn,
|
||||
closed: false,
|
||||
closeCallback: nil,
|
||||
closeChan: make(chan struct{}),
|
||||
callbacks: make(map[string]OnMessageCallback),
|
||||
binaryCallback: nil,
|
||||
sharedSecret: nil,
|
||||
}
|
||||
|
||||
// Launch a goroutine to handle messages
|
||||
go func() {
|
||||
for {
|
||||
// Read message
|
||||
kind, data, err := ws.Conn.ReadMessage()
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
|
||||
// If unexpected close error, break
|
||||
slog.Debug("WebSocket closed unexpectedly", "err", err)
|
||||
break
|
||||
} else if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
|
||||
break
|
||||
} else if err != nil {
|
||||
slog.Error("Failed reading WebSocket message", "err", err)
|
||||
break
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case websocket.TextMessage:
|
||||
// Decode message
|
||||
var msg MessageBase
|
||||
if err = json.Unmarshal(data, &msg); err != nil {
|
||||
slog.Error("Failed decoding WebSocket message", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle message type callback
|
||||
if callback, ok := ws.callbacks[msg.PayloadType]; ok {
|
||||
callback(data)
|
||||
} // TODO: Log unknown message payload type?
|
||||
break
|
||||
case websocket.BinaryMessage:
|
||||
// Handle binary message callback
|
||||
if ws.binaryCallback != nil {
|
||||
ws.binaryCallback(data)
|
||||
}
|
||||
break
|
||||
default:
|
||||
slog.Warn("Unknown WebSocket message type", "type", kind)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Signal closure to callback first
|
||||
if ws.closeCallback != nil {
|
||||
ws.closeCallback()
|
||||
}
|
||||
close(ws.closeChan)
|
||||
ws.closed = true
|
||||
}()
|
||||
|
||||
return ws
|
||||
}
|
||||
|
||||
// SetSharedSecret sets the shared secret for the websocket
|
||||
func (ws *SafeWebSocket) SetSharedSecret(secret []byte) {
|
||||
ws.sharedSecret = secret
|
||||
}
|
||||
|
||||
// GetSharedSecret returns the shared secret for the websocket
|
||||
func (ws *SafeWebSocket) GetSharedSecret() []byte {
|
||||
return ws.sharedSecret
|
||||
}
|
||||
|
||||
// SendJSON writes JSON to a websocket with a mutex
|
||||
func (ws *SafeWebSocket) SendJSON(v interface{}) error {
|
||||
ws.Lock()
|
||||
defer ws.Unlock()
|
||||
return ws.Conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
// SendBinary writes binary to a websocket with a mutex
|
||||
func (ws *SafeWebSocket) SendBinary(data []byte) error {
|
||||
ws.Lock()
|
||||
defer ws.Unlock()
|
||||
return ws.Conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
// RegisterMessageCallback sets the callback for binary message of given type
|
||||
func (ws *SafeWebSocket) RegisterMessageCallback(msgType string, callback OnMessageCallback) {
|
||||
if ws.callbacks == nil {
|
||||
ws.callbacks = make(map[string]OnMessageCallback)
|
||||
}
|
||||
ws.callbacks[msgType] = callback
|
||||
}
|
||||
|
||||
// RegisterBinaryMessageCallback sets the callback for all binary messages
|
||||
func (ws *SafeWebSocket) RegisterBinaryMessageCallback(callback OnMessageCallback) {
|
||||
ws.binaryCallback = callback
|
||||
}
|
||||
|
||||
// UnregisterMessageCallback removes the callback for binary message of given type
|
||||
func (ws *SafeWebSocket) UnregisterMessageCallback(msgType string) {
|
||||
if ws.callbacks != nil {
|
||||
delete(ws.callbacks, msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterBinaryMessageCallback removes the callback for all binary messages
|
||||
func (ws *SafeWebSocket) UnregisterBinaryMessageCallback() {
|
||||
ws.binaryCallback = nil
|
||||
}
|
||||
|
||||
// RegisterOnClose sets the callback for websocket closing
|
||||
func (ws *SafeWebSocket) RegisterOnClose(callback func()) {
|
||||
ws.closeCallback = func() {
|
||||
// Clear our callbacks
|
||||
ws.callbacks = nil
|
||||
ws.binaryCallback = nil
|
||||
// Call the callback
|
||||
callback()
|
||||
}
|
||||
}
|
||||
|
||||
// Closed returns a channel that closes when the WebSocket connection is terminated
|
||||
func (ws *SafeWebSocket) Closed() <-chan struct{} {
|
||||
return ws.closeChan
|
||||
}
|
||||
|
||||
// IsClosed returns true if the WebSocket connection is closed
|
||||
func (ws *SafeWebSocket) IsClosed() bool {
|
||||
return ws.closed
|
||||
}
|
||||
13
packages/relay/internal/core/consts.go
Normal file
13
packages/relay/internal/core/consts.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package core
|
||||
|
||||
import "time"
|
||||
|
||||
// --- Constants ---
|
||||
const (
|
||||
// PubSub Topics
|
||||
roomStateTopicName = "room-states"
|
||||
relayMetricsTopicName = "relay-metrics"
|
||||
|
||||
// Timers and Intervals
|
||||
metricsPublishInterval = 15 * time.Second // How often to publish own metrics
|
||||
)
|
||||
214
packages/relay/internal/core/core.go
Normal file
214
packages/relay/internal/core/core.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"relay/internal/common"
|
||||
"relay/internal/shared"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p"
|
||||
pubsub "github.com/libp2p/go-libp2p-pubsub"
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
|
||||
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
// -- Variables --
|
||||
|
||||
var globalRelay *Relay
|
||||
|
||||
// -- Structs --
|
||||
|
||||
// RelayInfo contains light information of Relay, in mesh-friendly format
|
||||
type RelayInfo struct {
|
||||
ID peer.ID
|
||||
MeshAddrs []string // Addresses of this relay
|
||||
MeshRooms *common.SafeMap[string, shared.RoomInfo] // Rooms hosted by this relay
|
||||
MeshLatencies *common.SafeMap[string, time.Duration] // Latencies to other peers from this relay
|
||||
}
|
||||
|
||||
// Relay structure enhanced with metrics and state
|
||||
type Relay struct {
|
||||
RelayInfo
|
||||
|
||||
Host host.Host // libp2p host for peer-to-peer networking
|
||||
PubSub *pubsub.PubSub // PubSub for state synchronization
|
||||
PingService *ping.PingService
|
||||
|
||||
// Local
|
||||
LocalRooms *common.SafeMap[ulid.ULID, *shared.Room] // room ID -> local Room struct (hosted by this relay)
|
||||
LocalMeshPeers *common.SafeMap[peer.ID, *RelayInfo] // peer ID -> mesh peer relay info (connected to this relay)
|
||||
LocalMeshConnections *common.SafeMap[peer.ID, *webrtc.PeerConnection] // peer ID -> PeerConnection (connected to this relay)
|
||||
|
||||
// Protocols
|
||||
ProtocolRegistry
|
||||
|
||||
// PubSub Topics
|
||||
pubTopicState *pubsub.Topic // topic for room states
|
||||
pubTopicRelayMetrics *pubsub.Topic // topic for relay metrics/status
|
||||
}
|
||||
|
||||
func NewRelay(ctx context.Context, port int, identityKey crypto.PrivKey) (*Relay, error) {
|
||||
listenAddrs := []string{
|
||||
fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", port), // IPv4 - Raw TCP
|
||||
fmt.Sprintf("/ip6/::/tcp/%d", port), // IPv6 - Raw TCP
|
||||
fmt.Sprintf("/ip4/0.0.0.0/tcp/%d/ws", port), // IPv4 - TCP WebSocket
|
||||
fmt.Sprintf("/ip6/::/tcp/%d/ws", port), // IPv6 - TCP WebSocket
|
||||
}
|
||||
|
||||
var muAddrs []multiaddr.Multiaddr
|
||||
for _, addr := range listenAddrs {
|
||||
multiAddr, err := multiaddr.NewMultiaddr(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse multiaddr '%s': %w", addr, err)
|
||||
}
|
||||
muAddrs = append(muAddrs, multiAddr)
|
||||
}
|
||||
|
||||
// Initialize libp2p host
|
||||
p2pHost, err := libp2p.New(
|
||||
// TODO: Currently static identity
|
||||
libp2p.Identity(identityKey),
|
||||
// Enable required transports
|
||||
libp2p.Transport(tcp.NewTCPTransport),
|
||||
libp2p.Transport(ws.New),
|
||||
// Other options
|
||||
libp2p.ListenAddrs(muAddrs...),
|
||||
libp2p.Security(noise.ID, noise.New),
|
||||
libp2p.EnableRelay(),
|
||||
libp2p.EnableHolePunching(),
|
||||
libp2p.EnableNATService(),
|
||||
libp2p.EnableAutoNATv2(),
|
||||
libp2p.ShareTCPListener(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create libp2p host for relay: %w", err)
|
||||
}
|
||||
|
||||
// Set up pubsub
|
||||
p2pPubsub, err := pubsub.NewGossipSub(ctx, p2pHost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create pubsub: %w, addrs: %v", err, p2pHost.Addrs())
|
||||
}
|
||||
|
||||
// Initialize Ping Service
|
||||
pingSvc := ping.NewPingService(p2pHost)
|
||||
|
||||
var addresses []string
|
||||
for _, addr := range p2pHost.Addrs() {
|
||||
addresses = append(addresses, addr.String())
|
||||
}
|
||||
|
||||
r := &Relay{
|
||||
RelayInfo: RelayInfo{
|
||||
ID: p2pHost.ID(),
|
||||
MeshAddrs: addresses,
|
||||
MeshRooms: common.NewSafeMap[string, shared.RoomInfo](),
|
||||
MeshLatencies: common.NewSafeMap[string, time.Duration](),
|
||||
},
|
||||
Host: p2pHost,
|
||||
PubSub: p2pPubsub,
|
||||
PingService: pingSvc,
|
||||
LocalRooms: common.NewSafeMap[ulid.ULID, *shared.Room](),
|
||||
LocalMeshPeers: common.NewSafeMap[peer.ID, *RelayInfo](),
|
||||
}
|
||||
|
||||
// Add network notifier after relay is initialized
|
||||
p2pHost.Network().Notify(&networkNotifier{relay: r})
|
||||
|
||||
// Set up PubSub topics and handlers
|
||||
if err = r.setupPubSub(ctx); err != nil {
|
||||
err = p2pHost.Close()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close host after PubSub setup failure", "err", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to setup PubSub: %w", err)
|
||||
}
|
||||
|
||||
// Initialize Protocol Registry
|
||||
r.ProtocolRegistry = NewProtocolRegistry(r)
|
||||
|
||||
// Start discovery features
|
||||
if err = startMDNSDiscovery(r); err != nil {
|
||||
slog.Warn("Failed to initialize mDNS discovery, continuing without..", "error", err)
|
||||
}
|
||||
|
||||
// Start background tasks
|
||||
go r.periodicMetricsPublisher(ctx)
|
||||
|
||||
printConnectInstructions(p2pHost)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func InitRelay(ctx context.Context, ctxCancel context.CancelFunc) error {
|
||||
var err error
|
||||
persistentDir := common.GetFlags().PersistDir
|
||||
|
||||
// Load or generate identity key
|
||||
var identityKey crypto.PrivKey
|
||||
var privKey ed25519.PrivateKey
|
||||
// First check if we need to generate identity
|
||||
hasIdentity := len(persistentDir) > 0 && common.GetFlags().RegenIdentity == false
|
||||
if hasIdentity {
|
||||
_, err = os.Stat(persistentDir + "/identity.key")
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to check identity key file: %w", err)
|
||||
} else if os.IsNotExist(err) {
|
||||
hasIdentity = false
|
||||
}
|
||||
}
|
||||
if !hasIdentity {
|
||||
// Make sure the persistent directory exists
|
||||
if err = os.MkdirAll(persistentDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create persistent data directory: %w", err)
|
||||
}
|
||||
// Generate
|
||||
slog.Info("Generating new identity for relay")
|
||||
privKey, err = common.GenerateED25519Key()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate new identity: %w", err)
|
||||
}
|
||||
// Save the key
|
||||
if err = common.SaveED25519Key(privKey, persistentDir+"/identity.key"); err != nil {
|
||||
return fmt.Errorf("failed to save identity key: %w", err)
|
||||
}
|
||||
slog.Info("New identity generated and saved", "path", persistentDir+"/identity.key")
|
||||
} else {
|
||||
slog.Info("Loading existing identity for relay", "path", persistentDir+"/identity.key")
|
||||
// Load the key
|
||||
privKey, err = common.LoadED25519Key(persistentDir + "/identity.key")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load identity key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to libp2p crypto.PrivKey
|
||||
identityKey, err = crypto.UnmarshalEd25519PrivateKey(privKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal ED25519 private key: %w", err)
|
||||
}
|
||||
|
||||
globalRelay, err = NewRelay(ctx, common.GetFlags().EndpointPort, identityKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create relay: %w", err)
|
||||
}
|
||||
|
||||
if err = common.InitWebRTCAPI(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("Relay initialized", "id", globalRelay.ID)
|
||||
return nil
|
||||
}
|
||||
38
packages/relay/internal/core/mdns.go
Normal file
38
packages/relay/internal/core/mdns.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/p2p/discovery/mdns"
|
||||
)
|
||||
|
||||
const (
|
||||
mdnsDiscoveryRendezvous = "/nestri-relay/mdns-discovery/1.0.0" // Shared string for mDNS discovery
|
||||
)
|
||||
|
||||
type discoveryNotifee struct {
|
||||
relay *Relay
|
||||
}
|
||||
|
||||
func (d *discoveryNotifee) HandlePeerFound(pi peer.AddrInfo) {
|
||||
if d.relay != nil {
|
||||
if err := d.relay.connectToRelay(context.Background(), &pi); err != nil {
|
||||
slog.Error("failed to connect to discovered relay", "peer", pi.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func startMDNSDiscovery(relay *Relay) error {
|
||||
d := &discoveryNotifee{
|
||||
relay: relay,
|
||||
}
|
||||
|
||||
service := mdns.NewMdnsService(relay.Host, mdnsDiscoveryRendezvous, d)
|
||||
if err := service.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start mDNS discovery: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
128
packages/relay/internal/core/metrics.go
Normal file
128
packages/relay/internal/core/metrics.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
// --- Metrics Collection and Publishing ---
|
||||
|
||||
// periodicMetricsPublisher periodically gathers local metrics and publishes them.
|
||||
func (r *Relay) periodicMetricsPublisher(ctx context.Context) {
|
||||
ticker := time.NewTicker(metricsPublishInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Publish immediately on start
|
||||
if err := r.publishRelayMetrics(ctx); err != nil {
|
||||
slog.Error("Failed to publish initial relay metrics", "err", err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("Stopping metrics publisher")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.publishRelayMetrics(ctx); err != nil {
|
||||
slog.Error("Failed to publish relay metrics", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// publishRelayMetrics sends the current relay status to the mesh.
|
||||
func (r *Relay) publishRelayMetrics(ctx context.Context) error {
|
||||
if r.pubTopicRelayMetrics == nil {
|
||||
slog.Warn("Cannot publish relay metrics: topic is nil")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check all peer latencies
|
||||
r.checkAllPeerLatencies(ctx)
|
||||
|
||||
data, err := json.Marshal(r.RelayInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal relay status: %w", err)
|
||||
}
|
||||
|
||||
if pubErr := r.pubTopicRelayMetrics.Publish(ctx, data); pubErr != nil {
|
||||
// Don't return error on publish failure, just log
|
||||
slog.Error("Failed to publish relay metrics message", "err", pubErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAllPeerLatencies measures latency to all currently connected peers.
|
||||
func (r *Relay) checkAllPeerLatencies(ctx context.Context) {
|
||||
var wg sync.WaitGroup
|
||||
for _, p := range r.Host.Network().Peers() {
|
||||
if p == r.ID {
|
||||
continue // Skip self
|
||||
}
|
||||
wg.Add(1)
|
||||
// Run checks concurrently
|
||||
go func(peerID peer.ID) {
|
||||
defer wg.Done()
|
||||
go r.measureLatencyToPeer(ctx, peerID)
|
||||
}(p)
|
||||
}
|
||||
wg.Wait() // Wait for all latency checks to complete
|
||||
}
|
||||
|
||||
// measureLatencyToPeer pings a specific peer and updates the local latency map.
|
||||
func (r *Relay) measureLatencyToPeer(ctx context.Context, peerID peer.ID) {
|
||||
// Check peer status first
|
||||
if !r.hasConnectedPeer(peerID) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create a context for the ping operation
|
||||
pingCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Use the PingService instance stored in the Relay struct
|
||||
if r.PingService == nil {
|
||||
slog.Error("PingService is nil, cannot measure latency", "peer", peerID)
|
||||
return
|
||||
}
|
||||
resultsCh := r.PingService.Ping(pingCtx, peerID)
|
||||
|
||||
// Wait for the result (or timeout)
|
||||
select {
|
||||
case <-pingCtx.Done():
|
||||
// Ping timed out
|
||||
slog.Warn("Latency check canceled", "peer", peerID, "err", pingCtx.Err())
|
||||
case result, ok := <-resultsCh:
|
||||
if !ok {
|
||||
// Channel closed unexpectedly
|
||||
slog.Warn("Ping service channel closed unexpectedly", "peer", peerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Received ping result
|
||||
if result.Error != nil {
|
||||
slog.Warn("Latency check failed, removing peer from local peers map", "peer", peerID, "err", result.Error)
|
||||
// Remove from MeshPeers if ping failed
|
||||
if r.LocalMeshPeers.Has(peerID) {
|
||||
r.LocalMeshPeers.Delete(peerID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Ping successful, update latency
|
||||
latency := result.RTT
|
||||
// Ensure latency is not zero if successful, assign a minimal value if so.
|
||||
// Sometimes RTT can be reported as 0 for very fast local connections.
|
||||
if latency <= 0 {
|
||||
latency = 1 * time.Microsecond
|
||||
}
|
||||
|
||||
r.RelayInfo.MeshLatencies.Set(peerID.String(), latency)
|
||||
}
|
||||
}
|
||||
128
packages/relay/internal/core/p2p.go
Normal file
128
packages/relay/internal/core/p2p.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
// --- Structs ---
|
||||
|
||||
// networkNotifier logs connection events and updates relay state
|
||||
type networkNotifier struct {
|
||||
relay *Relay
|
||||
}
|
||||
|
||||
// Connected is called when a connection is established
|
||||
func (n *networkNotifier) Connected(net network.Network, conn network.Conn) {
|
||||
if n.relay == nil {
|
||||
n.relay.onPeerConnected(conn.RemotePeer())
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnected is called when a connection is terminated
|
||||
func (n *networkNotifier) Disconnected(net network.Network, conn network.Conn) {
|
||||
// Update the status of the disconnected peer
|
||||
if n.relay != nil {
|
||||
n.relay.onPeerDisconnected(conn.RemotePeer())
|
||||
}
|
||||
}
|
||||
|
||||
// Listen is called when the node starts listening on an address
|
||||
func (n *networkNotifier) Listen(net network.Network, addr multiaddr.Multiaddr) {}
|
||||
|
||||
// ListenClose is called when the node stops listening on an address
|
||||
func (n *networkNotifier) ListenClose(net network.Network, addr multiaddr.Multiaddr) {}
|
||||
|
||||
// --- PubSub Setup ---
|
||||
|
||||
// setupPubSub initializes PubSub topics and subscriptions.
|
||||
func (r *Relay) setupPubSub(ctx context.Context) error {
|
||||
var err error
|
||||
|
||||
// Room State Topic
|
||||
r.pubTopicState, err = r.PubSub.Join(roomStateTopicName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to join room state topic '%s': %w", roomStateTopicName, err)
|
||||
}
|
||||
stateSub, err := r.pubTopicState.Subscribe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to room state topic '%s': %w", roomStateTopicName, err)
|
||||
}
|
||||
go r.handleRoomStateMessages(ctx, stateSub) // Handler in relay_state.go
|
||||
|
||||
// Relay Metrics Topic
|
||||
r.pubTopicRelayMetrics, err = r.PubSub.Join(relayMetricsTopicName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to join relay metrics topic '%s': %w", relayMetricsTopicName, err)
|
||||
}
|
||||
metricsSub, err := r.pubTopicRelayMetrics.Subscribe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to relay metrics topic '%s': %w", relayMetricsTopicName, err)
|
||||
}
|
||||
go r.handleRelayMetricsMessages(ctx, metricsSub) // Handler in relay_state.go
|
||||
|
||||
slog.Info("PubSub topics joined and subscriptions started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Connection Management ---
|
||||
|
||||
// connectToRelay is internal method to connect to a relay peer using multiaddresses
|
||||
func (r *Relay) connectToRelay(ctx context.Context, peerInfo *peer.AddrInfo) error {
|
||||
if peerInfo.ID == r.ID {
|
||||
return errors.New("cannot connect to self")
|
||||
}
|
||||
|
||||
// Use a timeout for the connection attempt
|
||||
connectCtx, cancel := context.WithTimeout(ctx, 15*time.Second) // 15s timeout
|
||||
defer cancel()
|
||||
|
||||
slog.Info("Attempting to connect to peer", "peer", peerInfo.ID, "addrs", peerInfo.Addrs)
|
||||
if err := r.Host.Connect(connectCtx, *peerInfo); err != nil {
|
||||
return fmt.Errorf("failed to connect to %s: %w", peerInfo.ID, err)
|
||||
}
|
||||
|
||||
slog.Info("Successfully connected to peer", "peer", peerInfo.ID, "addrs", peerInfo.Addrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectToRelay connects to another relay by its multiaddress.
|
||||
func (r *Relay) ConnectToRelay(ctx context.Context, addr string) error {
|
||||
ma, err := multiaddr.NewMultiaddr(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid multiaddress: %w", err)
|
||||
}
|
||||
|
||||
peerInfo, err := peer.AddrInfoFromP2pAddr(ma)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract peer info: %w", err)
|
||||
}
|
||||
|
||||
return r.connectToRelay(ctx, peerInfo)
|
||||
}
|
||||
|
||||
// printConnectInstructions logs the multiaddresses for connecting to this relay.
|
||||
func printConnectInstructions(p2pHost host.Host) {
|
||||
peerInfo := peer.AddrInfo{
|
||||
ID: p2pHost.ID(),
|
||||
Addrs: p2pHost.Addrs(),
|
||||
}
|
||||
addrs, err := peer.AddrInfoToP2pAddrs(&peerInfo)
|
||||
if err != nil {
|
||||
slog.Error("Failed to convert peer info to addresses", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Mesh connection addresses:")
|
||||
for _, addr := range addrs {
|
||||
slog.Info(fmt.Sprintf("> %s", addr.String()))
|
||||
}
|
||||
}
|
||||
694
packages/relay/internal/core/protocol_stream.go
Normal file
694
packages/relay/internal/core/protocol_stream.go
Normal file
@@ -0,0 +1,694 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
"relay/internal/shared"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/pion/rtp"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
// TODO:s
|
||||
// TODO: When disconnecting with stream open, causes crash on requester
|
||||
// TODO: Need to trigger stream request if remote room is online and there are participants in local waiting
|
||||
// TODO: Cleanup local room state when stream is closed upstream
|
||||
|
||||
// --- Protocol IDs ---
|
||||
const (
|
||||
protocolStreamRequest = "/nestri-relay/stream-request/1.0.0" // For requesting a stream from relay
|
||||
protocolStreamPush = "/nestri-relay/stream-push/1.0.0" // For pushing a stream to relay
|
||||
)
|
||||
|
||||
// --- Protocol Types ---
|
||||
|
||||
// StreamConnection is a connection between two relays for stream protocol
|
||||
type StreamConnection struct {
|
||||
pc *webrtc.PeerConnection
|
||||
ndc *connections.NestriDataChannel
|
||||
}
|
||||
|
||||
// StreamProtocol deals with meshed stream forwarding
|
||||
type StreamProtocol struct {
|
||||
relay *Relay
|
||||
servedConns *common.SafeMap[peer.ID, *StreamConnection] // peer ID -> StreamConnection (for served streams)
|
||||
incomingConns *common.SafeMap[string, *StreamConnection] // room name -> StreamConnection (for incoming pushed streams)
|
||||
requestedConns *common.SafeMap[string, *StreamConnection] // room name -> StreamConnection (for requested streams from other relays)
|
||||
}
|
||||
|
||||
func NewStreamProtocol(relay *Relay) *StreamProtocol {
|
||||
protocol := &StreamProtocol{
|
||||
relay: relay,
|
||||
servedConns: common.NewSafeMap[peer.ID, *StreamConnection](),
|
||||
incomingConns: common.NewSafeMap[string, *StreamConnection](),
|
||||
requestedConns: common.NewSafeMap[string, *StreamConnection](),
|
||||
}
|
||||
|
||||
protocol.relay.Host.SetStreamHandler(protocolStreamRequest, protocol.handleStreamRequest)
|
||||
protocol.relay.Host.SetStreamHandler(protocolStreamPush, protocol.handleStreamPush)
|
||||
|
||||
return protocol
|
||||
}
|
||||
|
||||
// --- Protocol Stream Handlers ---
|
||||
|
||||
// handleStreamRequest manages a request from another relay for a stream hosted locally
|
||||
func (sp *StreamProtocol) handleStreamRequest(stream network.Stream) {
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
|
||||
safeBRW := common.NewSafeBufioRW(brw)
|
||||
|
||||
iceHolder := make([]webrtc.ICECandidateInit, 0)
|
||||
for {
|
||||
data, err := safeBRW.Receive()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, network.ErrReset) {
|
||||
slog.Debug("Stream request connection closed by peer", "peer", stream.Conn().RemotePeer())
|
||||
return
|
||||
}
|
||||
|
||||
slog.Error("Failed to receive data", "err", err)
|
||||
_ = stream.Reset()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var baseMsg connections.MessageBase
|
||||
if err = json.Unmarshal(data, &baseMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal base message", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch baseMsg.Type {
|
||||
case "request-stream-room":
|
||||
var rawMsg connections.MessageRaw
|
||||
if err = json.Unmarshal(data, &rawMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal raw message for room stream request", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var roomName string
|
||||
if err = json.Unmarshal(rawMsg.Data, &roomName); err != nil {
|
||||
slog.Error("Failed to unmarshal room name from raw message", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Received stream request for room", "room", roomName)
|
||||
room := sp.relay.GetRoomByName(roomName)
|
||||
if room == nil || !room.IsOnline() || room.OwnerID != sp.relay.ID {
|
||||
// TODO: Allow forward requests to other relays from here?
|
||||
slog.Debug("Cannot provide stream for nil, offline or non-owned room", "room", roomName, "is_online", room != nil && room.IsOnline(), "is_owner", room != nil && room.OwnerID == sp.relay.ID)
|
||||
// Respond with "request-stream-offline" message with room name
|
||||
// TODO: Store the peer and send "online" message when the room comes online
|
||||
roomNameData, err := json.Marshal(roomName)
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal room name for request stream offline", "room", roomName, "err", err)
|
||||
continue
|
||||
} else {
|
||||
if err = safeBRW.SendJSON(connections.NewMessageRaw(
|
||||
"request-stream-offline",
|
||||
roomNameData,
|
||||
)); err != nil {
|
||||
slog.Error("Failed to send request stream offline message", "room", roomName, "err", err)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pc, err := common.CreatePeerConnection(func() {
|
||||
slog.Info("PeerConnection closed for requested stream", "room", roomName)
|
||||
// Cleanup the stream connection
|
||||
if ok := sp.servedConns.Has(stream.Conn().RemotePeer()); ok {
|
||||
sp.servedConns.Delete(stream.Conn().RemotePeer())
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to create PeerConnection for requested stream", "room", roomName, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add tracks
|
||||
if room.AudioTrack != nil {
|
||||
if _, err = pc.AddTrack(room.AudioTrack); err != nil {
|
||||
slog.Error("Failed to add audio track for requested stream", "room", roomName, "err", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if room.VideoTrack != nil {
|
||||
if _, err = pc.AddTrack(room.VideoTrack); err != nil {
|
||||
slog.Error("Failed to add video track for requested stream", "room", roomName, "err", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// DataChannel setup
|
||||
settingOrdered := true
|
||||
settingMaxRetransmits := uint16(2)
|
||||
dc, err := pc.CreateDataChannel("relay-data", &webrtc.DataChannelInit{
|
||||
Ordered: &settingOrdered,
|
||||
MaxRetransmits: &settingMaxRetransmits,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to create DataChannel for requested stream", "room", roomName, "err", err)
|
||||
continue
|
||||
}
|
||||
ndc := connections.NewNestriDataChannel(dc)
|
||||
|
||||
ndc.RegisterOnOpen(func() {
|
||||
slog.Debug("Relay DataChannel opened for requested stream", "room", roomName)
|
||||
})
|
||||
ndc.RegisterOnClose(func() {
|
||||
slog.Debug("Relay DataChannel closed for requested stream", "room", roomName)
|
||||
})
|
||||
ndc.RegisterMessageCallback("input", func(data []byte) {
|
||||
if room.DataChannel != nil {
|
||||
if err = room.DataChannel.SendBinary(data); err != nil {
|
||||
slog.Error("Failed to forward input message from mesh to upstream room", "room", roomName, "err", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// ICE Candidate handling
|
||||
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = safeBRW.SendJSON(connections.NewMessageICE("ice-candidate", candidate.ToJSON())); err != nil {
|
||||
slog.Error("Failed to send ICE candidate message for requested stream", "room", roomName, "err", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
// Create offer
|
||||
offer, err := pc.CreateOffer(nil)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create offer for requested stream", "room", roomName, "err", err)
|
||||
continue
|
||||
}
|
||||
if err = pc.SetLocalDescription(offer); err != nil {
|
||||
slog.Error("Failed to set local description for requested stream", "room", roomName, "err", err)
|
||||
continue
|
||||
}
|
||||
if err = safeBRW.SendJSON(connections.NewMessageSDP("offer", offer)); err != nil {
|
||||
slog.Error("Failed to send offer for requested stream", "room", roomName, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Store the connection
|
||||
sp.servedConns.Set(stream.Conn().RemotePeer(), &StreamConnection{
|
||||
pc: pc,
|
||||
ndc: ndc,
|
||||
})
|
||||
|
||||
slog.Debug("Sent offer for requested stream")
|
||||
case "ice-candidate":
|
||||
var iceMsg connections.MessageICE
|
||||
if err := json.Unmarshal(data, &iceMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal ICE message", "err", err)
|
||||
continue
|
||||
}
|
||||
if conn, ok := sp.servedConns.Get(stream.Conn().RemotePeer()); ok && conn.pc.RemoteDescription() != nil {
|
||||
if err := conn.pc.AddICECandidate(iceMsg.Candidate); err != nil {
|
||||
slog.Error("Failed to add ICE candidate", "err", err)
|
||||
}
|
||||
for _, heldIce := range iceHolder {
|
||||
if err := conn.pc.AddICECandidate(heldIce); err != nil {
|
||||
slog.Error("Failed to add held ICE candidate", "err", err)
|
||||
}
|
||||
}
|
||||
// Clear the held candidates
|
||||
iceHolder = make([]webrtc.ICECandidateInit, 0)
|
||||
} else {
|
||||
// Hold the candidate until remote description is set
|
||||
iceHolder = append(iceHolder, iceMsg.Candidate)
|
||||
}
|
||||
case "answer":
|
||||
var answerMsg connections.MessageSDP
|
||||
if err := json.Unmarshal(data, &answerMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal answer from signaling message", "err", err)
|
||||
continue
|
||||
}
|
||||
if conn, ok := sp.servedConns.Get(stream.Conn().RemotePeer()); ok {
|
||||
if err := conn.pc.SetRemoteDescription(answerMsg.SDP); err != nil {
|
||||
slog.Error("Failed to set remote description for answer", "err", err)
|
||||
continue
|
||||
}
|
||||
slog.Debug("Set remote description for answer")
|
||||
} else {
|
||||
slog.Warn("Received answer without active PeerConnection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// requestStream manages the internals of the stream request
|
||||
func (sp *StreamProtocol) requestStream(stream network.Stream, room *shared.Room) error {
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
|
||||
safeBRW := common.NewSafeBufioRW(brw)
|
||||
|
||||
slog.Debug("Requesting room stream from peer", "room", room.Name, "peer", stream.Conn().RemotePeer())
|
||||
|
||||
// Send room name to the remote peer
|
||||
roomData, err := json.Marshal(room.Name)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return fmt.Errorf("failed to marshal room name: %w", err)
|
||||
}
|
||||
if err = safeBRW.SendJSON(connections.NewMessageRaw(
|
||||
"request-stream-room",
|
||||
roomData,
|
||||
)); err != nil {
|
||||
_ = stream.Close()
|
||||
return fmt.Errorf("failed to send room request: %w", err)
|
||||
}
|
||||
|
||||
pc, err := common.CreatePeerConnection(func() {
|
||||
slog.Info("Relay PeerConnection closed for requested stream", "room", room.Name)
|
||||
_ = stream.Close() // ignore error as may be closed already
|
||||
// Cleanup the stream connection
|
||||
if ok := sp.requestedConns.Has(room.Name); ok {
|
||||
sp.requestedConns.Delete(room.Name)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return fmt.Errorf("failed to create PeerConnection: %w", err)
|
||||
}
|
||||
|
||||
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
localTrack, _ := webrtc.NewTrackLocalStaticRTP(track.Codec().RTPCodecCapability, track.ID(), "relay-"+room.Name+"-"+track.Kind().String())
|
||||
slog.Debug("Received track for requested stream", "room", room.Name, "track_kind", track.Kind().String())
|
||||
|
||||
room.SetTrack(track.Kind(), localTrack)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
rtpPacket, _, err := track.ReadRTP()
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
slog.Error("Failed to read RTP packet for requested stream room", "room", room.Name, "err", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
err = localTrack.WriteRTP(rtpPacket)
|
||||
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
slog.Error("Failed to write RTP to local track for requested stream room", "room", room.Name, "err", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
ndc := connections.NewNestriDataChannel(dc)
|
||||
ndc.RegisterOnOpen(func() {
|
||||
slog.Debug("Relay DataChannel opened for requested stream", "room", room.Name)
|
||||
})
|
||||
ndc.RegisterOnClose(func() {
|
||||
slog.Debug("Relay DataChannel closed for requested stream", "room", room.Name)
|
||||
})
|
||||
|
||||
// Set the DataChannel in the requestedConns map
|
||||
if conn, ok := sp.requestedConns.Get(room.Name); ok {
|
||||
conn.ndc = ndc
|
||||
} else {
|
||||
sp.requestedConns.Set(room.Name, &StreamConnection{
|
||||
pc: pc,
|
||||
ndc: ndc,
|
||||
})
|
||||
}
|
||||
|
||||
// We do not handle any messages from upstream here
|
||||
})
|
||||
|
||||
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = safeBRW.SendJSON(connections.NewMessageICE(
|
||||
"ice-candidate",
|
||||
candidate.ToJSON(),
|
||||
)); err != nil {
|
||||
slog.Error("Failed to send ICE candidate message for requested stream", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
// Handle incoming messages (offer and candidates)
|
||||
go func() {
|
||||
iceHolder := make([]webrtc.ICECandidateInit, 0)
|
||||
|
||||
for {
|
||||
data, err := safeBRW.Receive()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, network.ErrReset) {
|
||||
slog.Debug("Connection for requested stream closed by peer", "room", room.Name)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Error("Failed to receive data for requested stream", "room", room.Name, "err", err)
|
||||
_ = stream.Reset()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var baseMsg connections.MessageBase
|
||||
if err = json.Unmarshal(data, &baseMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal base message for requested stream", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch baseMsg.Type {
|
||||
case "ice-candidate":
|
||||
var iceMsg connections.MessageICE
|
||||
if err = json.Unmarshal(data, &iceMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal ICE candidate for requested stream", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
if conn, ok := sp.requestedConns.Get(room.Name); ok && conn.pc.RemoteDescription() != nil {
|
||||
if err = conn.pc.AddICECandidate(iceMsg.Candidate); err != nil {
|
||||
slog.Error("Failed to add ICE candidate for requested stream", "room", room.Name, "err", err)
|
||||
}
|
||||
// Add held candidates
|
||||
for _, heldCandidate := range iceHolder {
|
||||
if err = conn.pc.AddICECandidate(heldCandidate); err != nil {
|
||||
slog.Error("Failed to add held ICE candidate for requested stream", "room", room.Name, "err", err)
|
||||
}
|
||||
}
|
||||
// Clear the held candidates
|
||||
iceHolder = make([]webrtc.ICECandidateInit, 0)
|
||||
} else {
|
||||
// Hold the candidate until remote description is set
|
||||
iceHolder = append(iceHolder, iceMsg.Candidate)
|
||||
}
|
||||
case "offer":
|
||||
var offerMsg connections.MessageSDP
|
||||
if err = json.Unmarshal(data, &offerMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal offer for requested stream", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
if err = pc.SetRemoteDescription(offerMsg.SDP); err != nil {
|
||||
slog.Error("Failed to set remote description for requested stream", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
answer, err := pc.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create answer for requested stream", "room", room.Name, "err", err)
|
||||
if err = stream.Reset(); err != nil {
|
||||
slog.Error("Failed to reset stream for requested stream", "err", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err = pc.SetLocalDescription(answer); err != nil {
|
||||
slog.Error("Failed to set local description for requested stream", "room", room.Name, "err", err)
|
||||
if err = stream.Reset(); err != nil {
|
||||
slog.Error("Failed to reset stream for requested stream", "err", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err = safeBRW.SendJSON(connections.NewMessageSDP(
|
||||
"answer",
|
||||
answer,
|
||||
)); err != nil {
|
||||
slog.Error("Failed to send answer for requested stream", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Store the connection
|
||||
sp.requestedConns.Set(room.Name, &StreamConnection{
|
||||
pc: pc,
|
||||
ndc: nil,
|
||||
})
|
||||
|
||||
slog.Debug("Sent answer for requested stream", "room", room.Name)
|
||||
default:
|
||||
slog.Warn("Unknown signaling message type", "room", room.Name, "type", baseMsg.Type)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleStreamPush manages a stream push from a node (nestri-server)
|
||||
func (sp *StreamProtocol) handleStreamPush(stream network.Stream) {
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
|
||||
safeBRW := common.NewSafeBufioRW(brw)
|
||||
|
||||
var room *shared.Room
|
||||
iceHolder := make([]webrtc.ICECandidateInit, 0)
|
||||
for {
|
||||
data, err := safeBRW.Receive()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, network.ErrReset) {
|
||||
slog.Debug("Stream push connection closed by peer", "peer", stream.Conn().RemotePeer())
|
||||
return
|
||||
}
|
||||
|
||||
slog.Error("Failed to receive data for stream push", "err", err)
|
||||
_ = stream.Reset()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var baseMsg connections.MessageBase
|
||||
if err = json.Unmarshal(data, &baseMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal base message from base message", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch baseMsg.Type {
|
||||
case "push-stream-room":
|
||||
var rawMsg connections.MessageRaw
|
||||
if err = json.Unmarshal(data, &rawMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal room name from data", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var roomName string
|
||||
if err = json.Unmarshal(rawMsg.Data, &roomName); err != nil {
|
||||
slog.Error("Failed to unmarshal room name from raw message", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Received stream push request for room", "room", roomName)
|
||||
|
||||
room = sp.relay.GetRoomByName(roomName)
|
||||
if room != nil {
|
||||
if room.OwnerID != sp.relay.ID {
|
||||
slog.Error("Cannot push a stream to non-owned room", "room", room.Name, "owner_id", room.OwnerID)
|
||||
continue
|
||||
}
|
||||
if room.IsOnline() {
|
||||
slog.Error("Cannot push a stream to already online room", "room", room.Name)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Create a new room if it doesn't exist
|
||||
room = sp.relay.CreateRoom(roomName)
|
||||
}
|
||||
|
||||
// Respond with an OK with the room name
|
||||
roomData, err := json.Marshal(room.Name)
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal room name for push stream response", "err", err)
|
||||
continue
|
||||
}
|
||||
if err = safeBRW.SendJSON(connections.NewMessageRaw(
|
||||
"push-stream-ok",
|
||||
roomData,
|
||||
)); err != nil {
|
||||
slog.Error("Failed to send push stream OK response", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
case "ice-candidate":
|
||||
var iceMsg connections.MessageICE
|
||||
if err = json.Unmarshal(data, &iceMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal ICE candidate from data", "err", err)
|
||||
continue
|
||||
}
|
||||
if conn, ok := sp.incomingConns.Get(room.Name); ok && conn.pc.RemoteDescription() != nil {
|
||||
if err = conn.pc.AddICECandidate(iceMsg.Candidate); err != nil {
|
||||
slog.Error("Failed to add ICE candidate for pushed stream", "err", err)
|
||||
}
|
||||
for _, heldIce := range iceHolder {
|
||||
if err := conn.pc.AddICECandidate(heldIce); err != nil {
|
||||
slog.Error("Failed to add held ICE candidate for pushed stream", "err", err)
|
||||
}
|
||||
}
|
||||
// Clear the held candidates
|
||||
iceHolder = make([]webrtc.ICECandidateInit, 0)
|
||||
} else {
|
||||
// Hold the candidate until remote description is set
|
||||
iceHolder = append(iceHolder, iceMsg.Candidate)
|
||||
}
|
||||
case "offer":
|
||||
// Make sure we have room set to push to (set by "push-stream-room")
|
||||
if room == nil {
|
||||
slog.Error("Received offer without room set for stream push")
|
||||
continue
|
||||
}
|
||||
|
||||
var offerMsg connections.MessageSDP
|
||||
if err = json.Unmarshal(data, &offerMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal offer from data", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create PeerConnection for the incoming stream
|
||||
pc, err := common.CreatePeerConnection(func() {
|
||||
slog.Info("PeerConnection closed for pushed stream", "room", room.Name)
|
||||
// Cleanup the stream connection
|
||||
if ok := sp.incomingConns.Has(room.Name); ok {
|
||||
sp.incomingConns.Delete(room.Name)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to create PeerConnection for pushed stream", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
// TODO: Is this the best way to handle DataChannel? Should we just use the map directly?
|
||||
room.DataChannel = connections.NewNestriDataChannel(dc)
|
||||
room.DataChannel.RegisterOnOpen(func() {
|
||||
slog.Debug("DataChannel opened for pushed stream", "room", room.Name)
|
||||
})
|
||||
room.DataChannel.RegisterOnClose(func() {
|
||||
slog.Debug("DataChannel closed for pushed stream", "room", room.Name)
|
||||
})
|
||||
|
||||
// Set the DataChannel in the incomingConns map
|
||||
if conn, ok := sp.incomingConns.Get(room.Name); ok {
|
||||
conn.ndc = room.DataChannel
|
||||
} else {
|
||||
sp.incomingConns.Set(room.Name, &StreamConnection{
|
||||
pc: pc,
|
||||
ndc: room.DataChannel,
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = safeBRW.SendJSON(connections.NewMessageICE(
|
||||
"ice-candidate",
|
||||
candidate.ToJSON(),
|
||||
)); err != nil {
|
||||
slog.Error("Failed to send ICE candidate message for pushed stream", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
pc.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
localTrack, err := webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, remoteTrack.Kind().String(), fmt.Sprintf("nestri-%s-%s", room.Name, remoteTrack.Kind().String()))
|
||||
if err != nil {
|
||||
slog.Error("Failed to create local track for pushed stream", "room", room.Name, "track_kind", remoteTrack.Kind().String(), "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("Received track for pushed stream", "room", room.Name, "track_kind", remoteTrack.Kind().String())
|
||||
|
||||
// Set track for Room
|
||||
room.SetTrack(remoteTrack.Kind(), localTrack)
|
||||
|
||||
// Prepare PlayoutDelayExtension so we don't need to recreate it for each packet
|
||||
playoutExt := &rtp.PlayoutDelayExtension{
|
||||
MinDelay: 0,
|
||||
MaxDelay: 0,
|
||||
}
|
||||
playoutPayload, err := playoutExt.Marshal()
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal PlayoutDelayExtension for room", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
rtpPacket, _, err := remoteTrack.ReadRTP()
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
slog.Error("Failed to read RTP from remote track for room", "room", room.Name, "err", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Use PlayoutDelayExtension for low latency, if set for this track kind
|
||||
if extID, ok := common.GetExtension(remoteTrack.Kind(), common.ExtensionPlayoutDelay); ok {
|
||||
if err := rtpPacket.SetExtension(extID, playoutPayload); err != nil {
|
||||
slog.Error("Failed to set PlayoutDelayExtension for room", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
err = localTrack.WriteRTP(rtpPacket)
|
||||
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
slog.Error("Failed to write RTP to local track for room", "room", room.Name, "err", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("Track closed for room", "room", room.Name, "track_kind", remoteTrack.Kind().String())
|
||||
|
||||
// Cleanup the track from the room
|
||||
room.SetTrack(remoteTrack.Kind(), nil)
|
||||
})
|
||||
|
||||
// Set the remote description
|
||||
if err = pc.SetRemoteDescription(offerMsg.SDP); err != nil {
|
||||
slog.Error("Failed to set remote description for pushed stream", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
slog.Debug("Set remote description for pushed stream", "room", room.Name)
|
||||
|
||||
// Create an answer
|
||||
answer, err := pc.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create answer for pushed stream", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
if err = pc.SetLocalDescription(answer); err != nil {
|
||||
slog.Error("Failed to set local description for pushed stream", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
if err = safeBRW.SendJSON(connections.NewMessageSDP(
|
||||
"answer",
|
||||
answer,
|
||||
)); err != nil {
|
||||
slog.Error("Failed to send answer for pushed stream", "room", room.Name, "err", err)
|
||||
}
|
||||
|
||||
// Store the connection
|
||||
sp.incomingConns.Set(room.Name, &StreamConnection{
|
||||
pc: pc,
|
||||
ndc: room.DataChannel, // if it exists, if not it will be set later
|
||||
})
|
||||
slog.Debug("Sent answer for pushed stream", "room", room.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Public Usable Methods ---
|
||||
|
||||
// RequestStream sends a request to get room stream from another relay
|
||||
func (sp *StreamProtocol) RequestStream(ctx context.Context, room *shared.Room, peerID peer.ID) error {
|
||||
stream, err := sp.relay.Host.NewStream(ctx, peerID, protocolStreamRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stream request: %w", err)
|
||||
}
|
||||
|
||||
return sp.requestStream(stream, room)
|
||||
}
|
||||
13
packages/relay/internal/core/protocols.go
Normal file
13
packages/relay/internal/core/protocols.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package core
|
||||
|
||||
// ProtocolRegistry is a type holding all protocols to split away the bloat
|
||||
type ProtocolRegistry struct {
|
||||
StreamProtocol *StreamProtocol
|
||||
}
|
||||
|
||||
// NewProtocolRegistry initializes and returns a new protocol registry
|
||||
func NewProtocolRegistry(relay *Relay) ProtocolRegistry {
|
||||
return ProtocolRegistry{
|
||||
StreamProtocol: NewStreamProtocol(relay),
|
||||
}
|
||||
}
|
||||
108
packages/relay/internal/core/room.go
Normal file
108
packages/relay/internal/core/room.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"relay/internal/shared"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
// --- Room Management ---
|
||||
|
||||
// GetRoomByID retrieves a local Room struct by its ULID
|
||||
func (r *Relay) GetRoomByID(id ulid.ULID) *shared.Room {
|
||||
if room, ok := r.LocalRooms.Get(id); ok {
|
||||
return room
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRoomByName retrieves a local Room struct by its name
|
||||
func (r *Relay) GetRoomByName(name string) *shared.Room {
|
||||
for _, room := range r.LocalRooms.Copy() {
|
||||
if room.Name == name {
|
||||
return room
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateRoom creates a new local Room struct with the given name
|
||||
func (r *Relay) CreateRoom(name string) *shared.Room {
|
||||
roomID := ulid.Make()
|
||||
room := shared.NewRoom(name, roomID, r.ID)
|
||||
r.LocalRooms.Set(room.ID, room)
|
||||
slog.Debug("Created new local room", "room", name, "id", room.ID)
|
||||
return room
|
||||
}
|
||||
|
||||
// DeleteRoomIfEmpty checks if a local room struct is inactive and can be removed
|
||||
func (r *Relay) DeleteRoomIfEmpty(room *shared.Room) {
|
||||
if room == nil {
|
||||
return
|
||||
}
|
||||
if room.Participants.Len() == 0 && r.LocalRooms.Has(room.ID) {
|
||||
slog.Debug("Deleting empty room without participants", "room", room.Name)
|
||||
r.LocalRooms.Delete(room.ID)
|
||||
err := room.PeerConnection.Close()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close Room PeerConnection", "room", room.Name, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRemoteRoomByName returns room from mesh by name
|
||||
func (r *Relay) GetRemoteRoomByName(roomName string) *shared.RoomInfo {
|
||||
for _, room := range r.MeshRooms.Copy() {
|
||||
if room.Name == roomName && room.OwnerID != r.ID {
|
||||
// Make sure connection is alive
|
||||
if r.Host.Network().Connectedness(room.OwnerID) == network.Connected {
|
||||
return &room
|
||||
} else {
|
||||
slog.Debug("Removing stale peer, owns a room without connection", "room", roomName, "peer", room.OwnerID)
|
||||
r.onPeerDisconnected(room.OwnerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- State Publishing ---
|
||||
|
||||
// publishRoomStates publishes the state of all rooms currently owned by *this* relay
|
||||
func (r *Relay) publishRoomStates(ctx context.Context) error {
|
||||
if r.pubTopicState == nil {
|
||||
slog.Warn("Cannot publish room states: topic is nil")
|
||||
return nil
|
||||
}
|
||||
|
||||
var statesToPublish []shared.RoomInfo
|
||||
r.LocalRooms.Range(func(id ulid.ULID, room *shared.Room) bool {
|
||||
// Only publish state for rooms owned by this relay
|
||||
if room.OwnerID == r.ID {
|
||||
statesToPublish = append(statesToPublish, shared.RoomInfo{
|
||||
ID: room.ID,
|
||||
Name: room.Name,
|
||||
OwnerID: r.ID,
|
||||
})
|
||||
}
|
||||
return true // Continue iteration
|
||||
})
|
||||
|
||||
if len(statesToPublish) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := json.Marshal(statesToPublish)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal local room states: %w", err)
|
||||
}
|
||||
if pubErr := r.pubTopicState.Publish(ctx, data); pubErr != nil {
|
||||
slog.Error("Failed to publish room states message", "err", pubErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
173
packages/relay/internal/core/state.go
Normal file
173
packages/relay/internal/core/state.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"relay/internal/shared"
|
||||
"time"
|
||||
|
||||
pubsub "github.com/libp2p/go-libp2p-pubsub"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
// --- PubSub Message Handlers ---
|
||||
|
||||
// handleRoomStateMessages processes incoming room state updates from peers.
|
||||
func (r *Relay) handleRoomStateMessages(ctx context.Context, sub *pubsub.Subscription) {
|
||||
slog.Debug("Starting room state message handler...")
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("Stopping room state message handler")
|
||||
return
|
||||
default:
|
||||
msg, err := sub.Next(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, pubsub.ErrSubscriptionCancelled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
slog.Info("Room state subscription ended", "err", err)
|
||||
return
|
||||
}
|
||||
slog.Error("Error receiving room state message", "err", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
if msg.GetFrom() == r.Host.ID() {
|
||||
continue
|
||||
}
|
||||
|
||||
var states []shared.RoomInfo
|
||||
if err := json.Unmarshal(msg.Data, &states); err != nil {
|
||||
slog.Error("Failed to unmarshal room states", "from", msg.GetFrom(), "data_len", len(msg.Data), "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
r.updateMeshRoomStates(msg.GetFrom(), states)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRelayMetricsMessages processes incoming status updates from peers.
|
||||
func (r *Relay) handleRelayMetricsMessages(ctx context.Context, sub *pubsub.Subscription) {
|
||||
slog.Debug("Starting relay metrics message handler...")
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("Stopping relay metrics message handler")
|
||||
return
|
||||
default:
|
||||
msg, err := sub.Next(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, pubsub.ErrSubscriptionCancelled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
slog.Info("Relay metrics subscription ended", "err", err)
|
||||
return
|
||||
}
|
||||
slog.Error("Error receiving relay metrics message", "err", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
if msg.GetFrom() == r.Host.ID() {
|
||||
continue
|
||||
}
|
||||
|
||||
var info RelayInfo
|
||||
if err := json.Unmarshal(msg.Data, &info); err != nil {
|
||||
slog.Error("Failed to unmarshal relay status", "from", msg.GetFrom(), "data_len", len(msg.Data), "err", err)
|
||||
continue
|
||||
}
|
||||
if info.ID != msg.GetFrom() {
|
||||
slog.Error("Peer ID mismatch in relay status", "expected", info.ID, "actual", msg.GetFrom())
|
||||
continue
|
||||
}
|
||||
r.onPeerStatus(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- State Check Functions ---
|
||||
// hasConnectedPeer checks if peer is in map and has a valid connection
|
||||
func (r *Relay) hasConnectedPeer(peerID peer.ID) bool {
|
||||
if _, ok := r.LocalMeshPeers.Get(peerID); !ok {
|
||||
return false
|
||||
}
|
||||
if r.Host.Network().Connectedness(peerID) != network.Connected {
|
||||
slog.Debug("Peer not connected", "peer", peerID)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// --- State Change Functions ---
|
||||
|
||||
// onPeerStatus updates the status of a peer based on received metrics, adding local perspective
|
||||
func (r *Relay) onPeerStatus(recvInfo RelayInfo) {
|
||||
r.LocalMeshPeers.Set(recvInfo.ID, &recvInfo)
|
||||
}
|
||||
|
||||
// onPeerConnected is called when a new peer connects to the relay
|
||||
func (r *Relay) onPeerConnected(peerID peer.ID) {
|
||||
// Add to local peer map
|
||||
r.LocalMeshPeers.Set(peerID, &RelayInfo{
|
||||
ID: peerID,
|
||||
})
|
||||
|
||||
slog.Info("Peer connected", "peer", peerID)
|
||||
|
||||
// Trigger immediate state exchange
|
||||
go func() {
|
||||
if err := r.publishRelayMetrics(context.Background()); err != nil {
|
||||
slog.Error("Failed to publish relay metrics on connect", "err", err)
|
||||
} else {
|
||||
if err = r.publishRoomStates(context.Background()); err != nil {
|
||||
slog.Error("Failed to publish room states on connect", "err", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// onPeerDisconnected marks a peer as disconnected in our status view and removes latency info
|
||||
func (r *Relay) onPeerDisconnected(peerID peer.ID) {
|
||||
slog.Info("Mesh peer disconnected, deleting from local peer map", "peer", peerID)
|
||||
// Remove peer from local mesh peers
|
||||
if r.LocalMeshPeers.Has(peerID) {
|
||||
r.LocalMeshPeers.Delete(peerID)
|
||||
}
|
||||
// Remove any rooms associated with this peer
|
||||
if r.MeshRooms.Has(peerID.String()) {
|
||||
r.MeshRooms.Delete(peerID.String())
|
||||
}
|
||||
// Remove any latencies associated with this peer
|
||||
if r.LocalMeshPeers.Has(peerID) {
|
||||
r.LocalMeshPeers.Delete(peerID)
|
||||
}
|
||||
|
||||
// TODO: If any rooms were routed through this peer, handle that case
|
||||
}
|
||||
|
||||
// updateMeshRoomStates merges received room states into the MeshRooms map
|
||||
// TODO: Wrap in another type with timestamp or another mechanism to avoid conflicts
|
||||
func (r *Relay) updateMeshRoomStates(peerID peer.ID, states []shared.RoomInfo) {
|
||||
for _, state := range states {
|
||||
if state.OwnerID == r.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
// If previously did not exist, but does now, request a connection if participants exist for our room
|
||||
existed := r.MeshRooms.Has(state.ID.String())
|
||||
if !existed {
|
||||
// Request connection to this peer if we have participants in our local room
|
||||
if room, ok := r.LocalRooms.Get(state.ID); ok {
|
||||
if room.Participants.Len() > 0 {
|
||||
slog.Debug("Got new remote room state, we locally have participants for, requesting stream", "room_name", room.Name, "peer", peerID)
|
||||
if err := r.StreamProtocol.RequestStream(context.Background(), room, peerID); err != nil {
|
||||
slog.Error("Failed to request stream for new remote room state", "room_name", room.Name, "peer", peerID, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.MeshRooms.Set(state.ID.String(), state)
|
||||
}
|
||||
}
|
||||
@@ -1,187 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"log/slog"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
gen "relay/internal/proto"
|
||||
)
|
||||
|
||||
func ParticipantHandler(participant *Participant, room *Room, relay *Relay) {
|
||||
onPCClose := func() {
|
||||
slog.Debug("Participant PeerConnection closed", "participant", participant.ID, "room", room.Name)
|
||||
room.removeParticipantByID(participant.ID)
|
||||
}
|
||||
|
||||
var err error
|
||||
participant.PeerConnection, err = common.CreatePeerConnection(onPCClose)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create participant PeerConnection", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Data channel settings
|
||||
settingOrdered := true
|
||||
settingMaxRetransmits := uint16(0)
|
||||
dc, err := participant.PeerConnection.CreateDataChannel("data", &webrtc.DataChannelInit{
|
||||
Ordered: &settingOrdered,
|
||||
MaxRetransmits: &settingMaxRetransmits,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to create data channel for participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
participant.DataChannel = connections.NewNestriDataChannel(dc)
|
||||
|
||||
// Register channel opening handling
|
||||
participant.DataChannel.RegisterOnOpen(func() {
|
||||
slog.Debug("DataChannel opened for participant", "participant", participant.ID, "room", room.Name)
|
||||
})
|
||||
|
||||
// Register channel closing handling
|
||||
participant.DataChannel.RegisterOnClose(func() {
|
||||
slog.Debug("DataChannel closed for participant", "participant", participant.ID, "room", room.Name)
|
||||
})
|
||||
|
||||
// Register text message handling
|
||||
participant.DataChannel.RegisterMessageCallback("input", func(data []byte) {
|
||||
ForwardParticipantDataChannelMessage(participant, room, data)
|
||||
})
|
||||
|
||||
participant.PeerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
if err := participant.WebSocket.SendICECandidateMessageWS(candidate.ToJSON()); err != nil {
|
||||
slog.Error("Failed to send ICE candidate to participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
}
|
||||
})
|
||||
|
||||
iceHolder := make([]webrtc.ICECandidateInit, 0)
|
||||
|
||||
// ICE callback
|
||||
participant.WebSocket.RegisterMessageCallback("ice", func(data []byte) {
|
||||
var iceMsg connections.MessageICECandidate
|
||||
if err = json.Unmarshal(data, &iceMsg); err != nil {
|
||||
slog.Error("Failed to decode ICE candidate message from participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
if participant.PeerConnection.RemoteDescription() != nil {
|
||||
if err = participant.PeerConnection.AddICECandidate(iceMsg.Candidate); err != nil {
|
||||
slog.Error("Failed to add ICE candidate for participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
}
|
||||
// Add held ICE candidates
|
||||
for _, heldCandidate := range iceHolder {
|
||||
if err = participant.PeerConnection.AddICECandidate(heldCandidate); err != nil {
|
||||
slog.Error("Failed to add held ICE candidate for participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
}
|
||||
}
|
||||
iceHolder = nil
|
||||
} else {
|
||||
iceHolder = append(iceHolder, iceMsg.Candidate)
|
||||
}
|
||||
})
|
||||
|
||||
// SDP answer callback
|
||||
participant.WebSocket.RegisterMessageCallback("sdp", func(data []byte) {
|
||||
var sdpMsg connections.MessageSDP
|
||||
if err = json.Unmarshal(data, &sdpMsg); err != nil {
|
||||
slog.Error("Failed to decode SDP message from participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
handleParticipantSDP(participant, sdpMsg)
|
||||
})
|
||||
|
||||
// Log callback
|
||||
participant.WebSocket.RegisterMessageCallback("log", func(data []byte) {
|
||||
var logMsg connections.MessageLog
|
||||
if err = json.Unmarshal(data, &logMsg); err != nil {
|
||||
slog.Error("Failed to decode log message from participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
// TODO: Handle log message sending to metrics server
|
||||
})
|
||||
|
||||
// Metrics callback
|
||||
participant.WebSocket.RegisterMessageCallback("metrics", func(data []byte) {
|
||||
// Ignore for now
|
||||
})
|
||||
|
||||
participant.WebSocket.RegisterOnClose(func() {
|
||||
slog.Debug("WebSocket closed for participant", "participant", participant.ID, "room", room.Name)
|
||||
// Remove from Room
|
||||
room.removeParticipantByID(participant.ID)
|
||||
})
|
||||
|
||||
slog.Info("Participant ready, sending OK answer", "participant", participant.ID, "room", room.Name)
|
||||
if err := participant.WebSocket.SendAnswerMessageWS(connections.AnswerOK); err != nil {
|
||||
slog.Error("Failed to send OK answer", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
}
|
||||
|
||||
// If room is online, also send offer
|
||||
if room.Online {
|
||||
if err = room.signalParticipantWithTracks(participant); err != nil {
|
||||
slog.Error("Failed to signal participant with tracks", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
}
|
||||
} else {
|
||||
active, provider := relay.IsRoomActive(room.ID)
|
||||
if active {
|
||||
slog.Debug("Room active remotely, requesting stream", "room", room.Name, "provider", provider)
|
||||
if _, err := relay.requestStream(context.Background(), room.Name, room.ID, provider); err != nil {
|
||||
slog.Error("Failed to request stream", "room", room.Name, "err", err)
|
||||
} else {
|
||||
slog.Debug("Stream requested successfully", "room", room.Name, "provider", provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SDP answer handler for participants
|
||||
func handleParticipantSDP(participant *Participant, answerMsg connections.MessageSDP) {
|
||||
// Get SDP offer
|
||||
sdpAnswer := answerMsg.SDP.SDP
|
||||
|
||||
// Set remote description
|
||||
err := participant.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeAnswer,
|
||||
SDP: sdpAnswer,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to set remote SDP answer for participant", "participant", participant.ID, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ForwardParticipantDataChannelMessage(participant *Participant, room *Room, data []byte) {
|
||||
// Debug mode: Add latency timestamp
|
||||
if common.GetFlags().Debug {
|
||||
var inputMsg gen.ProtoMessageInput
|
||||
if err := proto.Unmarshal(data, &inputMsg); err != nil {
|
||||
slog.Error("Failed to decode input message from participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
protoLat := inputMsg.GetMessageBase().GetLatency()
|
||||
if protoLat != nil {
|
||||
lat := common.LatencyTrackerFromProto(protoLat)
|
||||
lat.AddTimestamp("relay_to_node")
|
||||
protoLat = lat.ToProto()
|
||||
}
|
||||
if newData, err := proto.Marshal(&inputMsg); err != nil {
|
||||
slog.Error("Failed to marshal input message from participant", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
return
|
||||
} else {
|
||||
// Update data with the modified message
|
||||
data = newData
|
||||
}
|
||||
}
|
||||
|
||||
// Forward to local room DataChannel if it exists (e.g., local ingest)
|
||||
if room.DataChannel != nil {
|
||||
if err := room.DataChannel.SendBinary(data); err != nil {
|
||||
slog.Error("Failed to send input message to room", "participant", participant.ID, "room", room.Name, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,202 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/libp2p/go-reuseport"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var httpMux *http.ServeMux
|
||||
|
||||
func InitHTTPEndpoint(_ context.Context, ctxCancel context.CancelFunc) error {
|
||||
// Create HTTP mux which serves our WS endpoint
|
||||
httpMux = http.NewServeMux()
|
||||
|
||||
// Endpoints themselves
|
||||
httpMux.Handle("/", http.NotFoundHandler())
|
||||
// If control endpoint secret is set, enable the control endpoint
|
||||
if len(common.GetFlags().ControlSecret) > 0 {
|
||||
httpMux.HandleFunc("/api/control", corsAnyHandler(controlHandler))
|
||||
}
|
||||
// WS endpoint
|
||||
httpMux.HandleFunc("/api/ws/{roomName}", corsAnyHandler(wsHandler))
|
||||
|
||||
// Get our serving port
|
||||
port := common.GetFlags().EndpointPort
|
||||
tlsCert := common.GetFlags().TLSCert
|
||||
tlsKey := common.GetFlags().TLSKey
|
||||
|
||||
// Create re-usable listener port
|
||||
httpListener, err := reuseport.Listen("tcp", ":"+strconv.Itoa(port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TCP listener: %w", err)
|
||||
}
|
||||
|
||||
// Log and start the endpoint server
|
||||
if len(tlsCert) <= 0 && len(tlsKey) <= 0 {
|
||||
slog.Info("Starting HTTP endpoint server", "port", port)
|
||||
go func() {
|
||||
if err := http.Serve(httpListener, httpMux); err != nil {
|
||||
slog.Error("Failed to start HTTP server", "err", err)
|
||||
ctxCancel()
|
||||
}
|
||||
}()
|
||||
} else if len(tlsCert) > 0 && len(tlsKey) > 0 {
|
||||
slog.Info("Starting HTTPS endpoint server", "port", port)
|
||||
go func() {
|
||||
if err := http.ServeTLS(httpListener, httpMux, tlsCert, tlsKey); err != nil {
|
||||
slog.Error("Failed to start HTTPS server", "err", err)
|
||||
ctxCancel()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
return errors.New("no TLS certificate or TLS key provided")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// logHTTPError logs (if verbose) and sends an error code to requester
|
||||
func logHTTPError(w http.ResponseWriter, err string, code int) {
|
||||
if common.GetFlags().Verbose {
|
||||
slog.Error("HTTP error", "code", code, "message", err)
|
||||
}
|
||||
http.Error(w, err, code)
|
||||
}
|
||||
|
||||
// corsAnyHandler allows any origin to access the endpoint
|
||||
func corsAnyHandler(next func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
|
||||
return func(res http.ResponseWriter, req *http.Request) {
|
||||
// Allow all origins
|
||||
res.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
res.Header().Set("Access-Control-Allow-Methods", "*")
|
||||
res.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
|
||||
if req.Method != http.MethodOptions {
|
||||
next(res, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wsHandler is the handler for the /api/ws/{roomName} endpoint
|
||||
func wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Get given room name now
|
||||
roomName := r.PathValue("roomName")
|
||||
if len(roomName) <= 0 {
|
||||
logHTTPError(w, "no room name given", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rel := GetRelay()
|
||||
// Get or create room in any case
|
||||
room := rel.GetOrCreateRoom(roomName)
|
||||
|
||||
// Upgrade to WebSocket
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
wsConn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logHTTPError(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Create SafeWebSocket
|
||||
ws := connections.NewSafeWebSocket(wsConn)
|
||||
// Assign message handler for join request
|
||||
ws.RegisterMessageCallback("join", func(data []byte) {
|
||||
var joinMsg connections.MessageJoin
|
||||
if err = json.Unmarshal(data, &joinMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal join message", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("Join message", "room", room.Name, "joinerType", joinMsg.JoinerType)
|
||||
|
||||
// Handle join request, depending if it's from ingest/node or participant/client
|
||||
switch joinMsg.JoinerType {
|
||||
case connections.JoinerNode:
|
||||
// If room already online, send InUse answer
|
||||
if room.Online {
|
||||
if err = ws.SendAnswerMessageWS(connections.AnswerInUse); err != nil {
|
||||
slog.Error("Failed to send InUse answer to node", "room", room.Name, "err", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
room.AssignWebSocket(ws)
|
||||
go IngestHandler(room)
|
||||
case connections.JoinerClient:
|
||||
// Create participant and add to room regardless of online status
|
||||
participant := NewParticipant(ws)
|
||||
room.AddParticipant(participant)
|
||||
// If room not online, send Offline answer
|
||||
if !room.Online {
|
||||
if err = ws.SendAnswerMessageWS(connections.AnswerOffline); err != nil {
|
||||
slog.Error("Failed to send offline answer to participant", "room", room.Name, "err", err)
|
||||
}
|
||||
}
|
||||
go ParticipantHandler(participant, room, rel)
|
||||
default:
|
||||
slog.Error("Unknown joiner type", "joinerType", joinMsg.JoinerType)
|
||||
}
|
||||
|
||||
// Unregister ourselves, if something happens on the other side they should just reconnect?
|
||||
ws.UnregisterMessageCallback("join")
|
||||
})
|
||||
}
|
||||
|
||||
// controlMessage is the JSON struct for the control messages
|
||||
type controlMessage struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// controlHandler is the handler for the /api/control endpoint, for controlling this relay
|
||||
func controlHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Check for control secret in Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if len(authHeader) <= 0 || authHeader != common.GetFlags().ControlSecret {
|
||||
logHTTPError(w, "missing or invalid Authorization header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle CORS preflight request
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode the control message
|
||||
var msg controlMessage
|
||||
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
|
||||
logHTTPError(w, "failed to decode control message", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
//relay := GetRelay()
|
||||
switch msg.Type {
|
||||
case "join_mesh":
|
||||
// Join the mesh network, get relay address from msg.Value
|
||||
if len(msg.Value) <= 0 {
|
||||
logHTTPError(w, "missing relay address", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
if err := GetRelay().ConnectToRelay(ctx, msg.Value); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to connect: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write([]byte("Successfully connected to relay"))
|
||||
default:
|
||||
logHTTPError(w, "unknown control message type", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
@@ -1,217 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/pion/rtp"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"io"
|
||||
"log/slog"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func IngestHandler(room *Room) {
|
||||
relay := GetRelay()
|
||||
|
||||
// Callback for closing PeerConnection
|
||||
onPCClose := func() {
|
||||
slog.Debug("ingest PeerConnection closed", "room", room.Name)
|
||||
room.Online = false
|
||||
room.signalParticipantsOffline()
|
||||
relay.DeleteRoomIfEmpty(room)
|
||||
}
|
||||
|
||||
var err error
|
||||
room.PeerConnection, err = common.CreatePeerConnection(onPCClose)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create ingest PeerConnection", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
room.PeerConnection.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
localTrack, err := webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, remoteTrack.Kind().String(), fmt.Sprintf("nestri-%s-%s", room.Name, remoteTrack.Kind().String()))
|
||||
if err != nil {
|
||||
slog.Error("Failed to create local track for room", "room", room.Name, "kind", remoteTrack.Kind(), "err", err)
|
||||
return
|
||||
}
|
||||
slog.Debug("Received track for room", "room", room.Name, "kind", remoteTrack.Kind())
|
||||
|
||||
// Set track and let Room handle state
|
||||
room.SetTrack(remoteTrack.Kind(), localTrack)
|
||||
|
||||
// Prepare PlayoutDelayExtension so we don't need to recreate it for each packet
|
||||
playoutExt := &rtp.PlayoutDelayExtension{
|
||||
MinDelay: 0,
|
||||
MaxDelay: 0,
|
||||
}
|
||||
playoutPayload, err := playoutExt.Marshal()
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal PlayoutDelayExtension for room", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
rtpPacket, _, err := remoteTrack.ReadRTP()
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
slog.Error("Failed to read RTP from remote track for room", "room", room.Name, "err", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Use PlayoutDelayExtension for low latency, only for video tracks
|
||||
if err := rtpPacket.SetExtension(common.ExtensionMap[common.ExtensionPlayoutDelay], playoutPayload); err != nil {
|
||||
slog.Error("Failed to set PlayoutDelayExtension for room", "room", room.Name, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = localTrack.WriteRTP(rtpPacket)
|
||||
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
slog.Error("Failed to write RTP to local track for room", "room", room.Name, "err", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("Track closed for room", "room", room.Name, "kind", remoteTrack.Kind())
|
||||
|
||||
// Clear track when done
|
||||
room.SetTrack(remoteTrack.Kind(), nil)
|
||||
})
|
||||
|
||||
room.PeerConnection.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
room.DataChannel = connections.NewNestriDataChannel(dc)
|
||||
slog.Debug("Ingest received DataChannel for room", "room", room.Name)
|
||||
|
||||
room.DataChannel.RegisterOnOpen(func() {
|
||||
slog.Debug("ingest DataChannel opened for room", "room", room.Name)
|
||||
})
|
||||
|
||||
room.DataChannel.OnClose(func() {
|
||||
slog.Debug("ingest DataChannel closed for room", "room", room.Name)
|
||||
})
|
||||
|
||||
// We do not handle any messages from ingest via DataChannel yet
|
||||
})
|
||||
|
||||
room.PeerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
slog.Debug("ingest received ICECandidate for room", "room", room.Name)
|
||||
err = room.WebSocket.SendICECandidateMessageWS(candidate.ToJSON())
|
||||
if err != nil {
|
||||
slog.Error("Failed to send ICE candidate message to ingest for room", "room", room.Name, "err", err)
|
||||
}
|
||||
})
|
||||
|
||||
iceHolder := make([]webrtc.ICECandidateInit, 0)
|
||||
|
||||
// ICE callback
|
||||
room.WebSocket.RegisterMessageCallback("ice", func(data []byte) {
|
||||
var iceMsg connections.MessageICECandidate
|
||||
if err = json.Unmarshal(data, &iceMsg); err != nil {
|
||||
slog.Error("Failed to decode ICE candidate message from ingest for room", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
if room.PeerConnection != nil {
|
||||
if room.PeerConnection.RemoteDescription() != nil {
|
||||
if err = room.PeerConnection.AddICECandidate(iceMsg.Candidate); err != nil {
|
||||
slog.Error("Failed to add ICE candidate for room", "room", room.Name, "err", err)
|
||||
}
|
||||
for _, heldCandidate := range iceHolder {
|
||||
if err = room.PeerConnection.AddICECandidate(heldCandidate); err != nil {
|
||||
slog.Error("Failed to add held ICE candidate for room", "room", room.Name, "err", err)
|
||||
}
|
||||
}
|
||||
iceHolder = make([]webrtc.ICECandidateInit, 0)
|
||||
} else {
|
||||
iceHolder = append(iceHolder, iceMsg.Candidate)
|
||||
}
|
||||
} else {
|
||||
slog.Error("ICE candidate received but PeerConnection is nil for room", "room", room.Name)
|
||||
}
|
||||
})
|
||||
|
||||
// SDP offer callback
|
||||
room.WebSocket.RegisterMessageCallback("sdp", func(data []byte) {
|
||||
var sdpMsg connections.MessageSDP
|
||||
if err = json.Unmarshal(data, &sdpMsg); err != nil {
|
||||
slog.Error("Failed to decode SDP message from ingest for room", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
answer := handleIngestSDP(room, sdpMsg)
|
||||
if answer != nil {
|
||||
if err = room.WebSocket.SendSDPMessageWS(*answer); err != nil {
|
||||
slog.Error("Failed to send SDP answer message to ingest for room", "room", room.Name, "err", err)
|
||||
}
|
||||
} else {
|
||||
slog.Error("Failed to handle ingest SDP message for room", "room", room.Name)
|
||||
}
|
||||
})
|
||||
|
||||
// Log callback
|
||||
room.WebSocket.RegisterMessageCallback("log", func(data []byte) {
|
||||
var logMsg connections.MessageLog
|
||||
if err = json.Unmarshal(data, &logMsg); err != nil {
|
||||
slog.Error("Failed to decode log message from ingest for room", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
// TODO: Handle log message sending to metrics server
|
||||
})
|
||||
|
||||
// Metrics callback
|
||||
room.WebSocket.RegisterMessageCallback("metrics", func(data []byte) {
|
||||
var metricsMsg connections.MessageMetrics
|
||||
if err = json.Unmarshal(data, &metricsMsg); err != nil {
|
||||
slog.Error("Failed to decode metrics message from ingest for room", "room", room.Name, "err", err)
|
||||
return
|
||||
}
|
||||
// TODO: Handle metrics message sending to metrics server
|
||||
})
|
||||
|
||||
room.WebSocket.RegisterOnClose(func() {
|
||||
slog.Debug("ingest WebSocket closed for room", "room", room.Name)
|
||||
room.Online = false
|
||||
room.signalParticipantsOffline()
|
||||
relay.DeleteRoomIfEmpty(room)
|
||||
})
|
||||
|
||||
slog.Info("Room is ready, sending OK answer to ingest", "room", room.Name)
|
||||
if err = room.WebSocket.SendAnswerMessageWS(connections.AnswerOK); err != nil {
|
||||
slog.Error("Failed to send OK answer message to ingest for room", "room", room.Name, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SDP offer handler, returns SDP answer
|
||||
func handleIngestSDP(room *Room, offerMsg connections.MessageSDP) *webrtc.SessionDescription {
|
||||
var err error
|
||||
|
||||
sdpOffer := offerMsg.SDP.SDP
|
||||
sdpOffer = strings.Replace(sdpOffer, ";sprop-maxcapturerate=24000", "", -1)
|
||||
|
||||
err = room.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
SDP: sdpOffer,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to set remote description for room", "room", room.Name, "err", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
answer, err := room.PeerConnection.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create SDP answer for room", "room", room.Name, "err", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = room.PeerConnection.SetLocalDescription(answer)
|
||||
if err != nil {
|
||||
slog.Error("Failed to set local description for room", "room", room.Name, "err", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &answer
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
)
|
||||
|
||||
type Participant struct {
|
||||
ID ulid.ULID //< Internal IDs are useful to keeping unique internal track and not have conflicts later
|
||||
Name string
|
||||
WebSocket *connections.SafeWebSocket
|
||||
PeerConnection *webrtc.PeerConnection
|
||||
DataChannel *connections.NestriDataChannel
|
||||
}
|
||||
|
||||
func NewParticipant(ws *connections.SafeWebSocket) *Participant {
|
||||
id, err := common.NewULID()
|
||||
if err != nil {
|
||||
slog.Error("Failed to create ULID for Participant", "err", err)
|
||||
return nil
|
||||
}
|
||||
return &Participant{
|
||||
ID: id,
|
||||
Name: createRandomName(),
|
||||
WebSocket: ws,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Participant) addTrack(trackLocal *webrtc.TrackLocalStaticRTP) error {
|
||||
rtpSender, err := p.PeerConnection.AddTrack(trackLocal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
rtcpBuffer := make([]byte, 1400)
|
||||
for {
|
||||
if _, _, rtcpErr := rtpSender.Read(rtcpBuffer); rtcpErr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Participant) signalOffer() error {
|
||||
if p.PeerConnection == nil {
|
||||
return fmt.Errorf("peer connection is nil for participant: '%s' - cannot signal offer", p.ID)
|
||||
}
|
||||
|
||||
offer, err := p.PeerConnection.CreateOffer(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.PeerConnection.SetLocalDescription(offer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.WebSocket.SendSDPMessageWS(offer)
|
||||
}
|
||||
|
||||
var namesFirst = []string{"Happy", "Sad", "Angry", "Calm", "Excited", "Bored", "Confused", "Confident", "Curious", "Depressed", "Disappointed", "Embarrassed", "Energetic", "Fearful", "Frustrated", "Glad", "Guilty", "Hopeful", "Impatient", "Jealous", "Lonely", "Motivated", "Nervous", "Optimistic", "Pessimistic", "Proud", "Relaxed", "Shy", "Stressed", "Surprised", "Tired", "Worried"}
|
||||
var namesSecond = []string{"Dragon", "Unicorn", "Troll", "Goblin", "Elf", "Dwarf", "Ogre", "Gnome", "Mermaid", "Siren", "Vampire", "Ghoul", "Werewolf", "Minotaur", "Centaur", "Griffin", "Phoenix", "Wyvern", "Hydra", "Kraken"}
|
||||
|
||||
func createRandomName() string {
|
||||
randomFirst := namesFirst[rand.Intn(len(namesFirst))]
|
||||
randomSecond := namesSecond[rand.Intn(len(namesSecond))]
|
||||
return randomFirst + " " + randomSecond
|
||||
}
|
||||
@@ -1,702 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/libp2p/go-libp2p"
|
||||
"github.com/libp2p/go-libp2p-pubsub"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/pnet"
|
||||
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"io"
|
||||
"log/slog"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
)
|
||||
|
||||
var globalRelay *Relay
|
||||
|
||||
// networkNotifier logs connection events
|
||||
type networkNotifier struct{}
|
||||
|
||||
func (n *networkNotifier) Connected(net network.Network, conn network.Conn) {
|
||||
slog.Info("Peer connected", "local", conn.LocalPeer(), "remote", conn.RemotePeer())
|
||||
}
|
||||
func (n *networkNotifier) Disconnected(net network.Network, conn network.Conn) {
|
||||
slog.Info("Peer disconnected", "local", conn.LocalPeer(), "remote", conn.RemotePeer())
|
||||
}
|
||||
func (n *networkNotifier) Listen(net network.Network, addr multiaddr.Multiaddr) {}
|
||||
func (n *networkNotifier) ListenClose(net network.Network, addr multiaddr.Multiaddr) {}
|
||||
|
||||
type ICEMessage struct {
|
||||
PeerID string
|
||||
TargetID string
|
||||
RoomID ulid.ULID
|
||||
Candidate []byte
|
||||
}
|
||||
|
||||
type Relay struct {
|
||||
ID peer.ID
|
||||
Rooms *common.SafeMap[ulid.ULID, *Room]
|
||||
Host host.Host // libp2p host for peer-to-peer networking
|
||||
PubSub *pubsub.PubSub // PubSub for state synchronization
|
||||
MeshState *common.SafeMap[ulid.ULID, RoomInfo] // room ID -> state
|
||||
RelayPCs *common.SafeMap[ulid.ULID, *webrtc.PeerConnection] // room ID -> relay PeerConnection
|
||||
pubTopicState *pubsub.Topic // topic for room states
|
||||
pubTopicICECandidate *pubsub.Topic // topic for ICE candidates aimed to this relay
|
||||
}
|
||||
|
||||
func NewRelay(ctx context.Context, port int) (*Relay, error) {
|
||||
listenAddrs := []string{
|
||||
fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", port), // IPv4
|
||||
fmt.Sprintf("/ip6/::/tcp/%d", port), // IPv6
|
||||
}
|
||||
|
||||
// Use "testToken" as the pre-shared token for authentication
|
||||
// TODO: Give via flags, before PR commit
|
||||
token := "testToken"
|
||||
// Generate 32-byte PSK from the token using SHA-256
|
||||
shaToken := sha256.Sum256([]byte(token))
|
||||
tokenPSK := pnet.PSK(shaToken[:])
|
||||
|
||||
// Initialize libp2p host
|
||||
p2pHost, err := libp2p.New(
|
||||
libp2p.ListenAddrStrings(listenAddrs...),
|
||||
libp2p.Security(noise.ID, noise.New),
|
||||
libp2p.EnableRelay(),
|
||||
libp2p.EnableHolePunching(),
|
||||
libp2p.PrivateNetwork(tokenPSK),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create libp2p host for relay: %w", err)
|
||||
}
|
||||
|
||||
// Set up pubsub
|
||||
p2pPubsub, err := pubsub.NewGossipSub(ctx, p2pHost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create pubsub: %w", err)
|
||||
}
|
||||
|
||||
// Add network notifier to log connections
|
||||
p2pHost.Network().Notify(&networkNotifier{})
|
||||
|
||||
r := &Relay{
|
||||
ID: p2pHost.ID(),
|
||||
Host: p2pHost,
|
||||
PubSub: p2pPubsub,
|
||||
Rooms: common.NewSafeMap[ulid.ULID, *Room](),
|
||||
MeshState: common.NewSafeMap[ulid.ULID, RoomInfo](),
|
||||
RelayPCs: common.NewSafeMap[ulid.ULID, *webrtc.PeerConnection](),
|
||||
}
|
||||
|
||||
// Set up state synchronization and stream handling
|
||||
r.setupStateSync(ctx)
|
||||
r.setupStreamHandler()
|
||||
|
||||
slog.Info("Relay initialized", "id", r.ID, "addrs", p2pHost.Addrs())
|
||||
|
||||
peerInfo := peer.AddrInfo{
|
||||
ID: p2pHost.ID(),
|
||||
Addrs: p2pHost.Addrs(),
|
||||
}
|
||||
addrs, err := peer.AddrInfoToP2pAddrs(&peerInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert peer info to addresses: %w", err)
|
||||
}
|
||||
|
||||
slog.Debug("Connect with one of the following URLs below:")
|
||||
for _, addr := range addrs {
|
||||
slog.Debug(fmt.Sprintf("- %s", addr.String()))
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func InitRelay(ctx context.Context, ctxCancel context.CancelFunc, port int) error {
|
||||
var err error
|
||||
globalRelay, err = NewRelay(ctx, port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create relay: %w", err)
|
||||
}
|
||||
|
||||
if err := common.InitWebRTCAPI(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := InitHTTPEndpoint(ctx, ctxCancel); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("Relay initialized", "id", globalRelay.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetRelay() *Relay {
|
||||
return globalRelay
|
||||
}
|
||||
|
||||
func (r *Relay) GetRoomByID(id ulid.ULID) *Room {
|
||||
if room, ok := r.Rooms.Get(id); ok {
|
||||
return room
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Relay) GetOrCreateRoom(name string) *Room {
|
||||
if room := r.GetRoomByName(name); room != nil {
|
||||
return room
|
||||
}
|
||||
|
||||
id, err := common.NewULID()
|
||||
if err != nil {
|
||||
slog.Error("Failed to generate new ULID for room", "err", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
room := NewRoom(name, id, r.ID)
|
||||
room.Relay = r
|
||||
r.Rooms.Set(room.ID, room)
|
||||
|
||||
slog.Debug("Created new room", "name", name, "id", room.ID)
|
||||
return room
|
||||
}
|
||||
|
||||
func (r *Relay) DeleteRoomIfEmpty(room *Room) {
|
||||
participantCount := room.Participants.Len()
|
||||
if participantCount > 0 {
|
||||
slog.Debug("Room not empty, not deleting", "name", room.Name, "id", room.ID, "participants", participantCount)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a "tombstone" state for the room, this allows propagation of the room deletion
|
||||
tombstoneState := RoomInfo{
|
||||
ID: room.ID,
|
||||
Name: room.Name,
|
||||
Online: false,
|
||||
OwnerID: room.OwnerID,
|
||||
}
|
||||
|
||||
// Publish updated state to mesh
|
||||
if err := r.publishRoomState(context.Background(), tombstoneState); err != nil {
|
||||
slog.Error("Failed to publish room states on change", "room", room.Name, "err", err)
|
||||
}
|
||||
|
||||
slog.Info("Deleting room since empty and offline", "name", room.Name, "id", room.ID)
|
||||
r.Rooms.Delete(room.ID)
|
||||
}
|
||||
|
||||
func (r *Relay) setupStateSync(ctx context.Context) {
|
||||
var err error
|
||||
r.pubTopicState, err = r.PubSub.Join("room-states")
|
||||
if err != nil {
|
||||
slog.Error("Failed to join pubsub topic", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
sub, err := r.pubTopicState.Subscribe()
|
||||
if err != nil {
|
||||
slog.Error("Failed to subscribe to topic", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
r.pubTopicICECandidate, err = r.PubSub.Join("ice-candidates")
|
||||
if err != nil {
|
||||
slog.Error("Failed to join ICE candidates topic", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
iceCandidateSub, err := r.pubTopicICECandidate.Subscribe()
|
||||
if err != nil {
|
||||
slog.Error("Failed to subscribe to ICE candidates topic", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle state updates only from authenticated peers
|
||||
go func() {
|
||||
for {
|
||||
msg, err := sub.Next(ctx)
|
||||
if err != nil {
|
||||
slog.Error("Error receiving pubsub message", "err", err)
|
||||
return
|
||||
}
|
||||
if msg.GetFrom() == r.Host.ID() {
|
||||
continue // Ignore own messages
|
||||
}
|
||||
var states []RoomInfo
|
||||
if err := json.Unmarshal(msg.Data, &states); err != nil {
|
||||
slog.Error("Failed to unmarshal room states", "err", err)
|
||||
continue
|
||||
}
|
||||
r.updateMeshState(states)
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle incoming ICE candidates for given room
|
||||
go func() {
|
||||
// Map of ICE candidate slices per room ID
|
||||
iceHolder := make(map[ulid.ULID][]webrtc.ICECandidateInit)
|
||||
|
||||
for {
|
||||
msg, err := iceCandidateSub.Next(ctx)
|
||||
if err != nil {
|
||||
slog.Error("Error receiving ICE candidate message", "err", err)
|
||||
return
|
||||
}
|
||||
if msg.GetFrom() == r.Host.ID() {
|
||||
continue // Ignore own messages
|
||||
}
|
||||
|
||||
var iceMsg ICEMessage
|
||||
if err := json.Unmarshal(msg.Data, &iceMsg); err != nil {
|
||||
slog.Error("Failed to unmarshal ICE candidate message", "err", err)
|
||||
continue
|
||||
}
|
||||
if iceMsg.TargetID != r.ID.String() {
|
||||
continue // Ignore messages not meant for this relay
|
||||
}
|
||||
|
||||
if iceHolder[iceMsg.RoomID] == nil {
|
||||
iceHolder[iceMsg.RoomID] = make([]webrtc.ICECandidateInit, 0)
|
||||
}
|
||||
|
||||
if pc, ok := r.RelayPCs.Get(iceMsg.RoomID); ok {
|
||||
// Unmarshal ice candidate
|
||||
var candidate webrtc.ICECandidateInit
|
||||
if err := json.Unmarshal(iceMsg.Candidate, &candidate); err != nil {
|
||||
slog.Error("Failed to unmarshal ICE candidate", "err", err)
|
||||
continue
|
||||
}
|
||||
if pc.RemoteDescription() != nil {
|
||||
if err := pc.AddICECandidate(candidate); err != nil {
|
||||
slog.Error("Failed to add ICE candidate", "err", err)
|
||||
}
|
||||
// Add any held candidates
|
||||
for _, heldCandidate := range iceHolder[iceMsg.RoomID] {
|
||||
if err := pc.AddICECandidate(heldCandidate); err != nil {
|
||||
slog.Error("Failed to add held ICE candidate", "err", err)
|
||||
}
|
||||
}
|
||||
iceHolder[iceMsg.RoomID] = make([]webrtc.ICECandidateInit, 0)
|
||||
} else {
|
||||
iceHolder[iceMsg.RoomID] = append(iceHolder[iceMsg.RoomID], candidate)
|
||||
}
|
||||
} else {
|
||||
slog.Error("PeerConnection for room not found when adding ICE candidate", "roomID", iceMsg.RoomID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *Relay) publishRoomState(ctx context.Context, state RoomInfo) error {
|
||||
data, err := json.Marshal([]RoomInfo{state})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.pubTopicState.Publish(ctx, data)
|
||||
}
|
||||
|
||||
func (r *Relay) publishRoomStates(ctx context.Context) error {
|
||||
var states []RoomInfo
|
||||
for _, room := range r.Rooms.Copy() {
|
||||
states = append(states, RoomInfo{
|
||||
ID: room.ID,
|
||||
Name: room.Name,
|
||||
Online: room.Online,
|
||||
OwnerID: r.ID,
|
||||
})
|
||||
}
|
||||
data, err := json.Marshal(states)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.pubTopicState.Publish(ctx, data)
|
||||
}
|
||||
|
||||
func (r *Relay) updateMeshState(states []RoomInfo) {
|
||||
for _, state := range states {
|
||||
if state.OwnerID == r.ID {
|
||||
continue // Skip own state
|
||||
}
|
||||
existing, exists := r.MeshState.Get(state.ID)
|
||||
r.MeshState.Set(state.ID, state)
|
||||
slog.Debug("Updated mesh state", "room", state.Name, "online", state.Online, "owner", state.OwnerID)
|
||||
|
||||
// React to state changes
|
||||
if !exists || existing.Online != state.Online {
|
||||
room := r.GetRoomByName(state.Name)
|
||||
if state.Online {
|
||||
if room == nil || !room.Online {
|
||||
slog.Info("Room became active remotely, requesting stream", "room", state.Name, "owner", state.OwnerID)
|
||||
go func() {
|
||||
if _, err := r.requestStream(context.Background(), state.Name, state.ID, state.OwnerID); err != nil {
|
||||
slog.Error("Failed to request stream", "room", state.Name, "err", err)
|
||||
} else {
|
||||
slog.Info("Successfully requested stream", "room", state.Name, "owner", state.OwnerID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
} else if room != nil && room.Online {
|
||||
slog.Info("Room became inactive remotely, stopping local stream", "room", state.Name)
|
||||
if pc, ok := r.RelayPCs.Get(state.ID); ok {
|
||||
_ = pc.Close()
|
||||
r.RelayPCs.Delete(state.ID)
|
||||
}
|
||||
room.Online = false
|
||||
room.signalParticipantsOffline()
|
||||
} else if room == nil && !exists {
|
||||
slog.Info("Received tombstone state for room", "name", state.Name, "id", state.ID)
|
||||
if pc, ok := r.RelayPCs.Get(state.ID); ok {
|
||||
_ = pc.Close()
|
||||
r.RelayPCs.Delete(state.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) IsRoomActive(roomID ulid.ULID) (bool, peer.ID) {
|
||||
if state, exists := r.MeshState.Get(roomID); exists && state.Online {
|
||||
return true, state.OwnerID
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func (r *Relay) GetRoomByName(name string) *Room {
|
||||
for _, room := range r.Rooms.Copy() {
|
||||
if room.Name == name {
|
||||
return room
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeMessage(stream network.Stream, data []byte) error {
|
||||
length := uint32(len(data))
|
||||
if err := binary.Write(stream, binary.BigEndian, length); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := stream.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func readMessage(stream network.Stream) ([]byte, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(stream, binary.BigEndian, &length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]byte, length)
|
||||
_, err := io.ReadFull(stream, data)
|
||||
return data, err
|
||||
}
|
||||
|
||||
func (r *Relay) setupStreamHandler() {
|
||||
r.Host.SetStreamHandler("/nestri-relay/stream/1.0.0", func(stream network.Stream) {
|
||||
defer func(stream network.Stream) {
|
||||
err := stream.Close()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close stream", "err", err)
|
||||
}
|
||||
}(stream)
|
||||
remotePeer := stream.Conn().RemotePeer()
|
||||
|
||||
roomNameData, err := readMessage(stream)
|
||||
if err != nil && err != io.EOF {
|
||||
slog.Error("Failed to read room name", "peer", remotePeer, "err", err)
|
||||
return
|
||||
}
|
||||
roomName := string(roomNameData)
|
||||
|
||||
slog.Info("Stream request from peer", "peer", remotePeer, "room", roomName)
|
||||
|
||||
room := r.GetRoomByName(roomName)
|
||||
if room == nil || !room.Online {
|
||||
slog.Error("Cannot provide stream for inactive room", "room", roomName)
|
||||
return
|
||||
}
|
||||
|
||||
pc, err := common.CreatePeerConnection(func() {
|
||||
r.RelayPCs.Delete(room.ID)
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to create relay PeerConnection", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
r.RelayPCs.Set(room.ID, pc)
|
||||
|
||||
if room.AudioTrack != nil {
|
||||
_, err := pc.AddTrack(room.AudioTrack)
|
||||
if err != nil {
|
||||
slog.Error("Failed to add audio track", "err", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if room.VideoTrack != nil {
|
||||
_, err := pc.AddTrack(room.VideoTrack)
|
||||
if err != nil {
|
||||
slog.Error("Failed to add video track", "err", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
settingOrdered := true
|
||||
settingMaxRetransmits := uint16(0)
|
||||
dc, err := pc.CreateDataChannel("relay-data", &webrtc.DataChannelInit{
|
||||
Ordered: &settingOrdered,
|
||||
MaxRetransmits: &settingMaxRetransmits,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to create relay DataChannel", "err", err)
|
||||
return
|
||||
}
|
||||
relayDC := connections.NewNestriDataChannel(dc)
|
||||
|
||||
relayDC.RegisterOnOpen(func() {
|
||||
slog.Debug("Relay DataChannel opened", "room", roomName)
|
||||
})
|
||||
|
||||
relayDC.RegisterOnClose(func() {
|
||||
slog.Debug("Relay DataChannel closed", "room", roomName)
|
||||
})
|
||||
|
||||
relayDC.RegisterMessageCallback("input", func(data []byte) {
|
||||
if room.DataChannel != nil {
|
||||
// Forward message to the room's data channel
|
||||
if err := room.DataChannel.SendBinary(data); err != nil {
|
||||
slog.Error("Failed to send DataChannel message", "room", roomName, "err", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
offer, err := pc.CreateOffer(nil)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create offer", "err", err)
|
||||
return
|
||||
}
|
||||
if err := pc.SetLocalDescription(offer); err != nil {
|
||||
slog.Error("Failed to set local description", "err", err)
|
||||
return
|
||||
}
|
||||
offerData, err := json.Marshal(offer)
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal offer", "err", err)
|
||||
return
|
||||
}
|
||||
if err := writeMessage(stream, offerData); err != nil {
|
||||
slog.Error("Failed to send offer", "peer", remotePeer, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle our generated ICE candidates
|
||||
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
candidateData, err := json.Marshal(candidate.ToJSON())
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal ICE candidate", "err", err)
|
||||
return
|
||||
}
|
||||
iceMsg := ICEMessage{
|
||||
PeerID: r.Host.ID().String(),
|
||||
TargetID: remotePeer.String(),
|
||||
RoomID: room.ID,
|
||||
Candidate: candidateData,
|
||||
}
|
||||
data, err := json.Marshal(iceMsg)
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal ICE message", "err", err)
|
||||
return
|
||||
}
|
||||
if err := r.pubTopicICECandidate.Publish(context.Background(), data); err != nil {
|
||||
slog.Error("Failed to publish ICE candidate message", "err", err)
|
||||
}
|
||||
})
|
||||
|
||||
answerData, err := readMessage(stream)
|
||||
if err != nil && err != io.EOF {
|
||||
slog.Error("Failed to read answer", "peer", remotePeer, "err", err)
|
||||
return
|
||||
}
|
||||
var answer webrtc.SessionDescription
|
||||
if err := json.Unmarshal(answerData, &answer); err != nil {
|
||||
slog.Error("Failed to unmarshal answer", "err", err)
|
||||
return
|
||||
}
|
||||
if err := pc.SetRemoteDescription(answer); err != nil {
|
||||
slog.Error("Failed to set remote description", "err", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Relay) requestStream(ctx context.Context, roomName string, roomID ulid.ULID, providerPeer peer.ID) (*webrtc.PeerConnection, error) {
|
||||
stream, err := r.Host.NewStream(ctx, providerPeer, "/nestri-relay/stream/1.0.0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stream: %w", err)
|
||||
}
|
||||
defer func(stream network.Stream) {
|
||||
err := stream.Close()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close stream", "err", err)
|
||||
}
|
||||
}(stream)
|
||||
|
||||
if err := writeMessage(stream, []byte(roomName)); err != nil {
|
||||
return nil, fmt.Errorf("failed to send room name: %w", err)
|
||||
}
|
||||
|
||||
room := r.GetRoomByName(roomName)
|
||||
if room == nil {
|
||||
room = NewRoom(roomName, roomID, providerPeer)
|
||||
r.Rooms.Set(roomID, room)
|
||||
} else if room.ID != roomID {
|
||||
// Mismatch, prefer the one from the provider
|
||||
// TODO: When mesh is created, if there are mismatches, we should have relays negotiate common room IDs
|
||||
room.ID = roomID
|
||||
room.OwnerID = providerPeer
|
||||
r.Rooms.Set(roomID, room)
|
||||
}
|
||||
|
||||
pc, err := common.CreatePeerConnection(func() {
|
||||
r.RelayPCs.Delete(roomID)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create PeerConnection: %w", err)
|
||||
}
|
||||
|
||||
r.RelayPCs.Set(roomID, pc)
|
||||
|
||||
offerData, err := readMessage(stream)
|
||||
if err != nil && err != io.EOF {
|
||||
return nil, fmt.Errorf("failed to read offer: %w", err)
|
||||
}
|
||||
var offer webrtc.SessionDescription
|
||||
if err := json.Unmarshal(offerData, &offer); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal offer: %w", err)
|
||||
}
|
||||
if err := pc.SetRemoteDescription(offer); err != nil {
|
||||
return nil, fmt.Errorf("failed to set remote description: %w", err)
|
||||
}
|
||||
|
||||
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
localTrack, _ := webrtc.NewTrackLocalStaticRTP(track.Codec().RTPCodecCapability, track.ID(), "relay-"+roomName+"-"+track.Kind().String())
|
||||
slog.Debug("Received track for mesh relay room", "room", roomName, "kind", track.Kind())
|
||||
|
||||
room.SetTrack(track.Kind(), localTrack)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
rtpPacket, _, err := track.ReadRTP()
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
slog.Error("Failed to read RTP packet from remote track for room", "room", roomName, "err", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
err = localTrack.WriteRTP(rtpPacket)
|
||||
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
slog.Error("Failed to write RTP to local track for room", "room", room.Name, "err", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// ICE candidate handling
|
||||
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
candidateData, err := json.Marshal(candidate.ToJSON())
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal ICE candidate", "err", err)
|
||||
return
|
||||
}
|
||||
iceMsg := ICEMessage{
|
||||
PeerID: r.Host.ID().String(),
|
||||
TargetID: providerPeer.String(),
|
||||
RoomID: roomID,
|
||||
Candidate: candidateData,
|
||||
}
|
||||
data, err := json.Marshal(iceMsg)
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal ICE message", "err", err)
|
||||
return
|
||||
}
|
||||
if err := r.pubTopicICECandidate.Publish(ctx, data); err != nil {
|
||||
slog.Error("Failed to publish ICE candidate message", "err", err)
|
||||
}
|
||||
})
|
||||
|
||||
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
relayDC := connections.NewNestriDataChannel(dc)
|
||||
slog.Debug("Received DataChannel from peer", "room", roomName)
|
||||
|
||||
relayDC.RegisterOnOpen(func() {
|
||||
slog.Debug("Relay DataChannel opened", "room", roomName)
|
||||
})
|
||||
|
||||
relayDC.OnClose(func() {
|
||||
slog.Debug("Relay DataChannel closed", "room", roomName)
|
||||
})
|
||||
|
||||
// Override room DataChannel with the mesh-relay one to forward messages
|
||||
room.DataChannel = relayDC
|
||||
})
|
||||
|
||||
answer, err := pc.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create answer: %w", err)
|
||||
}
|
||||
if err := pc.SetLocalDescription(answer); err != nil {
|
||||
return nil, fmt.Errorf("failed to set local description: %w", err)
|
||||
}
|
||||
answerData, err := json.Marshal(answer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal answer: %w", err)
|
||||
}
|
||||
if err := writeMessage(stream, answerData); err != nil {
|
||||
return nil, fmt.Errorf("failed to send answer: %w", err)
|
||||
}
|
||||
|
||||
return pc, nil
|
||||
}
|
||||
|
||||
// ConnectToRelay manually connects to another relay by its multiaddress
|
||||
func (r *Relay) ConnectToRelay(ctx context.Context, addr string) error {
|
||||
// Parse the multiaddress
|
||||
ma, err := multiaddr.NewMultiaddr(addr)
|
||||
if err != nil {
|
||||
slog.Error("Invalid multiaddress", "addr", addr, "err", err)
|
||||
return fmt.Errorf("invalid multiaddress: %w", err)
|
||||
}
|
||||
|
||||
// Extract peer ID from multiaddress
|
||||
peerInfo, err := peer.AddrInfoFromP2pAddr(ma)
|
||||
if err != nil {
|
||||
slog.Error("Failed to extract peer info", "addr", addr, "err", err)
|
||||
return fmt.Errorf("failed to extract peer info: %w", err)
|
||||
}
|
||||
|
||||
// Connect to the peer
|
||||
if err := r.Host.Connect(ctx, *peerInfo); err != nil {
|
||||
slog.Error("Failed to connect to peer", "peer", peerInfo.ID, "addr", addr, "err", err)
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
// Publish challenge on join
|
||||
//go r.sendAuthChallenge(ctx)
|
||||
|
||||
slog.Info("Successfully connected to peer", "peer", peerInfo.ID, "addr", addr)
|
||||
return nil
|
||||
}
|
||||
44
packages/relay/internal/shared/participant.go
Normal file
44
packages/relay/internal/shared/participant.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
type Participant struct {
|
||||
ID ulid.ULID
|
||||
PeerConnection *webrtc.PeerConnection
|
||||
DataChannel *connections.NestriDataChannel
|
||||
}
|
||||
|
||||
func NewParticipant() (*Participant, error) {
|
||||
id, err := common.NewULID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ULID for Participant: %w", err)
|
||||
}
|
||||
return &Participant{
|
||||
ID: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Participant) addTrack(trackLocal *webrtc.TrackLocalStaticRTP) error {
|
||||
rtpSender, err := p.PeerConnection.AddTrack(trackLocal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
rtcpBuffer := make([]byte, 1400)
|
||||
for {
|
||||
if _, _, rtcpErr := rtpSender.Read(rtcpBuffer); rtcpErr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,32 +1,28 @@
|
||||
package internal
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log/slog"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
type RoomInfo struct {
|
||||
ID ulid.ULID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Online bool `json:"online"`
|
||||
OwnerID peer.ID `json:"owner_id"`
|
||||
}
|
||||
|
||||
type Room struct {
|
||||
RoomInfo
|
||||
WebSocket *connections.SafeWebSocket
|
||||
PeerConnection *webrtc.PeerConnection
|
||||
AudioTrack *webrtc.TrackLocalStaticRTP
|
||||
VideoTrack *webrtc.TrackLocalStaticRTP
|
||||
DataChannel *connections.NestriDataChannel
|
||||
Participants *common.SafeMap[ulid.ULID, *Participant]
|
||||
Relay *Relay
|
||||
}
|
||||
|
||||
func NewRoom(name string, roomID ulid.ULID, ownerID peer.ID) *Room {
|
||||
@@ -34,21 +30,12 @@ func NewRoom(name string, roomID ulid.ULID, ownerID peer.ID) *Room {
|
||||
RoomInfo: RoomInfo{
|
||||
ID: roomID,
|
||||
Name: name,
|
||||
Online: false,
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
Participants: common.NewSafeMap[ulid.ULID, *Participant](),
|
||||
}
|
||||
}
|
||||
|
||||
// AssignWebSocket assigns a WebSocket connection to a Room
|
||||
func (r *Room) AssignWebSocket(ws *connections.SafeWebSocket) {
|
||||
if r.WebSocket != nil {
|
||||
slog.Warn("WebSocket already assigned to room", "room", r.Name)
|
||||
}
|
||||
r.WebSocket = ws
|
||||
}
|
||||
|
||||
// AddParticipant adds a Participant to a Room
|
||||
func (r *Room) AddParticipant(participant *Participant) {
|
||||
slog.Debug("Adding participant to room", "participant", participant.ID, "room", r.Name)
|
||||
@@ -62,21 +49,8 @@ func (r *Room) removeParticipantByID(pID ulid.ULID) {
|
||||
}
|
||||
}
|
||||
|
||||
// Removes a Participant from a Room by participant's name
|
||||
func (r *Room) removeParticipantByName(pName string) {
|
||||
for id, participant := range r.Participants.Copy() {
|
||||
if participant.Name == pName {
|
||||
if err := r.signalParticipantOffline(participant); err != nil {
|
||||
slog.Error("Failed to signal participant offline", "participant", participant.ID, "room", r.Name, "err", err)
|
||||
}
|
||||
r.Participants.Delete(id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Removes all participants from a Room
|
||||
func (r *Room) removeAllParticipants() {
|
||||
/*func (r *Room) removeAllParticipants() {
|
||||
for id, participant := range r.Participants.Copy() {
|
||||
if err := r.signalParticipantOffline(participant); err != nil {
|
||||
slog.Error("Failed to signal participant offline", "participant", participant.ID, "room", r.Name, "err", err)
|
||||
@@ -84,24 +58,28 @@ func (r *Room) removeAllParticipants() {
|
||||
r.Participants.Delete(id)
|
||||
slog.Debug("Removed participant from room", "participant", id, "room", r.Name)
|
||||
}
|
||||
}*/
|
||||
|
||||
// IsOnline checks if the room is online (has both audio and video tracks)
|
||||
func (r *Room) IsOnline() bool {
|
||||
return r.AudioTrack != nil && r.VideoTrack != nil
|
||||
}
|
||||
|
||||
func (r *Room) SetTrack(trackType webrtc.RTPCodecType, track *webrtc.TrackLocalStaticRTP) {
|
||||
//oldOnline := r.IsOnline()
|
||||
|
||||
switch trackType {
|
||||
case webrtc.RTPCodecTypeAudio:
|
||||
r.AudioTrack = track
|
||||
slog.Debug("Audio track set", "room", r.Name, "track", track != nil)
|
||||
case webrtc.RTPCodecTypeVideo:
|
||||
r.VideoTrack = track
|
||||
slog.Debug("Video track set", "room", r.Name, "track", track != nil)
|
||||
default:
|
||||
slog.Warn("Unknown track type", "room", r.Name, "trackType", trackType)
|
||||
}
|
||||
|
||||
newOnline := r.AudioTrack != nil && r.VideoTrack != nil
|
||||
if r.Online != newOnline {
|
||||
r.Online = newOnline
|
||||
if r.Online {
|
||||
/*newOnline := r.IsOnline()
|
||||
if oldOnline != newOnline {
|
||||
if newOnline {
|
||||
slog.Debug("Room online, participants will be signaled", "room", r.Name)
|
||||
r.signalParticipantsWithTracks()
|
||||
} else {
|
||||
@@ -109,15 +87,16 @@ func (r *Room) SetTrack(trackType webrtc.RTPCodecType, track *webrtc.TrackLocalS
|
||||
r.signalParticipantsOffline()
|
||||
}
|
||||
|
||||
// Publish updated state to mesh
|
||||
// TODO: Publish updated state to mesh
|
||||
go func() {
|
||||
if err := r.Relay.publishRoomStates(context.Background()); err != nil {
|
||||
slog.Error("Failed to publish room states on change", "room", r.Name, "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}*/
|
||||
}
|
||||
|
||||
/* TODO: libp2p'ify
|
||||
func (r *Room) signalParticipantsWithTracks() {
|
||||
for _, participant := range r.Participants.Copy() {
|
||||
if err := r.signalParticipantWithTracks(participant); err != nil {
|
||||
@@ -162,3 +141,4 @@ func (r *Room) signalParticipantOffline(participant *Participant) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
Reference in New Issue
Block a user