mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 08:45:38 +02:00
✨ feat: Add streaming support (#125)
This adds: - [x] Keyboard and mouse handling on the frontend - [x] Video and audio streaming from the backend to the frontend - [x] Input server that works with Websockets Update - 17/11 - [ ] Master docker container to run this - [ ] Steam runtime - [ ] Entrypoint.sh --------- Co-authored-by: Kristian Ollikainen <14197772+DatCaptainHorse@users.noreply.github.com> Co-authored-by: Kristian Ollikainen <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
100
packages/relay/internal/common.go
Normal file
100
packages/relay/internal/common.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log"
|
||||
)
|
||||
|
||||
var globalWebRTCAPI *webrtc.API
|
||||
var globalWebRTCConfig = webrtc.Configuration{
|
||||
ICETransportPolicy: webrtc.ICETransportPolicyAll,
|
||||
BundlePolicy: webrtc.BundlePolicyBalanced,
|
||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
|
||||
}
|
||||
|
||||
func InitWebRTCAPI() error {
|
||||
var err error
|
||||
flags := GetFlags()
|
||||
|
||||
// Media engine
|
||||
mediaEngine := &webrtc.MediaEngine{}
|
||||
|
||||
// Default codecs cover most of our needs
|
||||
err = mediaEngine.RegisterDefaultCodecs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add H.265 for special cases
|
||||
videoRTCPFeedback := []webrtc.RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}}
|
||||
for _, codec := range []webrtc.RTPCodecParameters{
|
||||
{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH265, ClockRate: 90000, RTCPFeedback: videoRTCPFeedback},
|
||||
PayloadType: 48,
|
||||
},
|
||||
{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeRTX, ClockRate: 90000, SDPFmtpLine: "apt=48"},
|
||||
PayloadType: 49,
|
||||
},
|
||||
} {
|
||||
if err := mediaEngine.RegisterCodec(codec, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Interceptor registry
|
||||
interceptorRegistry := &interceptor.Registry{}
|
||||
|
||||
// Use default set
|
||||
err = webrtc.RegisterDefaultInterceptors(mediaEngine, interceptorRegistry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Setting engine
|
||||
settingEngine := webrtc.SettingEngine{}
|
||||
|
||||
// New in v4, reduces CPU usage and latency when enabled
|
||||
settingEngine.EnableSCTPZeroChecksum(true)
|
||||
|
||||
// Set the UDP port range used by WebRTC
|
||||
err = settingEngine.SetEphemeralUDPPortRange(uint16(flags.WebRTCUDPStart), uint16(flags.WebRTCUDPEnd))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a new API object with our customized settings
|
||||
globalWebRTCAPI = webrtc.NewAPI(webrtc.WithMediaEngine(mediaEngine), webrtc.WithSettingEngine(settingEngine), webrtc.WithInterceptorRegistry(interceptorRegistry))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWebRTCAPI returns the global WebRTC API
|
||||
func GetWebRTCAPI() *webrtc.API {
|
||||
return globalWebRTCAPI
|
||||
}
|
||||
|
||||
// CreatePeerConnection sets up a new peer connection
|
||||
func CreatePeerConnection(onClose func()) (*webrtc.PeerConnection, error) {
|
||||
pc, err := globalWebRTCAPI.NewPeerConnection(globalWebRTCConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Log connection state changes and handle failed/disconnected connections
|
||||
pc.OnConnectionStateChange(func(connectionState webrtc.PeerConnectionState) {
|
||||
// Close PeerConnection in cases
|
||||
if connectionState == webrtc.PeerConnectionStateFailed ||
|
||||
connectionState == webrtc.PeerConnectionStateDisconnected ||
|
||||
connectionState == webrtc.PeerConnectionStateClosed {
|
||||
err := pc.Close()
|
||||
if err != nil {
|
||||
log.Printf("Error closing PeerConnection: %s\n", err.Error())
|
||||
}
|
||||
onClose()
|
||||
}
|
||||
})
|
||||
|
||||
return pc, nil
|
||||
}
|
||||
72
packages/relay/internal/datachannel.go
Normal file
72
packages/relay/internal/datachannel.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log"
|
||||
)
|
||||
|
||||
// NestriDataChannel is a custom data channel with callbacks
|
||||
type NestriDataChannel struct {
|
||||
*webrtc.DataChannel
|
||||
binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback
|
||||
}
|
||||
|
||||
// NewNestriDataChannel creates a new NestriDataChannel from *webrtc.DataChannel
|
||||
func NewNestriDataChannel(dc *webrtc.DataChannel) *NestriDataChannel {
|
||||
ndc := &NestriDataChannel{
|
||||
DataChannel: dc,
|
||||
binaryCallbacks: make(map[string]OnMessageCallback),
|
||||
}
|
||||
|
||||
// Handler for incoming messages
|
||||
ndc.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||
// If string type message, ignore
|
||||
if msg.IsString {
|
||||
return
|
||||
}
|
||||
|
||||
// Decode message
|
||||
var base MessageBase
|
||||
if err := DecodeMessage(msg.Data, &base); err != nil {
|
||||
log.Printf("Failed to decode binary DataChannel message, reason: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle message type callback
|
||||
if callback, ok := ndc.binaryCallbacks[base.PayloadType]; ok {
|
||||
go callback(msg.Data)
|
||||
} // TODO: Log unknown message type?
|
||||
})
|
||||
|
||||
return ndc
|
||||
}
|
||||
|
||||
// SendBinary sends a binary message to the data channel
|
||||
func (ndc *NestriDataChannel) SendBinary(data []byte) error {
|
||||
return ndc.Send(data)
|
||||
}
|
||||
|
||||
// RegisterMessageCallback registers a callback for a given binary message type
|
||||
func (ndc *NestriDataChannel) RegisterMessageCallback(msgType string, callback OnMessageCallback) {
|
||||
if ndc.binaryCallbacks == nil {
|
||||
ndc.binaryCallbacks = make(map[string]OnMessageCallback)
|
||||
}
|
||||
ndc.binaryCallbacks[msgType] = callback
|
||||
}
|
||||
|
||||
// UnregisterMessageCallback removes the callback for a given binary message type
|
||||
func (ndc *NestriDataChannel) UnregisterMessageCallback(msgType string) {
|
||||
if ndc.binaryCallbacks != nil {
|
||||
delete(ndc.binaryCallbacks, msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterOnOpen registers a callback for the data channel opening
|
||||
func (ndc *NestriDataChannel) RegisterOnOpen(callback func()) {
|
||||
ndc.OnOpen(callback)
|
||||
}
|
||||
|
||||
// RegisterOnClose registers a callback for the data channel closing
|
||||
func (ndc *NestriDataChannel) RegisterOnClose(callback func()) {
|
||||
ndc.OnClose(callback)
|
||||
}
|
||||
189
packages/relay/internal/egress.go
Normal file
189
packages/relay/internal/egress.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log"
|
||||
)
|
||||
|
||||
func participantHandler(participant *Participant, room *Room) {
|
||||
// Callback for closing PeerConnection
|
||||
onPCClose := func() {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("Closed PeerConnection for participant: '%s'\n", participant.ID)
|
||||
}
|
||||
room.removeParticipantByID(participant.ID)
|
||||
}
|
||||
|
||||
var err error
|
||||
participant.PeerConnection, err = CreatePeerConnection(onPCClose)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create PeerConnection for participant: '%s' - reason: %s\n", participant.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Data channel settings
|
||||
settingOrdered := false
|
||||
settingMaxRetransmits := uint16(0)
|
||||
dc, err := participant.PeerConnection.CreateDataChannel("data", &webrtc.DataChannelInit{
|
||||
Ordered: &settingOrdered,
|
||||
MaxRetransmits: &settingMaxRetransmits,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to create data channel for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
return
|
||||
}
|
||||
participant.DataChannel = NewNestriDataChannel(dc)
|
||||
|
||||
// Register channel opening handling
|
||||
participant.DataChannel.RegisterOnOpen(func() {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("DataChannel open for participant: %s\n", participant.ID)
|
||||
}
|
||||
})
|
||||
|
||||
// Register channel closing handling
|
||||
participant.DataChannel.RegisterOnClose(func() {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("DataChannel closed for participant: %s\n", participant.ID)
|
||||
}
|
||||
})
|
||||
|
||||
// Register text message handling
|
||||
participant.DataChannel.RegisterMessageCallback("input", func(data []byte) {
|
||||
// Send to room if it has a DataChannel
|
||||
if room.DataChannel != nil {
|
||||
// If debug mode, decode and add our timestamp, otherwise just send to room
|
||||
if GetFlags().Debug {
|
||||
var inputMsg MessageInput
|
||||
if err = DecodeMessage(data, &inputMsg); err != nil {
|
||||
log.Printf("Failed to decode input message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
return
|
||||
}
|
||||
inputMsg.LatencyTracker.AddTimestamp("relay_to_node")
|
||||
// Encode and send
|
||||
if data, err = EncodeMessage(inputMsg); err != nil {
|
||||
log.Printf("Failed to encode input message for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
return
|
||||
}
|
||||
if err = room.DataChannel.SendBinary(data); err != nil {
|
||||
log.Printf("Failed to send input message to room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
} else {
|
||||
if err = room.DataChannel.SendBinary(data); err != nil {
|
||||
log.Printf("Failed to send input message to room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
participant.PeerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("ICE candidate for participant: '%s' in room: '%s'\n", participant.ID, room.Name)
|
||||
}
|
||||
err = participant.WebSocket.SendICECandidateMessageWS(candidate.ToJSON())
|
||||
if err != nil {
|
||||
log.Printf("Failed to send ICE candidate for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
}
|
||||
})
|
||||
|
||||
iceHolder := make([]webrtc.ICECandidateInit, 0)
|
||||
|
||||
// ICE callback
|
||||
participant.WebSocket.RegisterMessageCallback("ice", func(data []byte) {
|
||||
var iceMsg MessageICECandidate
|
||||
if err = DecodeMessage(data, &iceMsg); err != nil {
|
||||
log.Printf("Failed to decode ICE message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
return
|
||||
}
|
||||
candidate := webrtc.ICECandidateInit{
|
||||
Candidate: iceMsg.Candidate.Candidate,
|
||||
}
|
||||
if participant.PeerConnection.RemoteDescription() != nil {
|
||||
if err = participant.PeerConnection.AddICECandidate(candidate); err != nil {
|
||||
log.Printf("Failed to add ICE candidate from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
}
|
||||
// Add held ICE candidates
|
||||
for _, heldCandidate := range iceHolder {
|
||||
if err = participant.PeerConnection.AddICECandidate(heldCandidate); err != nil {
|
||||
log.Printf("Failed to add held ICE candidate from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
}
|
||||
}
|
||||
iceHolder = nil
|
||||
} else {
|
||||
iceHolder = append(iceHolder, candidate)
|
||||
}
|
||||
})
|
||||
|
||||
// SDP answer callback
|
||||
participant.WebSocket.RegisterMessageCallback("sdp", func(data []byte) {
|
||||
var sdpMsg MessageSDP
|
||||
if err = DecodeMessage(data, &sdpMsg); err != nil {
|
||||
log.Printf("Failed to decode SDP message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
return
|
||||
}
|
||||
handleParticipantSDP(participant, sdpMsg)
|
||||
})
|
||||
|
||||
// Log callback
|
||||
participant.WebSocket.RegisterMessageCallback("log", func(data []byte) {
|
||||
var logMsg MessageLog
|
||||
if err = DecodeMessage(data, &logMsg); err != nil {
|
||||
log.Printf("Failed to decode log message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
return
|
||||
}
|
||||
// TODO: Handle log message sending to metrics server
|
||||
})
|
||||
|
||||
// Metrics callback
|
||||
participant.WebSocket.RegisterMessageCallback("metrics", func(data []byte) {
|
||||
// Ignore for now
|
||||
})
|
||||
|
||||
participant.WebSocket.RegisterOnClose(func() {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("WebSocket closed for participant: '%s' in room: '%s'\n", participant.ID, room.Name)
|
||||
}
|
||||
// Remove from Room
|
||||
room.removeParticipantByID(participant.ID)
|
||||
})
|
||||
|
||||
log.Printf("Participant: '%s' in room: '%s' is now ready, sending an OK\n", participant.ID, room.Name)
|
||||
if err = participant.WebSocket.SendAnswerMessageWS(AnswerOK); err != nil {
|
||||
log.Printf("Failed to send OK answer for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
}
|
||||
|
||||
// If room is already online, send also offer
|
||||
if room.Online {
|
||||
if room.AudioTrack != nil {
|
||||
if err = participant.addTrack(&room.AudioTrack); err != nil {
|
||||
log.Printf("Failed to add audio track for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
}
|
||||
}
|
||||
if room.VideoTrack != nil {
|
||||
if err = participant.addTrack(&room.VideoTrack); err != nil {
|
||||
log.Printf("Failed to add video track for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
}
|
||||
}
|
||||
if err = participant.signalOffer(); err != nil {
|
||||
log.Printf("Failed to signal offer for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SDP answer handler for participants
|
||||
func handleParticipantSDP(participant *Participant, answerMsg MessageSDP) {
|
||||
// Get SDP offer
|
||||
sdpAnswer := answerMsg.SDP.SDP
|
||||
|
||||
// Set remote description
|
||||
err := participant.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeAnswer,
|
||||
SDP: sdpAnswer,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to set remote description for participant: '%s' - reason: %s\n", participant.ID, err)
|
||||
}
|
||||
}
|
||||
82
packages/relay/internal/flags.go
Normal file
82
packages/relay/internal/flags.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
var globalFlags *Flags
|
||||
|
||||
type Flags struct {
|
||||
Verbose bool
|
||||
Debug bool
|
||||
EndpointPort int
|
||||
WebRTCUDPStart int
|
||||
WebRTCUDPEnd int
|
||||
STUNServer string
|
||||
}
|
||||
|
||||
func (flags *Flags) DebugLog() {
|
||||
log.Println("Relay Flags:")
|
||||
log.Println("> Verbose: ", flags.Verbose)
|
||||
log.Println("> Debug: ", flags.Debug)
|
||||
log.Println("> Endpoint Port: ", flags.EndpointPort)
|
||||
log.Println("> WebRTC UDP Range Start: ", flags.WebRTCUDPStart)
|
||||
log.Println("> WebRTC UDP Range End: ", flags.WebRTCUDPEnd)
|
||||
log.Println("> WebRTC STUN Server: ", flags.STUNServer)
|
||||
}
|
||||
|
||||
func getEnvAsInt(name string, defaultVal int) int {
|
||||
valueStr := os.Getenv(name)
|
||||
if value, err := strconv.Atoi(valueStr); err != nil {
|
||||
return defaultVal
|
||||
} else {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvAsBool(name string, defaultVal bool) bool {
|
||||
valueStr := os.Getenv(name)
|
||||
val, err := strconv.ParseBool(valueStr)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func getEnvAsString(name string, defaultVal string) string {
|
||||
valueStr := os.Getenv(name)
|
||||
if len(valueStr) == 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return valueStr
|
||||
}
|
||||
|
||||
func InitFlags() {
|
||||
// Create Flags struct
|
||||
globalFlags = &Flags{}
|
||||
// Get flags
|
||||
flag.BoolVar(&globalFlags.Verbose, "verbose", getEnvAsBool("VERBOSE", false), "Verbose mode")
|
||||
flag.BoolVar(&globalFlags.Debug, "debug", getEnvAsBool("DEBUG", false), "Debug mode")
|
||||
flag.IntVar(&globalFlags.EndpointPort, "endpointPort", getEnvAsInt("ENDPOINT_PORT", 8088), "HTTP endpoint port")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPStart, "webrtcUDPStart", getEnvAsInt("WEBRTC_UDP_START", 10000), "WebRTC UDP port range start")
|
||||
flag.IntVar(&globalFlags.WebRTCUDPEnd, "webrtcUDPEnd", getEnvAsInt("WEBRTC_UDP_END", 20000), "WebRTC UDP port range end")
|
||||
flag.StringVar(&globalFlags.STUNServer, "stunServer", getEnvAsString("STUN_SERVER", "stun.l.google.com:19302"), "WebRTC STUN server")
|
||||
// Parse flags
|
||||
flag.Parse()
|
||||
|
||||
// ICE STUN servers
|
||||
globalWebRTCConfig.ICEServers = []webrtc.ICEServer{
|
||||
{
|
||||
URLs: []string{"stun:" + globalFlags.STUNServer},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GetFlags() *Flags {
|
||||
return globalFlags
|
||||
}
|
||||
123
packages/relay/internal/http.go
Normal file
123
packages/relay/internal/http.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/gorilla/websocket"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var httpMux *http.ServeMux
|
||||
|
||||
func InitHTTPEndpoint() {
|
||||
// Create HTTP mux which serves our WS endpoint
|
||||
httpMux = http.NewServeMux()
|
||||
|
||||
// Endpoints themselves
|
||||
httpMux.Handle("/", http.NotFoundHandler())
|
||||
httpMux.HandleFunc("/api/ws/{roomName}", corsAnyHandler(wsHandler))
|
||||
|
||||
// Get our serving port
|
||||
port := GetFlags().EndpointPort
|
||||
|
||||
// Log and start the endpoint server
|
||||
log.Println("Starting HTTP endpoint server on :", strconv.Itoa(port))
|
||||
go func() {
|
||||
log.Fatal((&http.Server{
|
||||
Handler: httpMux,
|
||||
Addr: ":" + strconv.Itoa(port),
|
||||
}).ListenAndServe())
|
||||
}()
|
||||
}
|
||||
|
||||
// logHTTPError logs (if verbose) and sends an error code to requester
|
||||
func logHTTPError(w http.ResponseWriter, err string, code int) {
|
||||
if GetFlags().Verbose {
|
||||
log.Println(err)
|
||||
}
|
||||
http.Error(w, err, code)
|
||||
}
|
||||
|
||||
// corsAnyHandler allows any origin to access the endpoint
|
||||
func corsAnyHandler(next func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
|
||||
return func(res http.ResponseWriter, req *http.Request) {
|
||||
// Allow all origins
|
||||
res.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
res.Header().Set("Access-Control-Allow-Methods", "*")
|
||||
res.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
|
||||
if req.Method != http.MethodOptions {
|
||||
next(res, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wsHandler is the handler for the /api/ws/{roomName} endpoint
|
||||
func wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Get given room name now
|
||||
roomName := r.PathValue("roomName")
|
||||
if len(roomName) <= 0 {
|
||||
logHTTPError(w, "no room name given", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get or create room in any case
|
||||
room := GetOrCreateRoom(roomName)
|
||||
|
||||
// Upgrade to WebSocket
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
wsConn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logHTTPError(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Create SafeWebSocket
|
||||
ws := NewSafeWebSocket(wsConn)
|
||||
// Assign message handler for join request
|
||||
ws.RegisterMessageCallback("join", func(data []byte) {
|
||||
var joinMsg MessageJoin
|
||||
if err = DecodeMessage(data, &joinMsg); err != nil {
|
||||
log.Printf("Failed to decode join message: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("Join request for room: '%s' from: '%s'\n", room.Name, joinMsg.JoinerType.String())
|
||||
}
|
||||
|
||||
// Handle join request, depending if it's from ingest/node or participant/client
|
||||
switch joinMsg.JoinerType {
|
||||
case JoinerNode:
|
||||
// If room already online, send InUse answer
|
||||
if room.Online {
|
||||
if err = ws.SendAnswerMessageWS(AnswerInUse); err != nil {
|
||||
log.Printf("Failed to send InUse answer for Room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
room.assignWebSocket(ws)
|
||||
go ingestHandler(room)
|
||||
case JoinerClient:
|
||||
// Create participant and add to room regardless of online status
|
||||
participant := NewParticipant(ws)
|
||||
room.addParticipant(participant)
|
||||
// If room not online, send Offline answer
|
||||
if !room.Online {
|
||||
if err = ws.SendAnswerMessageWS(AnswerOffline); err != nil {
|
||||
log.Printf("Failed to send Offline answer for Room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
}
|
||||
go participantHandler(participant, room)
|
||||
default:
|
||||
log.Printf("Unknown joiner type: %d\n", joinMsg.JoinerType)
|
||||
}
|
||||
|
||||
// Unregister ourselves, if something happens on the other side they should just reconnect?
|
||||
ws.UnregisterMessageCallback("join")
|
||||
})
|
||||
}
|
||||
251
packages/relay/internal/ingest.go
Normal file
251
packages/relay/internal/ingest.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ingestHandler(room *Room) {
|
||||
// Callback for closing PeerConnection
|
||||
onPCClose := func() {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("Closed PeerConnection for room: '%s'\n", room.Name)
|
||||
}
|
||||
room.Online = false
|
||||
DeleteRoomIfEmpty(room)
|
||||
}
|
||||
|
||||
var err error
|
||||
room.PeerConnection, err = CreatePeerConnection(onPCClose)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create PeerConnection for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
room.PeerConnection.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
var localTrack *webrtc.TrackLocalStaticRTP
|
||||
if remoteTrack.Kind() == webrtc.RTPCodecTypeVideo {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("Received video track for room: '%s'\n", room.Name)
|
||||
}
|
||||
localTrack, err = webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, "video", fmt.Sprint("nestri-", room.Name))
|
||||
if err != nil {
|
||||
log.Printf("Failed to create local video track for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return
|
||||
}
|
||||
room.VideoTrack = localTrack
|
||||
} else if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("Received audio track for room: '%s'\n", room.Name)
|
||||
}
|
||||
localTrack, err = webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, "audio", fmt.Sprint("nestri-", room.Name))
|
||||
if err != nil {
|
||||
log.Printf("Failed to create local audio track for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return
|
||||
}
|
||||
room.AudioTrack = localTrack
|
||||
}
|
||||
|
||||
// If both audio and video tracks are set, set online state
|
||||
if room.AudioTrack != nil && room.VideoTrack != nil {
|
||||
room.Online = true
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("Room online and receiving: '%s' - signaling participants\n", room.Name)
|
||||
}
|
||||
room.signalParticipantsWithTracks()
|
||||
}
|
||||
|
||||
rtpBuffer := make([]byte, 1400)
|
||||
for {
|
||||
read, _, err := remoteTrack.Read(rtpBuffer)
|
||||
if err != nil {
|
||||
// EOF is expected when stopping room
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Printf("RTP read error from room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
_, err = localTrack.Write(rtpBuffer[:read])
|
||||
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
log.Printf("Failed to write RTP to local track for room: '%s' - reason: %s\n", room.Name, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if remoteTrack.Kind() == webrtc.RTPCodecTypeVideo {
|
||||
room.VideoTrack = nil
|
||||
} else if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio {
|
||||
room.AudioTrack = nil
|
||||
}
|
||||
|
||||
if room.VideoTrack == nil && room.AudioTrack == nil {
|
||||
room.Online = false
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("Room offline and not receiving: '%s'\n", room.Name)
|
||||
}
|
||||
// Signal participants of room offline
|
||||
room.signalParticipantsOffline()
|
||||
DeleteRoomIfEmpty(room)
|
||||
}
|
||||
})
|
||||
|
||||
room.PeerConnection.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
room.DataChannel = NewNestriDataChannel(dc)
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("New DataChannel for room: '%s' - '%s'\n", room.Name, room.DataChannel.Label())
|
||||
}
|
||||
|
||||
// Register channel opening handling
|
||||
room.DataChannel.RegisterOnOpen(func() {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("DataChannel for room: '%s' - '%s' open\n", room.Name, room.DataChannel.Label())
|
||||
}
|
||||
})
|
||||
|
||||
room.DataChannel.OnClose(func() {
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("DataChannel for room: '%s' - '%s' closed\n", room.Name, room.DataChannel.Label())
|
||||
}
|
||||
})
|
||||
|
||||
// We do not handle any messages from ingest via DataChannel yet
|
||||
})
|
||||
|
||||
room.PeerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("ICE candidate for room: '%s'\n", room.Name)
|
||||
}
|
||||
err = room.WebSocket.SendICECandidateMessageWS(candidate.ToJSON())
|
||||
if err != nil {
|
||||
log.Printf("Failed to send ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
})
|
||||
|
||||
iceHolder := make([]webrtc.ICECandidateInit, 0)
|
||||
|
||||
// ICE callback
|
||||
room.WebSocket.RegisterMessageCallback("ice", func(data []byte) {
|
||||
var iceMsg MessageICECandidate
|
||||
if err = DecodeMessage(data, &iceMsg); err != nil {
|
||||
log.Printf("Failed to decode ICE candidate message from ingest for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return
|
||||
}
|
||||
candidate := webrtc.ICECandidateInit{
|
||||
Candidate: iceMsg.Candidate.Candidate,
|
||||
}
|
||||
if room.PeerConnection != nil {
|
||||
// If remote isn't set yet, store ICE candidates
|
||||
if room.PeerConnection.RemoteDescription() != nil {
|
||||
if err = room.PeerConnection.AddICECandidate(candidate); err != nil {
|
||||
log.Printf("Failed to add ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
// Add any held ICE candidates
|
||||
for _, heldCandidate := range iceHolder {
|
||||
if err = room.PeerConnection.AddICECandidate(heldCandidate); err != nil {
|
||||
log.Printf("Failed to add held ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
}
|
||||
iceHolder = nil
|
||||
} else {
|
||||
iceHolder = append(iceHolder, candidate)
|
||||
}
|
||||
} else {
|
||||
log.Printf("ICE candidate received before PeerConnection for room: '%s'\n", room.Name)
|
||||
}
|
||||
})
|
||||
|
||||
// SDP offer callback
|
||||
room.WebSocket.RegisterMessageCallback("sdp", func(data []byte) {
|
||||
var sdpMsg MessageSDP
|
||||
if err = DecodeMessage(data, &sdpMsg); err != nil {
|
||||
log.Printf("Failed to decode SDP message from ingest for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return
|
||||
}
|
||||
answer := handleIngestSDP(room, sdpMsg)
|
||||
if answer != nil {
|
||||
if err = room.WebSocket.SendSDPMessageWS(*answer); err != nil {
|
||||
log.Printf("Failed to send SDP answer to ingest for room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
} else {
|
||||
log.Printf("Failed to handle SDP message from ingest for room: '%s'\n", room.Name)
|
||||
}
|
||||
})
|
||||
|
||||
// Log callback
|
||||
room.WebSocket.RegisterMessageCallback("log", func(data []byte) {
|
||||
var logMsg MessageLog
|
||||
if err = DecodeMessage(data, &logMsg); err != nil {
|
||||
log.Printf("Failed to decode log message from ingest for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return
|
||||
}
|
||||
// TODO: Handle log message sending to metrics server
|
||||
})
|
||||
|
||||
// Metrics callback
|
||||
room.WebSocket.RegisterMessageCallback("metrics", func(data []byte) {
|
||||
var metricsMsg MessageMetrics
|
||||
if err = DecodeMessage(data, &metricsMsg); err != nil {
|
||||
log.Printf("Failed to decode metrics message from ingest for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return
|
||||
}
|
||||
// TODO: Handle metrics message sending to metrics server
|
||||
})
|
||||
|
||||
room.WebSocket.RegisterOnClose(func() {
|
||||
// If PeerConnection is not open or does not exist, delete room
|
||||
if (room.PeerConnection != nil && room.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateConnected) ||
|
||||
room.PeerConnection == nil {
|
||||
DeleteRoomIfEmpty(room)
|
||||
}
|
||||
})
|
||||
|
||||
log.Printf("Room: '%s' is ready, sending an OK\n", room.Name)
|
||||
if err = room.WebSocket.SendAnswerMessageWS(AnswerOK); err != nil {
|
||||
log.Printf("Failed to send OK answer for room: '%s' - reason: %s\n", room.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// SDP offer handler, returns SDP answer
|
||||
func handleIngestSDP(room *Room, offerMsg MessageSDP) *webrtc.SessionDescription {
|
||||
var err error
|
||||
|
||||
// Get SDP offer
|
||||
sdpOffer := offerMsg.SDP.SDP
|
||||
|
||||
// Modify SDP offer to remove opus "sprop-maxcapturerate=24000" (fixes opus bad quality issue, present in GStreamer)
|
||||
sdpOffer = strings.Replace(sdpOffer, ";sprop-maxcapturerate=24000", "", -1)
|
||||
|
||||
// Set new remote description
|
||||
err = room.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
SDP: sdpOffer,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to set remote description for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create SDP answer
|
||||
answer, err := room.PeerConnection.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create SDP answer for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set local description
|
||||
err = room.PeerConnection.SetLocalDescription(answer)
|
||||
if err != nil {
|
||||
log.Printf("Failed to set local description for room: '%s' - reason: %s\n", room.Name, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &answer
|
||||
}
|
||||
114
packages/relay/internal/latency.go
Normal file
114
packages/relay/internal/latency.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TimestampEntry struct {
|
||||
Stage string `json:"stage"`
|
||||
Time string `json:"time"` // ISO 8601 string
|
||||
}
|
||||
|
||||
// LatencyTracker provides a generic structure for measuring time taken at various stages in message processing.
|
||||
// It can be embedded in message structs for tracking the flow of data and calculating round-trip latency.
|
||||
type LatencyTracker struct {
|
||||
SequenceID string `json:"sequence_id"`
|
||||
Timestamps []TimestampEntry `json:"timestamps"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// NewLatencyTracker initializes a new LatencyTracker with the given sequence ID
|
||||
func NewLatencyTracker(sequenceID string) *LatencyTracker {
|
||||
return &LatencyTracker{
|
||||
SequenceID: sequenceID,
|
||||
Timestamps: make([]TimestampEntry, 0),
|
||||
Metadata: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// AddTimestamp adds a new timestamp for a specific stage
|
||||
func (lt *LatencyTracker) AddTimestamp(stage string) {
|
||||
lt.Timestamps = append(lt.Timestamps, TimestampEntry{
|
||||
Stage: stage,
|
||||
// Ensure extremely precise UTC RFC3339 timestamps (down to nanoseconds)
|
||||
Time: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
})
|
||||
}
|
||||
|
||||
// TotalLatency calculates the total latency from the earliest to the latest timestamp
|
||||
func (lt *LatencyTracker) TotalLatency() (int64, error) {
|
||||
if len(lt.Timestamps) < 2 {
|
||||
return 0, nil // Not enough timestamps to calculate latency
|
||||
}
|
||||
|
||||
var earliest, latest time.Time
|
||||
for _, ts := range lt.Timestamps {
|
||||
t, err := time.Parse(time.RFC3339, ts.Time)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if earliest.IsZero() || t.Before(earliest) {
|
||||
earliest = t
|
||||
}
|
||||
if latest.IsZero() || t.After(latest) {
|
||||
latest = t
|
||||
}
|
||||
}
|
||||
|
||||
return latest.Sub(earliest).Milliseconds(), nil
|
||||
}
|
||||
|
||||
// PainPoints returns a list of stages where the duration exceeds the given threshold.
|
||||
func (lt *LatencyTracker) PainPoints(threshold time.Duration) []string {
|
||||
var painPoints []string
|
||||
var lastStage string
|
||||
var lastTime time.Time
|
||||
|
||||
for _, ts := range lt.Timestamps {
|
||||
stage := ts.Stage
|
||||
t := ts.Time
|
||||
if lastStage == "" {
|
||||
lastStage = stage
|
||||
lastTime, _ = time.Parse(time.RFC3339, t)
|
||||
continue
|
||||
}
|
||||
|
||||
currentTime, _ := time.Parse(time.RFC3339, t)
|
||||
if currentTime.Sub(lastTime) > threshold {
|
||||
painPoints = append(painPoints, fmt.Sprintf("%s -> %s", lastStage, stage))
|
||||
}
|
||||
|
||||
lastStage = stage
|
||||
lastTime = currentTime
|
||||
}
|
||||
return painPoints
|
||||
}
|
||||
|
||||
// StageLatency calculates the time taken between two specific stages.
|
||||
func (lt *LatencyTracker) StageLatency(startStage, endStage string) (time.Duration, error) {
|
||||
startTime, endTime := "", ""
|
||||
for _, ts := range lt.Timestamps {
|
||||
if ts.Stage == startStage {
|
||||
startTime = ts.Time
|
||||
}
|
||||
if ts.Stage == endStage {
|
||||
endTime = ts.Time
|
||||
}
|
||||
}
|
||||
|
||||
if startTime == "" || endTime == "" {
|
||||
return 0, fmt.Errorf("missing timestamps for stages: %s -> %s", startStage, endStage)
|
||||
}
|
||||
|
||||
start, err := time.Parse(time.RFC3339, startTime)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
end, err := time.Parse(time.RFC3339, endTime)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return end.Sub(start), nil
|
||||
}
|
||||
227
packages/relay/internal/messages.go
Normal file
227
packages/relay/internal/messages.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OnMessageCallback is a callback for binary messages of given type
|
||||
type OnMessageCallback func(data []byte)
|
||||
|
||||
// MessageBase is the base type for WS/DC messages.
|
||||
type MessageBase struct {
|
||||
PayloadType string `json:"payload_type"`
|
||||
LatencyTracker LatencyTracker `json:"latency_tracker,omitempty"`
|
||||
}
|
||||
|
||||
// MessageInput represents an input message.
|
||||
type MessageInput struct {
|
||||
MessageBase
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// MessageLog represents a log message.
|
||||
type MessageLog struct {
|
||||
MessageBase
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Time string `json:"time"`
|
||||
}
|
||||
|
||||
// MessageMetrics represents a metrics/heartbeat message.
|
||||
type MessageMetrics struct {
|
||||
MessageBase
|
||||
UsageCPU float64 `json:"usage_cpu"`
|
||||
UsageMemory float64 `json:"usage_memory"`
|
||||
Uptime uint64 `json:"uptime"`
|
||||
PipelineLatency float64 `json:"pipeline_latency"`
|
||||
}
|
||||
|
||||
// MessageICECandidate represents an ICE candidate message.
|
||||
type MessageICECandidate struct {
|
||||
MessageBase
|
||||
Candidate webrtc.ICECandidateInit `json:"candidate"`
|
||||
}
|
||||
|
||||
// MessageSDP represents an SDP message.
|
||||
type MessageSDP struct {
|
||||
MessageBase
|
||||
SDP webrtc.SessionDescription `json:"sdp"`
|
||||
}
|
||||
|
||||
// JoinerType is an enum for the type of incoming room joiner
|
||||
type JoinerType int
|
||||
|
||||
const (
|
||||
JoinerNode JoinerType = iota
|
||||
JoinerClient
|
||||
)
|
||||
|
||||
func (jt *JoinerType) String() string {
|
||||
switch *jt {
|
||||
case JoinerNode:
|
||||
return "node"
|
||||
case JoinerClient:
|
||||
return "client"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// MessageJoin is used to tell us that either participant or ingest wants to join the room
|
||||
type MessageJoin struct {
|
||||
MessageBase
|
||||
JoinerType JoinerType `json:"joiner_type"`
|
||||
}
|
||||
|
||||
// AnswerType is an enum for the type of answer, signaling Room state for a joiner
|
||||
type AnswerType int
|
||||
|
||||
const (
|
||||
AnswerOffline AnswerType = iota // For participant/client, when the room is offline without stream
|
||||
AnswerInUse // For ingest/node joiner, when the room is already in use by another ingest/node
|
||||
AnswerOK // For both, when the join request is handled successfully
|
||||
)
|
||||
|
||||
// MessageAnswer is used to send the answer to a join request
|
||||
type MessageAnswer struct {
|
||||
MessageBase
|
||||
AnswerType AnswerType `json:"answer_type"`
|
||||
}
|
||||
|
||||
// EncodeMessage encodes a message to be sent with gzip compression
|
||||
func EncodeMessage(msg interface{}) ([]byte, error) {
|
||||
// Marshal the message to JSON
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode message: %w", err)
|
||||
}
|
||||
|
||||
// Gzip compress the JSON
|
||||
var compressedData bytes.Buffer
|
||||
writer := gzip.NewWriter(&compressedData)
|
||||
_, err = writer.Write(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compress message: %w", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to finalize compression: %w", err)
|
||||
}
|
||||
|
||||
return compressedData.Bytes(), nil
|
||||
}
|
||||
|
||||
// DecodeMessage decodes a message received with gzip decompression
|
||||
func DecodeMessage(data []byte, target interface{}) error {
|
||||
// Gzip decompress the data
|
||||
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize decompression: %w", err)
|
||||
}
|
||||
defer func(reader *gzip.Reader) {
|
||||
if err = reader.Close(); err != nil {
|
||||
fmt.Printf("failed to close reader: %v\n", err)
|
||||
}
|
||||
}(reader)
|
||||
|
||||
// Decode the JSON
|
||||
err = json.NewDecoder(reader).Decode(target)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendLogMessageWS sends a log message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendLogMessageWS(level, message string) error {
|
||||
msg := MessageLog{
|
||||
MessageBase: MessageBase{PayloadType: "log"},
|
||||
Level: level,
|
||||
Message: message,
|
||||
Time: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
encoded, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode log message: %w", err)
|
||||
}
|
||||
|
||||
return ws.SendBinary(encoded)
|
||||
}
|
||||
|
||||
// SendMetricsMessageWS sends a metrics message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendMetricsMessageWS(usageCPU, usageMemory float64, uptime uint64, pipelineLatency float64) error {
|
||||
msg := MessageMetrics{
|
||||
MessageBase: MessageBase{PayloadType: "metrics"},
|
||||
UsageCPU: usageCPU,
|
||||
UsageMemory: usageMemory,
|
||||
Uptime: uptime,
|
||||
PipelineLatency: pipelineLatency,
|
||||
}
|
||||
encoded, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode metrics message: %w", err)
|
||||
}
|
||||
|
||||
return ws.SendBinary(encoded)
|
||||
}
|
||||
|
||||
// SendICECandidateMessageWS sends an ICE candidate message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendICECandidateMessageWS(candidate webrtc.ICECandidateInit) error {
|
||||
msg := MessageICECandidate{
|
||||
MessageBase: MessageBase{PayloadType: "ice"},
|
||||
Candidate: candidate,
|
||||
}
|
||||
encoded, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode ICE candidate message: %w", err)
|
||||
}
|
||||
|
||||
return ws.SendBinary(encoded)
|
||||
}
|
||||
|
||||
// SendSDPMessageWS sends an SDP message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendSDPMessageWS(sdp webrtc.SessionDescription) error {
|
||||
msg := MessageSDP{
|
||||
MessageBase: MessageBase{PayloadType: "sdp"},
|
||||
SDP: sdp,
|
||||
}
|
||||
encoded, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode SDP message: %w", err)
|
||||
}
|
||||
|
||||
return ws.SendBinary(encoded)
|
||||
}
|
||||
|
||||
// SendAnswerMessageWS sends an answer message to the given WebSocket connection.
|
||||
func (ws *SafeWebSocket) SendAnswerMessageWS(answer AnswerType) error {
|
||||
msg := MessageAnswer{
|
||||
MessageBase: MessageBase{PayloadType: "answer"},
|
||||
AnswerType: answer,
|
||||
}
|
||||
encoded, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode answer message: %w", err)
|
||||
}
|
||||
|
||||
return ws.SendBinary(encoded)
|
||||
}
|
||||
|
||||
// SendInputMessageDC sends an input message to the given DataChannel connection.
|
||||
func (ndc *NestriDataChannel) SendInputMessageDC(data string) error {
|
||||
msg := MessageInput{
|
||||
MessageBase: MessageBase{PayloadType: "input"},
|
||||
Data: data,
|
||||
}
|
||||
encoded, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode input message: %w", err)
|
||||
}
|
||||
|
||||
return ndc.SendBinary(encoded)
|
||||
}
|
||||
69
packages/relay/internal/participant.go
Normal file
69
packages/relay/internal/participant.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
type Participant struct {
|
||||
ID uuid.UUID //< Internal IDs are useful to keeping unique internal track and not have conflicts later
|
||||
Name string
|
||||
WebSocket *SafeWebSocket
|
||||
PeerConnection *webrtc.PeerConnection
|
||||
DataChannel *NestriDataChannel
|
||||
}
|
||||
|
||||
func NewParticipant(ws *SafeWebSocket) *Participant {
|
||||
return &Participant{
|
||||
ID: uuid.New(),
|
||||
Name: createRandomName(),
|
||||
WebSocket: ws,
|
||||
}
|
||||
}
|
||||
|
||||
func (vw *Participant) addTrack(trackLocal *webrtc.TrackLocal) error {
|
||||
rtpSender, err := vw.PeerConnection.AddTrack(*trackLocal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
rtcpBuffer := make([]byte, 1400)
|
||||
for {
|
||||
if _, _, rtcpErr := rtpSender.Read(rtcpBuffer); rtcpErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vw *Participant) signalOffer() error {
|
||||
if vw.PeerConnection == nil {
|
||||
return fmt.Errorf("peer connection is nil for participant: '%s' - cannot signal offer", vw.ID)
|
||||
}
|
||||
|
||||
offer, err := vw.PeerConnection.CreateOffer(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = vw.PeerConnection.SetLocalDescription(offer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return vw.WebSocket.SendSDPMessageWS(offer)
|
||||
}
|
||||
|
||||
var namesFirst = []string{"Happy", "Sad", "Angry", "Calm", "Excited", "Bored", "Confused", "Confident", "Curious", "Depressed", "Disappointed", "Embarrassed", "Energetic", "Fearful", "Frustrated", "Glad", "Guilty", "Hopeful", "Impatient", "Jealous", "Lonely", "Motivated", "Nervous", "Optimistic", "Pessimistic", "Proud", "Relaxed", "Shy", "Stressed", "Surprised", "Tired", "Worried"}
|
||||
var namesSecond = []string{"Dragon", "Unicorn", "Troll", "Goblin", "Elf", "Dwarf", "Ogre", "Gnome", "Mermaid", "Siren", "Vampire", "Ghoul", "Werewolf", "Minotaur", "Centaur", "Griffin", "Phoenix", "Wyvern", "Hydra", "Kraken"}
|
||||
|
||||
func createRandomName() string {
|
||||
randomFirst := namesFirst[rand.Intn(len(namesFirst))]
|
||||
randomSecond := namesSecond[rand.Intn(len(namesSecond))]
|
||||
return randomFirst + " " + randomSecond
|
||||
}
|
||||
179
packages/relay/internal/room.go
Normal file
179
packages/relay/internal/room.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var Rooms = make(map[uuid.UUID]*Room) //< Room ID -> Room
|
||||
var RoomsMutex = sync.RWMutex{}
|
||||
|
||||
func GetRoomByID(id uuid.UUID) *Room {
|
||||
RoomsMutex.RLock()
|
||||
defer RoomsMutex.RUnlock()
|
||||
if room, ok := Rooms[id]; ok {
|
||||
return room
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetRoomByName(name string) *Room {
|
||||
RoomsMutex.RLock()
|
||||
defer RoomsMutex.RUnlock()
|
||||
for _, room := range Rooms {
|
||||
if room.Name == name {
|
||||
return room
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetOrCreateRoom(name string) *Room {
|
||||
if room := GetRoomByName(name); room != nil {
|
||||
return room
|
||||
}
|
||||
RoomsMutex.Lock()
|
||||
room := NewRoom(name)
|
||||
Rooms[room.ID] = room
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("New room: '%s'\n", room.Name)
|
||||
}
|
||||
RoomsMutex.Unlock()
|
||||
return room
|
||||
}
|
||||
|
||||
func DeleteRoomIfEmpty(room *Room) {
|
||||
room.ParticipantsMutex.RLock()
|
||||
defer room.ParticipantsMutex.RUnlock()
|
||||
if !room.Online && len(room.Participants) <= 0 {
|
||||
RoomsMutex.Lock()
|
||||
delete(Rooms, room.ID)
|
||||
RoomsMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
type Room struct {
|
||||
ID uuid.UUID //< Internal IDs are useful to keeping unique internal track
|
||||
Name string
|
||||
Online bool //< Whether the room is currently online, i.e. receiving data from a nestri-server
|
||||
WebSocket *SafeWebSocket
|
||||
PeerConnection *webrtc.PeerConnection
|
||||
AudioTrack webrtc.TrackLocal
|
||||
VideoTrack webrtc.TrackLocal
|
||||
DataChannel *NestriDataChannel
|
||||
Participants map[uuid.UUID]*Participant
|
||||
ParticipantsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRoom(name string) *Room {
|
||||
return &Room{
|
||||
ID: uuid.New(),
|
||||
Name: name,
|
||||
Online: false,
|
||||
Participants: make(map[uuid.UUID]*Participant),
|
||||
}
|
||||
}
|
||||
|
||||
// Assigns a WebSocket connection to a Room
|
||||
func (r *Room) assignWebSocket(ws *SafeWebSocket) {
|
||||
// If WS already assigned, warn
|
||||
if r.WebSocket != nil {
|
||||
log.Printf("Warning: Room '%s' already has a WebSocket assigned\n", r.Name)
|
||||
}
|
||||
r.WebSocket = ws
|
||||
}
|
||||
|
||||
// Adds a Participant to a Room
|
||||
func (r *Room) addParticipant(participant *Participant) {
|
||||
r.ParticipantsMutex.Lock()
|
||||
r.Participants[participant.ID] = participant
|
||||
r.ParticipantsMutex.Unlock()
|
||||
}
|
||||
|
||||
// Removes a Participant from a Room by participant's ID.
|
||||
// If Room is offline and this is the last participant, the room is deleted
|
||||
func (r *Room) removeParticipantByID(pID uuid.UUID) {
|
||||
r.ParticipantsMutex.Lock()
|
||||
delete(r.Participants, pID)
|
||||
r.ParticipantsMutex.Unlock()
|
||||
DeleteRoomIfEmpty(r)
|
||||
}
|
||||
|
||||
// Removes a Participant from a Room by participant's name.
|
||||
// If Room is offline and this is the last participant, the room is deleted
|
||||
func (r *Room) removeParticipantByName(pName string) {
|
||||
r.ParticipantsMutex.Lock()
|
||||
for id, p := range r.Participants {
|
||||
if p.Name == pName {
|
||||
delete(r.Participants, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
r.ParticipantsMutex.Unlock()
|
||||
DeleteRoomIfEmpty(r)
|
||||
}
|
||||
|
||||
// Signals all participants with offer and add tracks to their PeerConnections
|
||||
func (r *Room) signalParticipantsWithTracks() {
|
||||
r.ParticipantsMutex.RLock()
|
||||
for _, participant := range r.Participants {
|
||||
// Add tracks to participant's PeerConnection
|
||||
if r.AudioTrack != nil {
|
||||
if err := participant.addTrack(&r.AudioTrack); err != nil {
|
||||
log.Printf("Failed to add audio track to participant: '%s' - reason: %s\n", participant.ID, err)
|
||||
}
|
||||
}
|
||||
if r.VideoTrack != nil {
|
||||
if err := participant.addTrack(&r.VideoTrack); err != nil {
|
||||
log.Printf("Failed to add video track to participant: '%s' - reason: %s\n", participant.ID, err)
|
||||
}
|
||||
}
|
||||
// Signal participant with offer
|
||||
if err := participant.signalOffer(); err != nil {
|
||||
log.Printf("Error signaling participant: %v\n", err)
|
||||
}
|
||||
}
|
||||
r.ParticipantsMutex.RUnlock()
|
||||
}
|
||||
|
||||
// Signals all participants that the Room is offline
|
||||
func (r *Room) signalParticipantsOffline() {
|
||||
r.ParticipantsMutex.RLock()
|
||||
for _, participant := range r.Participants {
|
||||
if err := participant.WebSocket.SendAnswerMessageWS(AnswerOffline); err != nil {
|
||||
log.Printf("Failed to send Offline answer for participant: '%s' - reason: %s\n", participant.ID, err)
|
||||
}
|
||||
}
|
||||
r.ParticipantsMutex.RUnlock()
|
||||
}
|
||||
|
||||
// Broadcasts a message to Room's Participant's - excluding one given ID of
|
||||
func (r *Room) broadcastMessage(msg webrtc.DataChannelMessage, excludeID uuid.UUID) {
|
||||
r.ParticipantsMutex.RLock()
|
||||
for d, participant := range r.Participants {
|
||||
if participant.DataChannel != nil {
|
||||
if d != excludeID { // Don't send back to the sender
|
||||
if err := participant.DataChannel.SendText(string(msg.Data)); err != nil {
|
||||
log.Printf("Error broadcasting to %s: %v\n", participant.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if r.DataChannel != nil {
|
||||
if err := r.DataChannel.SendText(string(msg.Data)); err != nil {
|
||||
log.Printf("Error broadcasting to Room: %v\n", err)
|
||||
}
|
||||
}
|
||||
r.ParticipantsMutex.RUnlock()
|
||||
}
|
||||
|
||||
// Sends message to Room (nestri-server)
|
||||
func (r *Room) sendToRoom(msg webrtc.DataChannelMessage) {
|
||||
if r.DataChannel != nil {
|
||||
if err := r.DataChannel.SendText(string(msg.Data)); err != nil {
|
||||
log.Printf("Error broadcasting to Room: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
114
packages/relay/internal/websocket.go
Normal file
114
packages/relay/internal/websocket.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/gorilla/websocket"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SafeWebSocket is a websocket with a mutex
|
||||
type SafeWebSocket struct {
|
||||
*websocket.Conn
|
||||
sync.Mutex
|
||||
binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback
|
||||
}
|
||||
|
||||
// NewSafeWebSocket creates a new SafeWebSocket from *websocket.Conn
|
||||
func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
|
||||
ws := &SafeWebSocket{
|
||||
Conn: conn,
|
||||
binaryCallbacks: make(map[string]OnMessageCallback),
|
||||
}
|
||||
|
||||
// Launch a goroutine to handle binary messages
|
||||
go func() {
|
||||
for {
|
||||
// Read binary message
|
||||
kind, data, err := ws.Conn.ReadMessage()
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
|
||||
// If unexpected close error, break
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("Unexpected WebSocket close error, reason: %s\n", err)
|
||||
}
|
||||
break
|
||||
} else if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
|
||||
// If closing, just break
|
||||
if GetFlags().Verbose {
|
||||
log.Printf("WebSocket closing\n")
|
||||
}
|
||||
break
|
||||
} else if err != nil {
|
||||
log.Printf("Failed to read WebSocket message, reason: %s\n", err)
|
||||
break
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case websocket.TextMessage:
|
||||
// Ignore, we use binary messages
|
||||
continue
|
||||
case websocket.BinaryMessage:
|
||||
// Decode message
|
||||
var msg MessageBase
|
||||
if err = DecodeMessage(data, &msg); err != nil {
|
||||
log.Printf("Failed to decode binary WebSocket message, reason: %s\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle message type callback
|
||||
if callback, ok := ws.binaryCallbacks[msg.PayloadType]; ok {
|
||||
callback(data)
|
||||
} // TODO: Log unknown message type?
|
||||
default:
|
||||
log.Printf("Unknown WebSocket message type: %d\n", kind)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ws
|
||||
}
|
||||
|
||||
// SendJSON writes JSON to a websocket with a mutex
|
||||
func (ws *SafeWebSocket) SendJSON(v interface{}) error {
|
||||
ws.Lock()
|
||||
defer ws.Unlock()
|
||||
return ws.Conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
// SendBinary writes binary to a websocket with a mutex
|
||||
func (ws *SafeWebSocket) SendBinary(data []byte) error {
|
||||
ws.Lock()
|
||||
defer ws.Unlock()
|
||||
return ws.Conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
// RegisterMessageCallback sets the callback for binary message of given type
|
||||
func (ws *SafeWebSocket) RegisterMessageCallback(msgType string, callback OnMessageCallback) {
|
||||
ws.Lock()
|
||||
defer ws.Unlock()
|
||||
if ws.binaryCallbacks == nil {
|
||||
ws.binaryCallbacks = make(map[string]OnMessageCallback)
|
||||
}
|
||||
ws.binaryCallbacks[msgType] = callback
|
||||
}
|
||||
|
||||
// UnregisterMessageCallback removes the callback for binary message of given type
|
||||
func (ws *SafeWebSocket) UnregisterMessageCallback(msgType string) {
|
||||
ws.Lock()
|
||||
defer ws.Unlock()
|
||||
if ws.binaryCallbacks != nil {
|
||||
delete(ws.binaryCallbacks, msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterOnClose sets the callback for websocket closing
|
||||
func (ws *SafeWebSocket) RegisterOnClose(callback func()) {
|
||||
ws.SetCloseHandler(func(code int, text string) error {
|
||||
// Clear our callbacks
|
||||
ws.Lock()
|
||||
ws.binaryCallbacks = nil
|
||||
ws.Unlock()
|
||||
// Call the callback
|
||||
callback()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user