feat: Add streaming support (#125)

This adds:
- [x] Keyboard and mouse handling on the frontend
- [x] Video and audio streaming from the backend to the frontend
- [x] Input server that works with Websockets

Update - 17/11
- [ ] Master docker container to run this
- [ ] Steam runtime
- [ ] Entrypoint.sh

---------

Co-authored-by: Kristian Ollikainen <14197772+DatCaptainHorse@users.noreply.github.com>
Co-authored-by: Kristian Ollikainen <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
Wanjohi
2024-12-08 14:54:56 +03:00
committed by GitHub
parent 5eb21eeadb
commit 379db1c87b
137 changed files with 12737 additions and 5234 deletions

View File

@@ -0,0 +1,100 @@
package relay
import (
"github.com/pion/interceptor"
"github.com/pion/webrtc/v4"
"log"
)
var globalWebRTCAPI *webrtc.API
var globalWebRTCConfig = webrtc.Configuration{
ICETransportPolicy: webrtc.ICETransportPolicyAll,
BundlePolicy: webrtc.BundlePolicyBalanced,
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
}
func InitWebRTCAPI() error {
var err error
flags := GetFlags()
// Media engine
mediaEngine := &webrtc.MediaEngine{}
// Default codecs cover most of our needs
err = mediaEngine.RegisterDefaultCodecs()
if err != nil {
return err
}
// Add H.265 for special cases
videoRTCPFeedback := []webrtc.RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}}
for _, codec := range []webrtc.RTPCodecParameters{
{
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH265, ClockRate: 90000, RTCPFeedback: videoRTCPFeedback},
PayloadType: 48,
},
{
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeRTX, ClockRate: 90000, SDPFmtpLine: "apt=48"},
PayloadType: 49,
},
} {
if err := mediaEngine.RegisterCodec(codec, webrtc.RTPCodecTypeVideo); err != nil {
return err
}
}
// Interceptor registry
interceptorRegistry := &interceptor.Registry{}
// Use default set
err = webrtc.RegisterDefaultInterceptors(mediaEngine, interceptorRegistry)
if err != nil {
return err
}
// Setting engine
settingEngine := webrtc.SettingEngine{}
// New in v4, reduces CPU usage and latency when enabled
settingEngine.EnableSCTPZeroChecksum(true)
// Set the UDP port range used by WebRTC
err = settingEngine.SetEphemeralUDPPortRange(uint16(flags.WebRTCUDPStart), uint16(flags.WebRTCUDPEnd))
if err != nil {
return err
}
// Create a new API object with our customized settings
globalWebRTCAPI = webrtc.NewAPI(webrtc.WithMediaEngine(mediaEngine), webrtc.WithSettingEngine(settingEngine), webrtc.WithInterceptorRegistry(interceptorRegistry))
return nil
}
// GetWebRTCAPI returns the global WebRTC API
func GetWebRTCAPI() *webrtc.API {
return globalWebRTCAPI
}
// CreatePeerConnection sets up a new peer connection
func CreatePeerConnection(onClose func()) (*webrtc.PeerConnection, error) {
pc, err := globalWebRTCAPI.NewPeerConnection(globalWebRTCConfig)
if err != nil {
return nil, err
}
// Log connection state changes and handle failed/disconnected connections
pc.OnConnectionStateChange(func(connectionState webrtc.PeerConnectionState) {
// Close PeerConnection in cases
if connectionState == webrtc.PeerConnectionStateFailed ||
connectionState == webrtc.PeerConnectionStateDisconnected ||
connectionState == webrtc.PeerConnectionStateClosed {
err := pc.Close()
if err != nil {
log.Printf("Error closing PeerConnection: %s\n", err.Error())
}
onClose()
}
})
return pc, nil
}

View File

@@ -0,0 +1,72 @@
package relay
import (
"github.com/pion/webrtc/v4"
"log"
)
// NestriDataChannel is a custom data channel with callbacks
type NestriDataChannel struct {
*webrtc.DataChannel
binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback
}
// NewNestriDataChannel creates a new NestriDataChannel from *webrtc.DataChannel
func NewNestriDataChannel(dc *webrtc.DataChannel) *NestriDataChannel {
ndc := &NestriDataChannel{
DataChannel: dc,
binaryCallbacks: make(map[string]OnMessageCallback),
}
// Handler for incoming messages
ndc.OnMessage(func(msg webrtc.DataChannelMessage) {
// If string type message, ignore
if msg.IsString {
return
}
// Decode message
var base MessageBase
if err := DecodeMessage(msg.Data, &base); err != nil {
log.Printf("Failed to decode binary DataChannel message, reason: %s\n", err)
return
}
// Handle message type callback
if callback, ok := ndc.binaryCallbacks[base.PayloadType]; ok {
go callback(msg.Data)
} // TODO: Log unknown message type?
})
return ndc
}
// SendBinary sends a binary message to the data channel
func (ndc *NestriDataChannel) SendBinary(data []byte) error {
return ndc.Send(data)
}
// RegisterMessageCallback registers a callback for a given binary message type
func (ndc *NestriDataChannel) RegisterMessageCallback(msgType string, callback OnMessageCallback) {
if ndc.binaryCallbacks == nil {
ndc.binaryCallbacks = make(map[string]OnMessageCallback)
}
ndc.binaryCallbacks[msgType] = callback
}
// UnregisterMessageCallback removes the callback for a given binary message type
func (ndc *NestriDataChannel) UnregisterMessageCallback(msgType string) {
if ndc.binaryCallbacks != nil {
delete(ndc.binaryCallbacks, msgType)
}
}
// RegisterOnOpen registers a callback for the data channel opening
func (ndc *NestriDataChannel) RegisterOnOpen(callback func()) {
ndc.OnOpen(callback)
}
// RegisterOnClose registers a callback for the data channel closing
func (ndc *NestriDataChannel) RegisterOnClose(callback func()) {
ndc.OnClose(callback)
}

View File

@@ -0,0 +1,189 @@
package relay
import (
"github.com/pion/webrtc/v4"
"log"
)
func participantHandler(participant *Participant, room *Room) {
// Callback for closing PeerConnection
onPCClose := func() {
if GetFlags().Verbose {
log.Printf("Closed PeerConnection for participant: '%s'\n", participant.ID)
}
room.removeParticipantByID(participant.ID)
}
var err error
participant.PeerConnection, err = CreatePeerConnection(onPCClose)
if err != nil {
log.Printf("Failed to create PeerConnection for participant: '%s' - reason: %s\n", participant.ID, err)
return
}
// Data channel settings
settingOrdered := false
settingMaxRetransmits := uint16(0)
dc, err := participant.PeerConnection.CreateDataChannel("data", &webrtc.DataChannelInit{
Ordered: &settingOrdered,
MaxRetransmits: &settingMaxRetransmits,
})
if err != nil {
log.Printf("Failed to create data channel for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
participant.DataChannel = NewNestriDataChannel(dc)
// Register channel opening handling
participant.DataChannel.RegisterOnOpen(func() {
if GetFlags().Verbose {
log.Printf("DataChannel open for participant: %s\n", participant.ID)
}
})
// Register channel closing handling
participant.DataChannel.RegisterOnClose(func() {
if GetFlags().Verbose {
log.Printf("DataChannel closed for participant: %s\n", participant.ID)
}
})
// Register text message handling
participant.DataChannel.RegisterMessageCallback("input", func(data []byte) {
// Send to room if it has a DataChannel
if room.DataChannel != nil {
// If debug mode, decode and add our timestamp, otherwise just send to room
if GetFlags().Debug {
var inputMsg MessageInput
if err = DecodeMessage(data, &inputMsg); err != nil {
log.Printf("Failed to decode input message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
inputMsg.LatencyTracker.AddTimestamp("relay_to_node")
// Encode and send
if data, err = EncodeMessage(inputMsg); err != nil {
log.Printf("Failed to encode input message for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
if err = room.DataChannel.SendBinary(data); err != nil {
log.Printf("Failed to send input message to room: '%s' - reason: %s\n", room.Name, err)
}
} else {
if err = room.DataChannel.SendBinary(data); err != nil {
log.Printf("Failed to send input message to room: '%s' - reason: %s\n", room.Name, err)
}
}
}
})
participant.PeerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
return
}
if GetFlags().Verbose {
log.Printf("ICE candidate for participant: '%s' in room: '%s'\n", participant.ID, room.Name)
}
err = participant.WebSocket.SendICECandidateMessageWS(candidate.ToJSON())
if err != nil {
log.Printf("Failed to send ICE candidate for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
})
iceHolder := make([]webrtc.ICECandidateInit, 0)
// ICE callback
participant.WebSocket.RegisterMessageCallback("ice", func(data []byte) {
var iceMsg MessageICECandidate
if err = DecodeMessage(data, &iceMsg); err != nil {
log.Printf("Failed to decode ICE message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
candidate := webrtc.ICECandidateInit{
Candidate: iceMsg.Candidate.Candidate,
}
if participant.PeerConnection.RemoteDescription() != nil {
if err = participant.PeerConnection.AddICECandidate(candidate); err != nil {
log.Printf("Failed to add ICE candidate from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
// Add held ICE candidates
for _, heldCandidate := range iceHolder {
if err = participant.PeerConnection.AddICECandidate(heldCandidate); err != nil {
log.Printf("Failed to add held ICE candidate from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
}
iceHolder = nil
} else {
iceHolder = append(iceHolder, candidate)
}
})
// SDP answer callback
participant.WebSocket.RegisterMessageCallback("sdp", func(data []byte) {
var sdpMsg MessageSDP
if err = DecodeMessage(data, &sdpMsg); err != nil {
log.Printf("Failed to decode SDP message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
handleParticipantSDP(participant, sdpMsg)
})
// Log callback
participant.WebSocket.RegisterMessageCallback("log", func(data []byte) {
var logMsg MessageLog
if err = DecodeMessage(data, &logMsg); err != nil {
log.Printf("Failed to decode log message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
// TODO: Handle log message sending to metrics server
})
// Metrics callback
participant.WebSocket.RegisterMessageCallback("metrics", func(data []byte) {
// Ignore for now
})
participant.WebSocket.RegisterOnClose(func() {
if GetFlags().Verbose {
log.Printf("WebSocket closed for participant: '%s' in room: '%s'\n", participant.ID, room.Name)
}
// Remove from Room
room.removeParticipantByID(participant.ID)
})
log.Printf("Participant: '%s' in room: '%s' is now ready, sending an OK\n", participant.ID, room.Name)
if err = participant.WebSocket.SendAnswerMessageWS(AnswerOK); err != nil {
log.Printf("Failed to send OK answer for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
// If room is already online, send also offer
if room.Online {
if room.AudioTrack != nil {
if err = participant.addTrack(&room.AudioTrack); err != nil {
log.Printf("Failed to add audio track for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
}
if room.VideoTrack != nil {
if err = participant.addTrack(&room.VideoTrack); err != nil {
log.Printf("Failed to add video track for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
}
if err = participant.signalOffer(); err != nil {
log.Printf("Failed to signal offer for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
}
}
// SDP answer handler for participants
func handleParticipantSDP(participant *Participant, answerMsg MessageSDP) {
// Get SDP offer
sdpAnswer := answerMsg.SDP.SDP
// Set remote description
err := participant.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer,
SDP: sdpAnswer,
})
if err != nil {
log.Printf("Failed to set remote description for participant: '%s' - reason: %s\n", participant.ID, err)
}
}

View File

@@ -0,0 +1,82 @@
package relay
import (
"flag"
"log"
"os"
"strconv"
"github.com/pion/webrtc/v4"
)
var globalFlags *Flags
type Flags struct {
Verbose bool
Debug bool
EndpointPort int
WebRTCUDPStart int
WebRTCUDPEnd int
STUNServer string
}
func (flags *Flags) DebugLog() {
log.Println("Relay Flags:")
log.Println("> Verbose: ", flags.Verbose)
log.Println("> Debug: ", flags.Debug)
log.Println("> Endpoint Port: ", flags.EndpointPort)
log.Println("> WebRTC UDP Range Start: ", flags.WebRTCUDPStart)
log.Println("> WebRTC UDP Range End: ", flags.WebRTCUDPEnd)
log.Println("> WebRTC STUN Server: ", flags.STUNServer)
}
func getEnvAsInt(name string, defaultVal int) int {
valueStr := os.Getenv(name)
if value, err := strconv.Atoi(valueStr); err != nil {
return defaultVal
} else {
return value
}
}
func getEnvAsBool(name string, defaultVal bool) bool {
valueStr := os.Getenv(name)
val, err := strconv.ParseBool(valueStr)
if err != nil {
return defaultVal
}
return val
}
func getEnvAsString(name string, defaultVal string) string {
valueStr := os.Getenv(name)
if len(valueStr) == 0 {
return defaultVal
}
return valueStr
}
func InitFlags() {
// Create Flags struct
globalFlags = &Flags{}
// Get flags
flag.BoolVar(&globalFlags.Verbose, "verbose", getEnvAsBool("VERBOSE", false), "Verbose mode")
flag.BoolVar(&globalFlags.Debug, "debug", getEnvAsBool("DEBUG", false), "Debug mode")
flag.IntVar(&globalFlags.EndpointPort, "endpointPort", getEnvAsInt("ENDPOINT_PORT", 8088), "HTTP endpoint port")
flag.IntVar(&globalFlags.WebRTCUDPStart, "webrtcUDPStart", getEnvAsInt("WEBRTC_UDP_START", 10000), "WebRTC UDP port range start")
flag.IntVar(&globalFlags.WebRTCUDPEnd, "webrtcUDPEnd", getEnvAsInt("WEBRTC_UDP_END", 20000), "WebRTC UDP port range end")
flag.StringVar(&globalFlags.STUNServer, "stunServer", getEnvAsString("STUN_SERVER", "stun.l.google.com:19302"), "WebRTC STUN server")
// Parse flags
flag.Parse()
// ICE STUN servers
globalWebRTCConfig.ICEServers = []webrtc.ICEServer{
{
URLs: []string{"stun:" + globalFlags.STUNServer},
},
}
}
func GetFlags() *Flags {
return globalFlags
}

View File

@@ -0,0 +1,123 @@
package relay
import (
"github.com/gorilla/websocket"
"log"
"net/http"
"strconv"
)
var httpMux *http.ServeMux
func InitHTTPEndpoint() {
// Create HTTP mux which serves our WS endpoint
httpMux = http.NewServeMux()
// Endpoints themselves
httpMux.Handle("/", http.NotFoundHandler())
httpMux.HandleFunc("/api/ws/{roomName}", corsAnyHandler(wsHandler))
// Get our serving port
port := GetFlags().EndpointPort
// Log and start the endpoint server
log.Println("Starting HTTP endpoint server on :", strconv.Itoa(port))
go func() {
log.Fatal((&http.Server{
Handler: httpMux,
Addr: ":" + strconv.Itoa(port),
}).ListenAndServe())
}()
}
// logHTTPError logs (if verbose) and sends an error code to requester
func logHTTPError(w http.ResponseWriter, err string, code int) {
if GetFlags().Verbose {
log.Println(err)
}
http.Error(w, err, code)
}
// corsAnyHandler allows any origin to access the endpoint
func corsAnyHandler(next func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(res http.ResponseWriter, req *http.Request) {
// Allow all origins
res.Header().Set("Access-Control-Allow-Origin", "*")
res.Header().Set("Access-Control-Allow-Methods", "*")
res.Header().Set("Access-Control-Allow-Headers", "*")
if req.Method != http.MethodOptions {
next(res, req)
}
}
}
// wsHandler is the handler for the /api/ws/{roomName} endpoint
func wsHandler(w http.ResponseWriter, r *http.Request) {
// Get given room name now
roomName := r.PathValue("roomName")
if len(roomName) <= 0 {
logHTTPError(w, "no room name given", http.StatusBadRequest)
return
}
// Get or create room in any case
room := GetOrCreateRoom(roomName)
// Upgrade to WebSocket
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
wsConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logHTTPError(w, err.Error(), http.StatusInternalServerError)
return
}
// Create SafeWebSocket
ws := NewSafeWebSocket(wsConn)
// Assign message handler for join request
ws.RegisterMessageCallback("join", func(data []byte) {
var joinMsg MessageJoin
if err = DecodeMessage(data, &joinMsg); err != nil {
log.Printf("Failed to decode join message: %s\n", err)
return
}
if GetFlags().Verbose {
log.Printf("Join request for room: '%s' from: '%s'\n", room.Name, joinMsg.JoinerType.String())
}
// Handle join request, depending if it's from ingest/node or participant/client
switch joinMsg.JoinerType {
case JoinerNode:
// If room already online, send InUse answer
if room.Online {
if err = ws.SendAnswerMessageWS(AnswerInUse); err != nil {
log.Printf("Failed to send InUse answer for Room: '%s' - reason: %s\n", room.Name, err)
}
return
}
room.assignWebSocket(ws)
go ingestHandler(room)
case JoinerClient:
// Create participant and add to room regardless of online status
participant := NewParticipant(ws)
room.addParticipant(participant)
// If room not online, send Offline answer
if !room.Online {
if err = ws.SendAnswerMessageWS(AnswerOffline); err != nil {
log.Printf("Failed to send Offline answer for Room: '%s' - reason: %s\n", room.Name, err)
}
}
go participantHandler(participant, room)
default:
log.Printf("Unknown joiner type: %d\n", joinMsg.JoinerType)
}
// Unregister ourselves, if something happens on the other side they should just reconnect?
ws.UnregisterMessageCallback("join")
})
}

View File

@@ -0,0 +1,251 @@
package relay
import (
"errors"
"fmt"
"github.com/pion/webrtc/v4"
"io"
"log"
"strings"
)
func ingestHandler(room *Room) {
// Callback for closing PeerConnection
onPCClose := func() {
if GetFlags().Verbose {
log.Printf("Closed PeerConnection for room: '%s'\n", room.Name)
}
room.Online = false
DeleteRoomIfEmpty(room)
}
var err error
room.PeerConnection, err = CreatePeerConnection(onPCClose)
if err != nil {
log.Printf("Failed to create PeerConnection for room: '%s' - reason: %s\n", room.Name, err)
return
}
room.PeerConnection.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
var localTrack *webrtc.TrackLocalStaticRTP
if remoteTrack.Kind() == webrtc.RTPCodecTypeVideo {
if GetFlags().Verbose {
log.Printf("Received video track for room: '%s'\n", room.Name)
}
localTrack, err = webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, "video", fmt.Sprint("nestri-", room.Name))
if err != nil {
log.Printf("Failed to create local video track for room: '%s' - reason: %s\n", room.Name, err)
return
}
room.VideoTrack = localTrack
} else if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio {
if GetFlags().Verbose {
log.Printf("Received audio track for room: '%s'\n", room.Name)
}
localTrack, err = webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, "audio", fmt.Sprint("nestri-", room.Name))
if err != nil {
log.Printf("Failed to create local audio track for room: '%s' - reason: %s\n", room.Name, err)
return
}
room.AudioTrack = localTrack
}
// If both audio and video tracks are set, set online state
if room.AudioTrack != nil && room.VideoTrack != nil {
room.Online = true
if GetFlags().Verbose {
log.Printf("Room online and receiving: '%s' - signaling participants\n", room.Name)
}
room.signalParticipantsWithTracks()
}
rtpBuffer := make([]byte, 1400)
for {
read, _, err := remoteTrack.Read(rtpBuffer)
if err != nil {
// EOF is expected when stopping room
if !errors.Is(err, io.EOF) {
log.Printf("RTP read error from room: '%s' - reason: %s\n", room.Name, err)
}
break
}
_, err = localTrack.Write(rtpBuffer[:read])
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
log.Printf("Failed to write RTP to local track for room: '%s' - reason: %s\n", room.Name, err)
break
}
}
if remoteTrack.Kind() == webrtc.RTPCodecTypeVideo {
room.VideoTrack = nil
} else if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio {
room.AudioTrack = nil
}
if room.VideoTrack == nil && room.AudioTrack == nil {
room.Online = false
if GetFlags().Verbose {
log.Printf("Room offline and not receiving: '%s'\n", room.Name)
}
// Signal participants of room offline
room.signalParticipantsOffline()
DeleteRoomIfEmpty(room)
}
})
room.PeerConnection.OnDataChannel(func(dc *webrtc.DataChannel) {
room.DataChannel = NewNestriDataChannel(dc)
if GetFlags().Verbose {
log.Printf("New DataChannel for room: '%s' - '%s'\n", room.Name, room.DataChannel.Label())
}
// Register channel opening handling
room.DataChannel.RegisterOnOpen(func() {
if GetFlags().Verbose {
log.Printf("DataChannel for room: '%s' - '%s' open\n", room.Name, room.DataChannel.Label())
}
})
room.DataChannel.OnClose(func() {
if GetFlags().Verbose {
log.Printf("DataChannel for room: '%s' - '%s' closed\n", room.Name, room.DataChannel.Label())
}
})
// We do not handle any messages from ingest via DataChannel yet
})
room.PeerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
return
}
if GetFlags().Verbose {
log.Printf("ICE candidate for room: '%s'\n", room.Name)
}
err = room.WebSocket.SendICECandidateMessageWS(candidate.ToJSON())
if err != nil {
log.Printf("Failed to send ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
}
})
iceHolder := make([]webrtc.ICECandidateInit, 0)
// ICE callback
room.WebSocket.RegisterMessageCallback("ice", func(data []byte) {
var iceMsg MessageICECandidate
if err = DecodeMessage(data, &iceMsg); err != nil {
log.Printf("Failed to decode ICE candidate message from ingest for room: '%s' - reason: %s\n", room.Name, err)
return
}
candidate := webrtc.ICECandidateInit{
Candidate: iceMsg.Candidate.Candidate,
}
if room.PeerConnection != nil {
// If remote isn't set yet, store ICE candidates
if room.PeerConnection.RemoteDescription() != nil {
if err = room.PeerConnection.AddICECandidate(candidate); err != nil {
log.Printf("Failed to add ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
}
// Add any held ICE candidates
for _, heldCandidate := range iceHolder {
if err = room.PeerConnection.AddICECandidate(heldCandidate); err != nil {
log.Printf("Failed to add held ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
}
}
iceHolder = nil
} else {
iceHolder = append(iceHolder, candidate)
}
} else {
log.Printf("ICE candidate received before PeerConnection for room: '%s'\n", room.Name)
}
})
// SDP offer callback
room.WebSocket.RegisterMessageCallback("sdp", func(data []byte) {
var sdpMsg MessageSDP
if err = DecodeMessage(data, &sdpMsg); err != nil {
log.Printf("Failed to decode SDP message from ingest for room: '%s' - reason: %s\n", room.Name, err)
return
}
answer := handleIngestSDP(room, sdpMsg)
if answer != nil {
if err = room.WebSocket.SendSDPMessageWS(*answer); err != nil {
log.Printf("Failed to send SDP answer to ingest for room: '%s' - reason: %s\n", room.Name, err)
}
} else {
log.Printf("Failed to handle SDP message from ingest for room: '%s'\n", room.Name)
}
})
// Log callback
room.WebSocket.RegisterMessageCallback("log", func(data []byte) {
var logMsg MessageLog
if err = DecodeMessage(data, &logMsg); err != nil {
log.Printf("Failed to decode log message from ingest for room: '%s' - reason: %s\n", room.Name, err)
return
}
// TODO: Handle log message sending to metrics server
})
// Metrics callback
room.WebSocket.RegisterMessageCallback("metrics", func(data []byte) {
var metricsMsg MessageMetrics
if err = DecodeMessage(data, &metricsMsg); err != nil {
log.Printf("Failed to decode metrics message from ingest for room: '%s' - reason: %s\n", room.Name, err)
return
}
// TODO: Handle metrics message sending to metrics server
})
room.WebSocket.RegisterOnClose(func() {
// If PeerConnection is not open or does not exist, delete room
if (room.PeerConnection != nil && room.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateConnected) ||
room.PeerConnection == nil {
DeleteRoomIfEmpty(room)
}
})
log.Printf("Room: '%s' is ready, sending an OK\n", room.Name)
if err = room.WebSocket.SendAnswerMessageWS(AnswerOK); err != nil {
log.Printf("Failed to send OK answer for room: '%s' - reason: %s\n", room.Name, err)
}
}
// SDP offer handler, returns SDP answer
func handleIngestSDP(room *Room, offerMsg MessageSDP) *webrtc.SessionDescription {
var err error
// Get SDP offer
sdpOffer := offerMsg.SDP.SDP
// Modify SDP offer to remove opus "sprop-maxcapturerate=24000" (fixes opus bad quality issue, present in GStreamer)
sdpOffer = strings.Replace(sdpOffer, ";sprop-maxcapturerate=24000", "", -1)
// Set new remote description
err = room.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer,
SDP: sdpOffer,
})
if err != nil {
log.Printf("Failed to set remote description for room: '%s' - reason: %s\n", room.Name, err)
return nil
}
// Create SDP answer
answer, err := room.PeerConnection.CreateAnswer(nil)
if err != nil {
log.Printf("Failed to create SDP answer for room: '%s' - reason: %s\n", room.Name, err)
return nil
}
// Set local description
err = room.PeerConnection.SetLocalDescription(answer)
if err != nil {
log.Printf("Failed to set local description for room: '%s' - reason: %s\n", room.Name, err)
return nil
}
return &answer
}

View File

@@ -0,0 +1,114 @@
package relay
import (
"fmt"
"time"
)
type TimestampEntry struct {
Stage string `json:"stage"`
Time string `json:"time"` // ISO 8601 string
}
// LatencyTracker provides a generic structure for measuring time taken at various stages in message processing.
// It can be embedded in message structs for tracking the flow of data and calculating round-trip latency.
type LatencyTracker struct {
SequenceID string `json:"sequence_id"`
Timestamps []TimestampEntry `json:"timestamps"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// NewLatencyTracker initializes a new LatencyTracker with the given sequence ID
func NewLatencyTracker(sequenceID string) *LatencyTracker {
return &LatencyTracker{
SequenceID: sequenceID,
Timestamps: make([]TimestampEntry, 0),
Metadata: make(map[string]string),
}
}
// AddTimestamp adds a new timestamp for a specific stage
func (lt *LatencyTracker) AddTimestamp(stage string) {
lt.Timestamps = append(lt.Timestamps, TimestampEntry{
Stage: stage,
// Ensure extremely precise UTC RFC3339 timestamps (down to nanoseconds)
Time: time.Now().UTC().Format(time.RFC3339Nano),
})
}
// TotalLatency calculates the total latency from the earliest to the latest timestamp
func (lt *LatencyTracker) TotalLatency() (int64, error) {
if len(lt.Timestamps) < 2 {
return 0, nil // Not enough timestamps to calculate latency
}
var earliest, latest time.Time
for _, ts := range lt.Timestamps {
t, err := time.Parse(time.RFC3339, ts.Time)
if err != nil {
return 0, err
}
if earliest.IsZero() || t.Before(earliest) {
earliest = t
}
if latest.IsZero() || t.After(latest) {
latest = t
}
}
return latest.Sub(earliest).Milliseconds(), nil
}
// PainPoints returns a list of stages where the duration exceeds the given threshold.
func (lt *LatencyTracker) PainPoints(threshold time.Duration) []string {
var painPoints []string
var lastStage string
var lastTime time.Time
for _, ts := range lt.Timestamps {
stage := ts.Stage
t := ts.Time
if lastStage == "" {
lastStage = stage
lastTime, _ = time.Parse(time.RFC3339, t)
continue
}
currentTime, _ := time.Parse(time.RFC3339, t)
if currentTime.Sub(lastTime) > threshold {
painPoints = append(painPoints, fmt.Sprintf("%s -> %s", lastStage, stage))
}
lastStage = stage
lastTime = currentTime
}
return painPoints
}
// StageLatency calculates the time taken between two specific stages.
func (lt *LatencyTracker) StageLatency(startStage, endStage string) (time.Duration, error) {
startTime, endTime := "", ""
for _, ts := range lt.Timestamps {
if ts.Stage == startStage {
startTime = ts.Time
}
if ts.Stage == endStage {
endTime = ts.Time
}
}
if startTime == "" || endTime == "" {
return 0, fmt.Errorf("missing timestamps for stages: %s -> %s", startStage, endStage)
}
start, err := time.Parse(time.RFC3339, startTime)
if err != nil {
return 0, err
}
end, err := time.Parse(time.RFC3339, endTime)
if err != nil {
return 0, err
}
return end.Sub(start), nil
}

View File

@@ -0,0 +1,227 @@
package relay
import (
"bytes"
"compress/gzip"
"encoding/json"
"fmt"
"github.com/pion/webrtc/v4"
"time"
)
// OnMessageCallback is a callback for binary messages of given type
type OnMessageCallback func(data []byte)
// MessageBase is the base type for WS/DC messages.
type MessageBase struct {
PayloadType string `json:"payload_type"`
LatencyTracker LatencyTracker `json:"latency_tracker,omitempty"`
}
// MessageInput represents an input message.
type MessageInput struct {
MessageBase
Data string `json:"data"`
}
// MessageLog represents a log message.
type MessageLog struct {
MessageBase
Level string `json:"level"`
Message string `json:"message"`
Time string `json:"time"`
}
// MessageMetrics represents a metrics/heartbeat message.
type MessageMetrics struct {
MessageBase
UsageCPU float64 `json:"usage_cpu"`
UsageMemory float64 `json:"usage_memory"`
Uptime uint64 `json:"uptime"`
PipelineLatency float64 `json:"pipeline_latency"`
}
// MessageICECandidate represents an ICE candidate message.
type MessageICECandidate struct {
MessageBase
Candidate webrtc.ICECandidateInit `json:"candidate"`
}
// MessageSDP represents an SDP message.
type MessageSDP struct {
MessageBase
SDP webrtc.SessionDescription `json:"sdp"`
}
// JoinerType is an enum for the type of incoming room joiner
type JoinerType int
const (
JoinerNode JoinerType = iota
JoinerClient
)
func (jt *JoinerType) String() string {
switch *jt {
case JoinerNode:
return "node"
case JoinerClient:
return "client"
default:
return "unknown"
}
}
// MessageJoin is used to tell us that either participant or ingest wants to join the room
type MessageJoin struct {
MessageBase
JoinerType JoinerType `json:"joiner_type"`
}
// AnswerType is an enum for the type of answer, signaling Room state for a joiner
type AnswerType int
const (
AnswerOffline AnswerType = iota // For participant/client, when the room is offline without stream
AnswerInUse // For ingest/node joiner, when the room is already in use by another ingest/node
AnswerOK // For both, when the join request is handled successfully
)
// MessageAnswer is used to send the answer to a join request
type MessageAnswer struct {
MessageBase
AnswerType AnswerType `json:"answer_type"`
}
// EncodeMessage encodes a message to be sent with gzip compression
func EncodeMessage(msg interface{}) ([]byte, error) {
// Marshal the message to JSON
data, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("failed to encode message: %w", err)
}
// Gzip compress the JSON
var compressedData bytes.Buffer
writer := gzip.NewWriter(&compressedData)
_, err = writer.Write(data)
if err != nil {
return nil, fmt.Errorf("failed to compress message: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("failed to finalize compression: %w", err)
}
return compressedData.Bytes(), nil
}
// DecodeMessage decodes a message received with gzip decompression
func DecodeMessage(data []byte, target interface{}) error {
// Gzip decompress the data
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return fmt.Errorf("failed to initialize decompression: %w", err)
}
defer func(reader *gzip.Reader) {
if err = reader.Close(); err != nil {
fmt.Printf("failed to close reader: %v\n", err)
}
}(reader)
// Decode the JSON
err = json.NewDecoder(reader).Decode(target)
if err != nil {
return fmt.Errorf("failed to decode message: %w", err)
}
return nil
}
// SendLogMessageWS sends a log message to the given WebSocket connection.
func (ws *SafeWebSocket) SendLogMessageWS(level, message string) error {
msg := MessageLog{
MessageBase: MessageBase{PayloadType: "log"},
Level: level,
Message: message,
Time: time.Now().Format(time.RFC3339),
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode log message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendMetricsMessageWS sends a metrics message to the given WebSocket connection.
func (ws *SafeWebSocket) SendMetricsMessageWS(usageCPU, usageMemory float64, uptime uint64, pipelineLatency float64) error {
msg := MessageMetrics{
MessageBase: MessageBase{PayloadType: "metrics"},
UsageCPU: usageCPU,
UsageMemory: usageMemory,
Uptime: uptime,
PipelineLatency: pipelineLatency,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode metrics message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendICECandidateMessageWS sends an ICE candidate message to the given WebSocket connection.
func (ws *SafeWebSocket) SendICECandidateMessageWS(candidate webrtc.ICECandidateInit) error {
msg := MessageICECandidate{
MessageBase: MessageBase{PayloadType: "ice"},
Candidate: candidate,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode ICE candidate message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendSDPMessageWS sends an SDP message to the given WebSocket connection.
func (ws *SafeWebSocket) SendSDPMessageWS(sdp webrtc.SessionDescription) error {
msg := MessageSDP{
MessageBase: MessageBase{PayloadType: "sdp"},
SDP: sdp,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode SDP message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendAnswerMessageWS sends an answer message to the given WebSocket connection.
func (ws *SafeWebSocket) SendAnswerMessageWS(answer AnswerType) error {
msg := MessageAnswer{
MessageBase: MessageBase{PayloadType: "answer"},
AnswerType: answer,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode answer message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendInputMessageDC sends an input message to the given DataChannel connection.
func (ndc *NestriDataChannel) SendInputMessageDC(data string) error {
msg := MessageInput{
MessageBase: MessageBase{PayloadType: "input"},
Data: data,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode input message: %w", err)
}
return ndc.SendBinary(encoded)
}

View File

@@ -0,0 +1,69 @@
package relay
import (
"fmt"
"github.com/google/uuid"
"github.com/pion/webrtc/v4"
"math/rand"
)
type Participant struct {
ID uuid.UUID //< Internal IDs are useful to keeping unique internal track and not have conflicts later
Name string
WebSocket *SafeWebSocket
PeerConnection *webrtc.PeerConnection
DataChannel *NestriDataChannel
}
func NewParticipant(ws *SafeWebSocket) *Participant {
return &Participant{
ID: uuid.New(),
Name: createRandomName(),
WebSocket: ws,
}
}
func (vw *Participant) addTrack(trackLocal *webrtc.TrackLocal) error {
rtpSender, err := vw.PeerConnection.AddTrack(*trackLocal)
if err != nil {
return err
}
go func() {
rtcpBuffer := make([]byte, 1400)
for {
if _, _, rtcpErr := rtpSender.Read(rtcpBuffer); rtcpErr != nil {
return
}
}
}()
return nil
}
func (vw *Participant) signalOffer() error {
if vw.PeerConnection == nil {
return fmt.Errorf("peer connection is nil for participant: '%s' - cannot signal offer", vw.ID)
}
offer, err := vw.PeerConnection.CreateOffer(nil)
if err != nil {
return err
}
err = vw.PeerConnection.SetLocalDescription(offer)
if err != nil {
return err
}
return vw.WebSocket.SendSDPMessageWS(offer)
}
var namesFirst = []string{"Happy", "Sad", "Angry", "Calm", "Excited", "Bored", "Confused", "Confident", "Curious", "Depressed", "Disappointed", "Embarrassed", "Energetic", "Fearful", "Frustrated", "Glad", "Guilty", "Hopeful", "Impatient", "Jealous", "Lonely", "Motivated", "Nervous", "Optimistic", "Pessimistic", "Proud", "Relaxed", "Shy", "Stressed", "Surprised", "Tired", "Worried"}
var namesSecond = []string{"Dragon", "Unicorn", "Troll", "Goblin", "Elf", "Dwarf", "Ogre", "Gnome", "Mermaid", "Siren", "Vampire", "Ghoul", "Werewolf", "Minotaur", "Centaur", "Griffin", "Phoenix", "Wyvern", "Hydra", "Kraken"}
func createRandomName() string {
randomFirst := namesFirst[rand.Intn(len(namesFirst))]
randomSecond := namesSecond[rand.Intn(len(namesSecond))]
return randomFirst + " " + randomSecond
}

View File

@@ -0,0 +1,179 @@
package relay
import (
"github.com/google/uuid"
"github.com/pion/webrtc/v4"
"log"
"sync"
)
var Rooms = make(map[uuid.UUID]*Room) //< Room ID -> Room
var RoomsMutex = sync.RWMutex{}
func GetRoomByID(id uuid.UUID) *Room {
RoomsMutex.RLock()
defer RoomsMutex.RUnlock()
if room, ok := Rooms[id]; ok {
return room
}
return nil
}
func GetRoomByName(name string) *Room {
RoomsMutex.RLock()
defer RoomsMutex.RUnlock()
for _, room := range Rooms {
if room.Name == name {
return room
}
}
return nil
}
func GetOrCreateRoom(name string) *Room {
if room := GetRoomByName(name); room != nil {
return room
}
RoomsMutex.Lock()
room := NewRoom(name)
Rooms[room.ID] = room
if GetFlags().Verbose {
log.Printf("New room: '%s'\n", room.Name)
}
RoomsMutex.Unlock()
return room
}
func DeleteRoomIfEmpty(room *Room) {
room.ParticipantsMutex.RLock()
defer room.ParticipantsMutex.RUnlock()
if !room.Online && len(room.Participants) <= 0 {
RoomsMutex.Lock()
delete(Rooms, room.ID)
RoomsMutex.Unlock()
}
}
type Room struct {
ID uuid.UUID //< Internal IDs are useful to keeping unique internal track
Name string
Online bool //< Whether the room is currently online, i.e. receiving data from a nestri-server
WebSocket *SafeWebSocket
PeerConnection *webrtc.PeerConnection
AudioTrack webrtc.TrackLocal
VideoTrack webrtc.TrackLocal
DataChannel *NestriDataChannel
Participants map[uuid.UUID]*Participant
ParticipantsMutex sync.RWMutex
}
func NewRoom(name string) *Room {
return &Room{
ID: uuid.New(),
Name: name,
Online: false,
Participants: make(map[uuid.UUID]*Participant),
}
}
// Assigns a WebSocket connection to a Room
func (r *Room) assignWebSocket(ws *SafeWebSocket) {
// If WS already assigned, warn
if r.WebSocket != nil {
log.Printf("Warning: Room '%s' already has a WebSocket assigned\n", r.Name)
}
r.WebSocket = ws
}
// Adds a Participant to a Room
func (r *Room) addParticipant(participant *Participant) {
r.ParticipantsMutex.Lock()
r.Participants[participant.ID] = participant
r.ParticipantsMutex.Unlock()
}
// Removes a Participant from a Room by participant's ID.
// If Room is offline and this is the last participant, the room is deleted
func (r *Room) removeParticipantByID(pID uuid.UUID) {
r.ParticipantsMutex.Lock()
delete(r.Participants, pID)
r.ParticipantsMutex.Unlock()
DeleteRoomIfEmpty(r)
}
// Removes a Participant from a Room by participant's name.
// If Room is offline and this is the last participant, the room is deleted
func (r *Room) removeParticipantByName(pName string) {
r.ParticipantsMutex.Lock()
for id, p := range r.Participants {
if p.Name == pName {
delete(r.Participants, id)
break
}
}
r.ParticipantsMutex.Unlock()
DeleteRoomIfEmpty(r)
}
// Signals all participants with offer and add tracks to their PeerConnections
func (r *Room) signalParticipantsWithTracks() {
r.ParticipantsMutex.RLock()
for _, participant := range r.Participants {
// Add tracks to participant's PeerConnection
if r.AudioTrack != nil {
if err := participant.addTrack(&r.AudioTrack); err != nil {
log.Printf("Failed to add audio track to participant: '%s' - reason: %s\n", participant.ID, err)
}
}
if r.VideoTrack != nil {
if err := participant.addTrack(&r.VideoTrack); err != nil {
log.Printf("Failed to add video track to participant: '%s' - reason: %s\n", participant.ID, err)
}
}
// Signal participant with offer
if err := participant.signalOffer(); err != nil {
log.Printf("Error signaling participant: %v\n", err)
}
}
r.ParticipantsMutex.RUnlock()
}
// Signals all participants that the Room is offline
func (r *Room) signalParticipantsOffline() {
r.ParticipantsMutex.RLock()
for _, participant := range r.Participants {
if err := participant.WebSocket.SendAnswerMessageWS(AnswerOffline); err != nil {
log.Printf("Failed to send Offline answer for participant: '%s' - reason: %s\n", participant.ID, err)
}
}
r.ParticipantsMutex.RUnlock()
}
// Broadcasts a message to Room's Participant's - excluding one given ID of
func (r *Room) broadcastMessage(msg webrtc.DataChannelMessage, excludeID uuid.UUID) {
r.ParticipantsMutex.RLock()
for d, participant := range r.Participants {
if participant.DataChannel != nil {
if d != excludeID { // Don't send back to the sender
if err := participant.DataChannel.SendText(string(msg.Data)); err != nil {
log.Printf("Error broadcasting to %s: %v\n", participant.Name, err)
}
}
}
}
if r.DataChannel != nil {
if err := r.DataChannel.SendText(string(msg.Data)); err != nil {
log.Printf("Error broadcasting to Room: %v\n", err)
}
}
r.ParticipantsMutex.RUnlock()
}
// Sends message to Room (nestri-server)
func (r *Room) sendToRoom(msg webrtc.DataChannelMessage) {
if r.DataChannel != nil {
if err := r.DataChannel.SendText(string(msg.Data)); err != nil {
log.Printf("Error broadcasting to Room: %v\n", err)
}
}
}

View File

@@ -0,0 +1,114 @@
package relay
import (
"github.com/gorilla/websocket"
"log"
"sync"
)
// SafeWebSocket is a websocket with a mutex
type SafeWebSocket struct {
*websocket.Conn
sync.Mutex
binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback
}
// NewSafeWebSocket creates a new SafeWebSocket from *websocket.Conn
func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
ws := &SafeWebSocket{
Conn: conn,
binaryCallbacks: make(map[string]OnMessageCallback),
}
// Launch a goroutine to handle binary messages
go func() {
for {
// Read binary message
kind, data, err := ws.Conn.ReadMessage()
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
// If unexpected close error, break
if GetFlags().Verbose {
log.Printf("Unexpected WebSocket close error, reason: %s\n", err)
}
break
} else if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
// If closing, just break
if GetFlags().Verbose {
log.Printf("WebSocket closing\n")
}
break
} else if err != nil {
log.Printf("Failed to read WebSocket message, reason: %s\n", err)
break
}
switch kind {
case websocket.TextMessage:
// Ignore, we use binary messages
continue
case websocket.BinaryMessage:
// Decode message
var msg MessageBase
if err = DecodeMessage(data, &msg); err != nil {
log.Printf("Failed to decode binary WebSocket message, reason: %s\n", err)
continue
}
// Handle message type callback
if callback, ok := ws.binaryCallbacks[msg.PayloadType]; ok {
callback(data)
} // TODO: Log unknown message type?
default:
log.Printf("Unknown WebSocket message type: %d\n", kind)
}
}
}()
return ws
}
// SendJSON writes JSON to a websocket with a mutex
func (ws *SafeWebSocket) SendJSON(v interface{}) error {
ws.Lock()
defer ws.Unlock()
return ws.Conn.WriteJSON(v)
}
// SendBinary writes binary to a websocket with a mutex
func (ws *SafeWebSocket) SendBinary(data []byte) error {
ws.Lock()
defer ws.Unlock()
return ws.Conn.WriteMessage(websocket.BinaryMessage, data)
}
// RegisterMessageCallback sets the callback for binary message of given type
func (ws *SafeWebSocket) RegisterMessageCallback(msgType string, callback OnMessageCallback) {
ws.Lock()
defer ws.Unlock()
if ws.binaryCallbacks == nil {
ws.binaryCallbacks = make(map[string]OnMessageCallback)
}
ws.binaryCallbacks[msgType] = callback
}
// UnregisterMessageCallback removes the callback for binary message of given type
func (ws *SafeWebSocket) UnregisterMessageCallback(msgType string) {
ws.Lock()
defer ws.Unlock()
if ws.binaryCallbacks != nil {
delete(ws.binaryCallbacks, msgType)
}
}
// RegisterOnClose sets the callback for websocket closing
func (ws *SafeWebSocket) RegisterOnClose(callback func()) {
ws.SetCloseHandler(func(code int, text string) error {
// Clear our callbacks
ws.Lock()
ws.binaryCallbacks = nil
ws.Unlock()
// Call the callback
callback()
return nil
})
}