Fixed multi-controllers, optimize and improve code in relay and nestri-server

This commit is contained in:
DatCaptainHorse
2025-10-25 03:57:26 +03:00
parent 67f9a7d0a0
commit a54cf759fa
27 changed files with 837 additions and 644 deletions

View File

@@ -26,7 +26,7 @@ func InitWebRTCAPI() error {
mediaEngine := &webrtc.MediaEngine{}
// Register our extensions
if err := RegisterExtensions(mediaEngine); err != nil {
if err = RegisterExtensions(mediaEngine); err != nil {
return fmt.Errorf("failed to register extensions: %w", err)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log/slog"
"math"
"relay/internal/common"
"relay/internal/connections"
"relay/internal/shared"
@@ -176,7 +177,7 @@ func (sp *StreamProtocol) handleStreamRequest(stream network.Stream) {
// Create participant for this viewer
participant, err := shared.NewParticipant(
"", // session ID will be set if this is a client session
"",
stream.Conn().RemotePeer(),
)
if err != nil {
@@ -189,102 +190,37 @@ func (sp *StreamProtocol) handleStreamRequest(stream network.Stream) {
participant.SessionID = session.SessionID
}
// Assign peer connection
participant.PeerConnection = pc
// Create per-participant tracks
if room.VideoTrack != nil {
participant.VideoTrack, err = webrtc.NewTrackLocalStaticRTP(
room.VideoTrack.Codec(),
"video-"+participant.ID.String(),
"nestri-"+reqMsg.RoomName+"-video",
// Add audio/video tracks
{
localTrack, err := webrtc.NewTrackLocalStaticRTP(
room.AudioCodec,
"participant-"+participant.ID.String(),
"participant-"+participant.ID.String()+"-audio",
)
if err != nil {
slog.Error("Failed to create participant video track", "room", reqMsg.RoomName, "err", err)
continue
slog.Error("Failed to create track for stream request", "err", err)
return
}
rtpSender, err := pc.AddTrack(participant.VideoTrack)
if err != nil {
slog.Error("Failed to add participant video track", "room", reqMsg.RoomName, "err", err)
continue
}
slog.Info("Added video track for participant",
"room", reqMsg.RoomName,
"participant", participant.ID,
"sender_id", fmt.Sprintf("%p", rtpSender))
// Relay packets from channel to track (VIDEO)
go func() {
for pkt := range participant.VideoChan {
// Use a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
done := make(chan error, 1)
go func() {
done <- participant.VideoTrack.WriteRTP(pkt)
}()
select {
case err := <-done:
cancel()
if err != nil {
if !errors.Is(err, io.ErrClosedPipe) {
slog.Debug("Failed to write video", "room", reqMsg.RoomName, "err", err)
}
return
}
case <-ctx.Done():
cancel()
slog.Error("WriteRTP BLOCKED for >100ms!",
"participant", participant.ID,
"room", reqMsg.RoomName)
// Don't return, continue processing
}
}
}()
participant.SetTrack(webrtc.RTPCodecTypeAudio, localTrack)
slog.Debug("Set audio track for requested stream", "room", room.Name)
}
if room.AudioTrack != nil {
participant.AudioTrack, err = webrtc.NewTrackLocalStaticRTP(
room.AudioTrack.Codec(),
"audio-"+participant.ID.String(),
"nestri-"+reqMsg.RoomName+"-audio",
{
localTrack, err := webrtc.NewTrackLocalStaticRTP(
room.VideoCodec,
"participant-"+participant.ID.String(),
"participant-"+participant.ID.String()+"-video",
)
if err != nil {
slog.Error("Failed to create participant audio track", "room", reqMsg.RoomName, "err", err)
continue
slog.Error("Failed to create track for stream request", "err", err)
return
}
_, err := pc.AddTrack(participant.AudioTrack)
if err != nil {
slog.Error("Failed to add participant audio track", "room", reqMsg.RoomName, "err", err)
continue
}
// Relay packets from channel to track (AUDIO)
go func() {
for pkt := range participant.AudioChan {
start := time.Now()
if err := participant.AudioTrack.WriteRTP(pkt); err != nil {
if !errors.Is(err, io.ErrClosedPipe) {
slog.Debug("Failed to write audio to participant", "room", reqMsg.RoomName, "err", err)
}
return
}
duration := time.Since(start)
if duration > 50*time.Millisecond {
slog.Warn("Slow audio WriteRTP detected",
"duration", duration,
"participant", participant.ID,
"room", reqMsg.RoomName)
}
}
}()
participant.SetTrack(webrtc.RTPCodecTypeVideo, localTrack)
slog.Debug("Set video track for requested stream", "room", room.Name)
}
// Add participant to room
room.AddParticipant(participant)
// Cleanup on disconnect
cleanupParticipantID := participant.ID
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
@@ -294,6 +230,9 @@ func (sp *StreamProtocol) handleStreamRequest(stream network.Stream) {
slog.Info("Participant disconnected from room", "room", reqMsg.RoomName, "participant", cleanupParticipantID)
room.RemoveParticipantByID(cleanupParticipantID)
participant.Close()
} else if state == webrtc.PeerConnectionStateConnected {
// Add participant to room when connection is established
room.AddParticipant(participant)
}
})
@@ -334,33 +273,33 @@ func (sp *StreamProtocol) handleStreamRequest(stream network.Stream) {
peerID := stream.Conn().RemotePeer()
// Check if it's a controller attach with assigned slot
if attach := msgWrapper.GetControllerAttach(); attach != nil && attach.Slot >= 0 {
if attach := msgWrapper.GetControllerAttach(); attach != nil && attach.SessionSlot >= 0 {
if session, ok := sp.relay.ClientSessions.Get(peerID); ok {
// Check if slot already tracked
hasSlot := false
for _, slot := range session.ControllerSlots {
if slot == attach.Slot {
if slot == attach.SessionSlot {
hasSlot = true
break
}
}
if !hasSlot {
session.ControllerSlots = append(session.ControllerSlots, attach.Slot)
session.ControllerSlots = append(session.ControllerSlots, attach.SessionSlot)
session.LastActivity = time.Now()
slog.Info("Controller slot assigned to client session",
"session", session.SessionID,
"slot", attach.Slot,
"slot", attach.SessionSlot,
"total_slots", len(session.ControllerSlots))
}
}
}
// Check if it's a controller detach
if detach := msgWrapper.GetControllerDetach(); detach != nil && detach.Slot >= 0 {
if detach := msgWrapper.GetControllerDetach(); detach != nil && detach.SessionSlot >= 0 {
if session, ok := sp.relay.ClientSessions.Get(peerID); ok {
newSlots := make([]int32, 0, len(session.ControllerSlots))
for _, slot := range session.ControllerSlots {
if slot != detach.Slot {
if slot != detach.SessionSlot {
newSlots = append(newSlots, slot)
}
}
@@ -368,7 +307,7 @@ func (sp *StreamProtocol) handleStreamRequest(stream network.Stream) {
session.LastActivity = time.Now()
slog.Info("Controller slot removed from client session",
"session", session.SessionID,
"slot", detach.Slot,
"slot", detach.SessionSlot,
"remaining_slots", len(session.ControllerSlots))
}
}
@@ -537,19 +476,25 @@ func (sp *StreamProtocol) handleStreamPush(stream network.Stream) {
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, network.ErrReset) {
slog.Debug("Stream push connection closed by peer", "peer", stream.Conn().RemotePeer(), "error", err)
if room != nil {
room.Close()
sp.incomingConns.Set(room.Name, nil)
}
return
}
slog.Error("Failed to receive data for stream push", "err", err)
_ = stream.Reset()
if room != nil {
room.Close()
sp.incomingConns.Set(room.Name, nil)
}
return
}
if msgWrapper.MessageBase == nil {
slog.Error("No MessageBase in stream push")
_ = stream.Reset()
return
continue
}
switch msgWrapper.MessageBase.PayloadType {
@@ -606,7 +551,7 @@ func (sp *StreamProtocol) handleStreamPush(stream network.Stream) {
slog.Error("Failed to add ICE candidate for pushed stream", "err", err)
}
for _, heldIce := range iceHolder {
if err := conn.pc.AddICECandidate(heldIce); err != nil {
if err = conn.pc.AddICECandidate(heldIce); err != nil {
slog.Error("Failed to add held ICE candidate for pushed stream", "err", err)
}
}
@@ -645,6 +590,9 @@ func (sp *StreamProtocol) handleStreamPush(stream network.Stream) {
continue
}
// Assign room peer connection
room.PeerConnection = pc
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
// TODO: Is this the best way to handle DataChannel? Should we just use the map directly?
room.DataChannel = connections.NewNestriDataChannel(dc)
@@ -708,17 +656,6 @@ func (sp *StreamProtocol) handleStreamPush(stream network.Stream) {
})
pc.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
localTrack, err := webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, remoteTrack.Kind().String(), fmt.Sprintf("nestri-%s-%s", room.Name, remoteTrack.Kind().String()))
if err != nil {
slog.Error("Failed to create local track for pushed stream", "room", room.Name, "track_kind", remoteTrack.Kind().String(), "err", err)
return
}
slog.Debug("Received track for pushed stream", "room", room.Name, "track_kind", remoteTrack.Kind().String())
// Set track for Room
room.SetTrack(remoteTrack.Kind(), localTrack)
// Prepare PlayoutDelayExtension so we don't need to recreate it for each packet
playoutExt := &rtp.PlayoutDelayExtension{
MinDelay: 0,
@@ -730,6 +667,12 @@ func (sp *StreamProtocol) handleStreamPush(stream network.Stream) {
return
}
if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio {
room.AudioCodec = remoteTrack.Codec().RTPCodecCapability
} else if remoteTrack.Kind() == webrtc.RTPCodecTypeVideo {
room.VideoCodec = remoteTrack.Codec().RTPCodecCapability
}
for {
rtpPacket, _, err := remoteTrack.ReadRTP()
if err != nil {
@@ -741,19 +684,61 @@ func (sp *StreamProtocol) handleStreamPush(stream network.Stream) {
// Use PlayoutDelayExtension for low latency, if set for this track kind
if extID, ok := common.GetExtension(remoteTrack.Kind(), common.ExtensionPlayoutDelay); ok {
if err := rtpPacket.SetExtension(extID, playoutPayload); err != nil {
if err = rtpPacket.SetExtension(extID, playoutPayload); err != nil {
slog.Error("Failed to set PlayoutDelayExtension for room", "room", room.Name, "err", err)
continue
}
}
room.BroadcastPacket(remoteTrack.Kind(), rtpPacket)
// Calculate differences
var timeDiff int64
var sequenceDiff int
if remoteTrack.Kind() == webrtc.RTPCodecTypeVideo {
timeDiff = int64(rtpPacket.Timestamp) - int64(room.LastVideoTimestamp)
if !room.VideoTimestampSet {
timeDiff = 0
room.VideoTimestampSet = true
} else if timeDiff < -(math.MaxUint32 / 10) {
timeDiff += math.MaxUint32 + 1
}
sequenceDiff = int(rtpPacket.SequenceNumber) - int(room.LastVideoSequenceNumber)
if !room.VideoSequenceSet {
sequenceDiff = 0
room.VideoSequenceSet = true
} else if sequenceDiff < -(math.MaxUint16 / 10) {
sequenceDiff += math.MaxUint16 + 1
}
room.LastVideoTimestamp = rtpPacket.Timestamp
room.LastVideoSequenceNumber = rtpPacket.SequenceNumber
} else { // Audio
timeDiff = int64(rtpPacket.Timestamp) - int64(room.LastAudioTimestamp)
if !room.AudioTimestampSet {
timeDiff = 0
room.AudioTimestampSet = true
} else if timeDiff < -(math.MaxUint32 / 10) {
timeDiff += math.MaxUint32 + 1
}
sequenceDiff = int(rtpPacket.SequenceNumber) - int(room.LastAudioSequenceNumber)
if !room.AudioSequenceSet {
sequenceDiff = 0
room.AudioSequenceSet = true
} else if sequenceDiff < -(math.MaxUint16 / 10) {
sequenceDiff += math.MaxUint16 + 1
}
room.LastAudioTimestamp = rtpPacket.Timestamp
room.LastAudioSequenceNumber = rtpPacket.SequenceNumber
}
// Broadcast with differences
room.BroadcastPacketRetimed(remoteTrack.Kind(), rtpPacket, timeDiff, sequenceDiff)
}
slog.Debug("Track closed for room", "room", room.Name, "track_kind", remoteTrack.Kind().String())
// Cleanup the track from the room
room.SetTrack(remoteTrack.Kind(), nil)
})
// Set the remote description

View File

@@ -45,7 +45,7 @@ func (r *Relay) DeleteRoomIfEmpty(room *shared.Room) {
if room == nil {
return
}
if room.Participants.Len() == 0 && r.LocalRooms.Has(room.ID) {
if len(room.Participants) <= 0 && r.LocalRooms.Has(room.ID) {
slog.Debug("Deleting empty room without participants", "room", room.Name)
r.LocalRooms.Delete(room.ID)
err := room.PeerConnection.Close()

View File

@@ -195,18 +195,18 @@ func (r *Relay) updateMeshRoomStates(peerID peer.ID, states []shared.RoomInfo) {
}
// If previously did not exist, but does now, request a connection if participants exist for our room
existed := r.Rooms.Has(state.ID.String())
/*existed := r.Rooms.Has(state.ID.String())
if !existed {
// Request connection to this peer if we have participants in our local room
if room, ok := r.LocalRooms.Get(state.ID); ok {
if room.Participants.Len() > 0 {
if len(room.Participants) > 0 {
slog.Debug("Got new remote room state, we locally have participants for, requesting stream", "room_name", room.Name, "peer", peerID)
if err := r.StreamProtocol.RequestStream(context.Background(), room, peerID); err != nil {
slog.Error("Failed to request stream for new remote room state", "room_name", room.Name, "peer", peerID, "err", err)
}
}
}
}
}*/
r.Rooms.Set(state.ID.String(), state)
}

View File

@@ -363,9 +363,9 @@ func (x *ProtoKeyUp) GetKey() int32 {
// ControllerAttach message
type ProtoControllerAttach struct {
state protoimpl.MessageState `protogen:"open.v1"`
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // One of the following enums: "ps", "xbox" or "switch"
Slot int32 `protobuf:"varint,2,opt,name=slot,proto3" json:"slot,omitempty"` // Slot number (0-3)
SessionId string `protobuf:"bytes,3,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` // Session ID of the client attaching the controller
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // One of the following enums: "ps", "xbox" or "switch"
SessionSlot int32 `protobuf:"varint,2,opt,name=session_slot,json=sessionSlot,proto3" json:"session_slot,omitempty"` // Session specific slot number (0-3)
SessionId string `protobuf:"bytes,3,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` // Session ID of the client
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -407,9 +407,9 @@ func (x *ProtoControllerAttach) GetId() string {
return ""
}
func (x *ProtoControllerAttach) GetSlot() int32 {
func (x *ProtoControllerAttach) GetSessionSlot() int32 {
if x != nil {
return x.Slot
return x.SessionSlot
}
return 0
}
@@ -424,7 +424,8 @@ func (x *ProtoControllerAttach) GetSessionId() string {
// ControllerDetach message
type ProtoControllerDetach struct {
state protoimpl.MessageState `protogen:"open.v1"`
Slot int32 `protobuf:"varint,1,opt,name=slot,proto3" json:"slot,omitempty"` // Slot number (0-3)
SessionSlot int32 `protobuf:"varint,1,opt,name=session_slot,json=sessionSlot,proto3" json:"session_slot,omitempty"` // Session specific slot number (0-3)
SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` // Session ID of the client
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -459,19 +460,27 @@ func (*ProtoControllerDetach) Descriptor() ([]byte, []int) {
return file_types_proto_rawDescGZIP(), []int{8}
}
func (x *ProtoControllerDetach) GetSlot() int32 {
func (x *ProtoControllerDetach) GetSessionSlot() int32 {
if x != nil {
return x.Slot
return x.SessionSlot
}
return 0
}
func (x *ProtoControllerDetach) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
// ControllerButton message
type ProtoControllerButton struct {
state protoimpl.MessageState `protogen:"open.v1"`
Slot int32 `protobuf:"varint,1,opt,name=slot,proto3" json:"slot,omitempty"` // Slot number (0-3)
Button int32 `protobuf:"varint,2,opt,name=button,proto3" json:"button,omitempty"` // Button code (linux input event code)
Pressed bool `protobuf:"varint,3,opt,name=pressed,proto3" json:"pressed,omitempty"` // true if pressed, false if released
SessionSlot int32 `protobuf:"varint,1,opt,name=session_slot,json=sessionSlot,proto3" json:"session_slot,omitempty"` // Session specific slot number (0-3)
SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` // Session ID of the client
Button int32 `protobuf:"varint,3,opt,name=button,proto3" json:"button,omitempty"` // Button code (linux input event code)
Pressed bool `protobuf:"varint,4,opt,name=pressed,proto3" json:"pressed,omitempty"` // true if pressed, false if released
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -506,13 +515,20 @@ func (*ProtoControllerButton) Descriptor() ([]byte, []int) {
return file_types_proto_rawDescGZIP(), []int{9}
}
func (x *ProtoControllerButton) GetSlot() int32 {
func (x *ProtoControllerButton) GetSessionSlot() int32 {
if x != nil {
return x.Slot
return x.SessionSlot
}
return 0
}
func (x *ProtoControllerButton) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *ProtoControllerButton) GetButton() int32 {
if x != nil {
return x.Button
@@ -530,9 +546,10 @@ func (x *ProtoControllerButton) GetPressed() bool {
// ControllerTriggers message
type ProtoControllerTrigger struct {
state protoimpl.MessageState `protogen:"open.v1"`
Slot int32 `protobuf:"varint,1,opt,name=slot,proto3" json:"slot,omitempty"` // Slot number (0-3)
Trigger int32 `protobuf:"varint,2,opt,name=trigger,proto3" json:"trigger,omitempty"` // Trigger number (0 for left, 1 for right)
Value int32 `protobuf:"varint,3,opt,name=value,proto3" json:"value,omitempty"` // trigger value (-32768 to 32767)
SessionSlot int32 `protobuf:"varint,1,opt,name=session_slot,json=sessionSlot,proto3" json:"session_slot,omitempty"` // Session specific slot number (0-3)
SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` // Session ID of the client
Trigger int32 `protobuf:"varint,3,opt,name=trigger,proto3" json:"trigger,omitempty"` // Trigger number (0 for left, 1 for right)
Value int32 `protobuf:"varint,4,opt,name=value,proto3" json:"value,omitempty"` // trigger value (-32768 to 32767)
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -567,13 +584,20 @@ func (*ProtoControllerTrigger) Descriptor() ([]byte, []int) {
return file_types_proto_rawDescGZIP(), []int{10}
}
func (x *ProtoControllerTrigger) GetSlot() int32 {
func (x *ProtoControllerTrigger) GetSessionSlot() int32 {
if x != nil {
return x.Slot
return x.SessionSlot
}
return 0
}
func (x *ProtoControllerTrigger) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *ProtoControllerTrigger) GetTrigger() int32 {
if x != nil {
return x.Trigger
@@ -591,10 +615,11 @@ func (x *ProtoControllerTrigger) GetValue() int32 {
// ControllerSticks message
type ProtoControllerStick struct {
state protoimpl.MessageState `protogen:"open.v1"`
Slot int32 `protobuf:"varint,1,opt,name=slot,proto3" json:"slot,omitempty"` // Slot number (0-3)
Stick int32 `protobuf:"varint,2,opt,name=stick,proto3" json:"stick,omitempty"` // Stick number (0 for left, 1 for right)
X int32 `protobuf:"varint,3,opt,name=x,proto3" json:"x,omitempty"` // X axis value (-32768 to 32767)
Y int32 `protobuf:"varint,4,opt,name=y,proto3" json:"y,omitempty"` // Y axis value (-32768 to 32767)
SessionSlot int32 `protobuf:"varint,1,opt,name=session_slot,json=sessionSlot,proto3" json:"session_slot,omitempty"` // Session specific slot number (0-3)
SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` // Session ID of the client
Stick int32 `protobuf:"varint,3,opt,name=stick,proto3" json:"stick,omitempty"` // Stick number (0 for left, 1 for right)
X int32 `protobuf:"varint,4,opt,name=x,proto3" json:"x,omitempty"` // X axis value (-32768 to 32767)
Y int32 `protobuf:"varint,5,opt,name=y,proto3" json:"y,omitempty"` // Y axis value (-32768 to 32767)
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -629,13 +654,20 @@ func (*ProtoControllerStick) Descriptor() ([]byte, []int) {
return file_types_proto_rawDescGZIP(), []int{11}
}
func (x *ProtoControllerStick) GetSlot() int32 {
func (x *ProtoControllerStick) GetSessionSlot() int32 {
if x != nil {
return x.Slot
return x.SessionSlot
}
return 0
}
func (x *ProtoControllerStick) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *ProtoControllerStick) GetStick() int32 {
if x != nil {
return x.Stick
@@ -660,9 +692,10 @@ func (x *ProtoControllerStick) GetY() int32 {
// ControllerAxis message
type ProtoControllerAxis struct {
state protoimpl.MessageState `protogen:"open.v1"`
Slot int32 `protobuf:"varint,1,opt,name=slot,proto3" json:"slot,omitempty"` // Slot number (0-3)
Axis int32 `protobuf:"varint,2,opt,name=axis,proto3" json:"axis,omitempty"` // Axis number (0 for d-pad horizontal, 1 for d-pad vertical)
Value int32 `protobuf:"varint,3,opt,name=value,proto3" json:"value,omitempty"` // axis value (-1 to 1)
SessionSlot int32 `protobuf:"varint,1,opt,name=session_slot,json=sessionSlot,proto3" json:"session_slot,omitempty"` // Session specific slot number (0-3)
SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` // Session ID of the client
Axis int32 `protobuf:"varint,3,opt,name=axis,proto3" json:"axis,omitempty"` // Axis number (0 for d-pad horizontal, 1 for d-pad vertical)
Value int32 `protobuf:"varint,4,opt,name=value,proto3" json:"value,omitempty"` // axis value (-1 to 1)
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -697,13 +730,20 @@ func (*ProtoControllerAxis) Descriptor() ([]byte, []int) {
return file_types_proto_rawDescGZIP(), []int{12}
}
func (x *ProtoControllerAxis) GetSlot() int32 {
func (x *ProtoControllerAxis) GetSessionSlot() int32 {
if x != nil {
return x.Slot
return x.SessionSlot
}
return 0
}
func (x *ProtoControllerAxis) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *ProtoControllerAxis) GetAxis() int32 {
if x != nil {
return x.Axis
@@ -721,10 +761,11 @@ func (x *ProtoControllerAxis) GetValue() int32 {
// ControllerRumble message
type ProtoControllerRumble struct {
state protoimpl.MessageState `protogen:"open.v1"`
Slot int32 `protobuf:"varint,1,opt,name=slot,proto3" json:"slot,omitempty"` // Slot number (0-3)
LowFrequency int32 `protobuf:"varint,2,opt,name=low_frequency,json=lowFrequency,proto3" json:"low_frequency,omitempty"` // Low frequency rumble (0-65535)
HighFrequency int32 `protobuf:"varint,3,opt,name=high_frequency,json=highFrequency,proto3" json:"high_frequency,omitempty"` // High frequency rumble (0-65535)
Duration int32 `protobuf:"varint,4,opt,name=duration,proto3" json:"duration,omitempty"` // Duration in milliseconds
SessionSlot int32 `protobuf:"varint,1,opt,name=session_slot,json=sessionSlot,proto3" json:"session_slot,omitempty"` // Session specific slot number (0-3)
SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` // Session ID of the client
LowFrequency int32 `protobuf:"varint,3,opt,name=low_frequency,json=lowFrequency,proto3" json:"low_frequency,omitempty"` // Low frequency rumble (0-65535)
HighFrequency int32 `protobuf:"varint,4,opt,name=high_frequency,json=highFrequency,proto3" json:"high_frequency,omitempty"` // High frequency rumble (0-65535)
Duration int32 `protobuf:"varint,5,opt,name=duration,proto3" json:"duration,omitempty"` // Duration in milliseconds
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -759,13 +800,20 @@ func (*ProtoControllerRumble) Descriptor() ([]byte, []int) {
return file_types_proto_rawDescGZIP(), []int{13}
}
func (x *ProtoControllerRumble) GetSlot() int32 {
func (x *ProtoControllerRumble) GetSessionSlot() int32 {
if x != nil {
return x.Slot
return x.SessionSlot
}
return 0
}
func (x *ProtoControllerRumble) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *ProtoControllerRumble) GetLowFrequency() int32 {
if x != nil {
return x.LowFrequency
@@ -1215,36 +1263,48 @@ const file_types_proto_rawDesc = "" +
"\x03key\x18\x01 \x01(\x05R\x03key\"\x1e\n" +
"\n" +
"ProtoKeyUp\x12\x10\n" +
"\x03key\x18\x01 \x01(\x05R\x03key\"Z\n" +
"\x03key\x18\x01 \x01(\x05R\x03key\"i\n" +
"\x15ProtoControllerAttach\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" +
"\x04slot\x18\x02 \x01(\x05R\x04slot\x12\x1d\n" +
"\x02id\x18\x01 \x01(\tR\x02id\x12!\n" +
"\fsession_slot\x18\x02 \x01(\x05R\vsessionSlot\x12\x1d\n" +
"\n" +
"session_id\x18\x03 \x01(\tR\tsessionId\"+\n" +
"\x15ProtoControllerDetach\x12\x12\n" +
"\x04slot\x18\x01 \x01(\x05R\x04slot\"]\n" +
"\x15ProtoControllerButton\x12\x12\n" +
"\x04slot\x18\x01 \x01(\x05R\x04slot\x12\x16\n" +
"\x06button\x18\x02 \x01(\x05R\x06button\x12\x18\n" +
"\apressed\x18\x03 \x01(\bR\apressed\"\\\n" +
"\x16ProtoControllerTrigger\x12\x12\n" +
"\x04slot\x18\x01 \x01(\x05R\x04slot\x12\x18\n" +
"\atrigger\x18\x02 \x01(\x05R\atrigger\x12\x14\n" +
"\x05value\x18\x03 \x01(\x05R\x05value\"\\\n" +
"\x14ProtoControllerStick\x12\x12\n" +
"\x04slot\x18\x01 \x01(\x05R\x04slot\x12\x14\n" +
"\x05stick\x18\x02 \x01(\x05R\x05stick\x12\f\n" +
"\x01x\x18\x03 \x01(\x05R\x01x\x12\f\n" +
"\x01y\x18\x04 \x01(\x05R\x01y\"S\n" +
"\x13ProtoControllerAxis\x12\x12\n" +
"\x04slot\x18\x01 \x01(\x05R\x04slot\x12\x12\n" +
"\x04axis\x18\x02 \x01(\x05R\x04axis\x12\x14\n" +
"\x05value\x18\x03 \x01(\x05R\x05value\"\x93\x01\n" +
"\x15ProtoControllerRumble\x12\x12\n" +
"\x04slot\x18\x01 \x01(\x05R\x04slot\x12#\n" +
"\rlow_frequency\x18\x02 \x01(\x05R\flowFrequency\x12%\n" +
"\x0ehigh_frequency\x18\x03 \x01(\x05R\rhighFrequency\x12\x1a\n" +
"\bduration\x18\x04 \x01(\x05R\bduration\"\xde\x01\n" +
"session_id\x18\x03 \x01(\tR\tsessionId\"Y\n" +
"\x15ProtoControllerDetach\x12!\n" +
"\fsession_slot\x18\x01 \x01(\x05R\vsessionSlot\x12\x1d\n" +
"\n" +
"session_id\x18\x02 \x01(\tR\tsessionId\"\x8b\x01\n" +
"\x15ProtoControllerButton\x12!\n" +
"\fsession_slot\x18\x01 \x01(\x05R\vsessionSlot\x12\x1d\n" +
"\n" +
"session_id\x18\x02 \x01(\tR\tsessionId\x12\x16\n" +
"\x06button\x18\x03 \x01(\x05R\x06button\x12\x18\n" +
"\apressed\x18\x04 \x01(\bR\apressed\"\x8a\x01\n" +
"\x16ProtoControllerTrigger\x12!\n" +
"\fsession_slot\x18\x01 \x01(\x05R\vsessionSlot\x12\x1d\n" +
"\n" +
"session_id\x18\x02 \x01(\tR\tsessionId\x12\x18\n" +
"\atrigger\x18\x03 \x01(\x05R\atrigger\x12\x14\n" +
"\x05value\x18\x04 \x01(\x05R\x05value\"\x8a\x01\n" +
"\x14ProtoControllerStick\x12!\n" +
"\fsession_slot\x18\x01 \x01(\x05R\vsessionSlot\x12\x1d\n" +
"\n" +
"session_id\x18\x02 \x01(\tR\tsessionId\x12\x14\n" +
"\x05stick\x18\x03 \x01(\x05R\x05stick\x12\f\n" +
"\x01x\x18\x04 \x01(\x05R\x01x\x12\f\n" +
"\x01y\x18\x05 \x01(\x05R\x01y\"\x81\x01\n" +
"\x13ProtoControllerAxis\x12!\n" +
"\fsession_slot\x18\x01 \x01(\x05R\vsessionSlot\x12\x1d\n" +
"\n" +
"session_id\x18\x02 \x01(\tR\tsessionId\x12\x12\n" +
"\x04axis\x18\x03 \x01(\x05R\x04axis\x12\x14\n" +
"\x05value\x18\x04 \x01(\x05R\x05value\"\xc1\x01\n" +
"\x15ProtoControllerRumble\x12!\n" +
"\fsession_slot\x18\x01 \x01(\x05R\vsessionSlot\x12\x1d\n" +
"\n" +
"session_id\x18\x02 \x01(\tR\tsessionId\x12#\n" +
"\rlow_frequency\x18\x03 \x01(\x05R\flowFrequency\x12%\n" +
"\x0ehigh_frequency\x18\x04 \x01(\x05R\rhighFrequency\x12\x1a\n" +
"\bduration\x18\x05 \x01(\x05R\bduration\"\xde\x01\n" +
"\x13RTCIceCandidateInit\x12\x1c\n" +
"\tcandidate\x18\x01 \x01(\tR\tcandidate\x12)\n" +
"\rsdpMLineIndex\x18\x02 \x01(\rH\x00R\rsdpMLineIndex\x88\x01\x01\x12\x1b\n" +

View File

@@ -1,14 +1,16 @@
package shared
import (
"errors"
"fmt"
"io"
"log/slog"
"relay/internal/common"
"relay/internal/connections"
"sync"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/oklog/ulid/v2"
"github.com/pion/rtp"
"github.com/pion/webrtc/v4"
)
@@ -22,8 +24,15 @@ type Participant struct {
// Per-viewer tracks and channels
VideoTrack *webrtc.TrackLocalStaticRTP
AudioTrack *webrtc.TrackLocalStaticRTP
VideoChan chan *rtp.Packet
AudioChan chan *rtp.Packet
// Per-viewer RTP state for retiming
VideoSequenceNumber uint16
VideoTimestamp uint32
AudioSequenceNumber uint16
AudioTimestamp uint32
packetQueue chan *participantPacket
closeOnce sync.Once
}
func NewParticipant(sessionID string, peerID peer.ID) (*Participant, error) {
@@ -31,24 +40,50 @@ func NewParticipant(sessionID string, peerID peer.ID) (*Participant, error) {
if err != nil {
return nil, fmt.Errorf("failed to create ULID for Participant: %w", err)
}
return &Participant{
ID: id,
SessionID: sessionID,
PeerID: peerID,
VideoChan: make(chan *rtp.Packet, 500),
AudioChan: make(chan *rtp.Packet, 100),
}, nil
p := &Participant{
ID: id,
SessionID: sessionID,
PeerID: peerID,
VideoSequenceNumber: 0,
VideoTimestamp: 0,
AudioSequenceNumber: 0,
AudioTimestamp: 0,
packetQueue: make(chan *participantPacket, 1000),
}
go p.packetWriter()
return p, nil
}
// SetTrack sets audio/video track for Participant
func (p *Participant) SetTrack(trackType webrtc.RTPCodecType, track *webrtc.TrackLocalStaticRTP) {
switch trackType {
case webrtc.RTPCodecTypeAudio:
p.AudioTrack = track
_, err := p.PeerConnection.AddTrack(track)
if err != nil {
slog.Error("Failed to add Participant audio track", err)
}
case webrtc.RTPCodecTypeVideo:
p.VideoTrack = track
_, err := p.PeerConnection.AddTrack(track)
if err != nil {
slog.Error("Failed to add Participant video track", err)
}
default:
slog.Warn("Unknown track type", "participant", p.ID, "trackType", trackType)
}
}
// Close cleans up participant resources
func (p *Participant) Close() {
if p.VideoChan != nil {
close(p.VideoChan)
p.VideoChan = nil
}
if p.AudioChan != nil {
close(p.AudioChan)
p.AudioChan = nil
if p.DataChannel != nil {
err := p.DataChannel.Close()
if err != nil {
slog.Error("Failed to close Participant DataChannel", err)
}
p.DataChannel = nil
}
if p.PeerConnection != nil {
err := p.PeerConnection.Close()
@@ -57,4 +92,45 @@ func (p *Participant) Close() {
}
p.PeerConnection = nil
}
if p.VideoTrack != nil {
p.VideoTrack = nil
}
if p.AudioTrack != nil {
p.AudioTrack = nil
}
}
func (p *Participant) packetWriter() {
for pkt := range p.packetQueue {
var track *webrtc.TrackLocalStaticRTP
var sequenceNumber uint16
var timestamp uint32
// No mutex needed - only this goroutine modifies these
if pkt.kind == webrtc.RTPCodecTypeAudio {
track = p.AudioTrack
p.AudioSequenceNumber = uint16(int(p.AudioSequenceNumber) + pkt.sequenceDiff)
p.AudioTimestamp = uint32(int64(p.AudioTimestamp) + pkt.timeDiff)
sequenceNumber = p.AudioSequenceNumber
timestamp = p.AudioTimestamp
} else {
track = p.VideoTrack
p.VideoSequenceNumber = uint16(int(p.VideoSequenceNumber) + pkt.sequenceDiff)
p.VideoTimestamp = uint32(int64(p.VideoTimestamp) + pkt.timeDiff)
sequenceNumber = p.VideoSequenceNumber
timestamp = p.VideoTimestamp
}
if track != nil {
pkt.packet.SequenceNumber = sequenceNumber
pkt.packet.Timestamp = timestamp
if err := track.WriteRTP(pkt.packet); err != nil && !errors.Is(err, io.ErrClosedPipe) {
slog.Error("WriteRTP failed", "participant", p.ID, "kind", pkt.kind, "err", err)
}
}
// Return packet struct to pool
participantPacketPool.Put(pkt)
}
}

View File

@@ -2,9 +2,9 @@ package shared
import (
"log/slog"
"relay/internal/common"
"relay/internal/connections"
"time"
"sync"
"sync/atomic"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/oklog/ulid/v2"
@@ -12,6 +12,19 @@ import (
"github.com/pion/webrtc/v4"
)
var participantPacketPool = sync.Pool{
New: func() interface{} {
return &participantPacket{}
},
}
type participantPacket struct {
kind webrtc.RTPCodecType
packet *rtp.Packet
timeDiff int64
sequenceDiff int
}
type RoomInfo struct {
ID ulid.ULID `json:"id"`
Name string `json:"name"`
@@ -20,16 +33,27 @@ type RoomInfo struct {
type Room struct {
RoomInfo
AudioCodec webrtc.RTPCodecCapability
VideoCodec webrtc.RTPCodecCapability
PeerConnection *webrtc.PeerConnection
AudioTrack *webrtc.TrackLocalStaticRTP
VideoTrack *webrtc.TrackLocalStaticRTP
DataChannel *connections.NestriDataChannel
Participants *common.SafeMap[ulid.ULID, *Participant]
// Broadcast queues (unbuffered, fan-out happens async)
videoBroadcastChan chan *rtp.Packet
audioBroadcastChan chan *rtp.Packet
broadcastStop chan struct{}
// Atomic pointer to slice of participant channels
participantChannels atomic.Pointer[[]chan<- *participantPacket]
participantsMtx sync.Mutex // Use only for add/remove
Participants map[ulid.ULID]*Participant // Keep general track of Participant(s)
// Track last seen values to calculate diffs
LastVideoTimestamp uint32
LastVideoSequenceNumber uint16
LastAudioTimestamp uint32
LastAudioSequenceNumber uint16
VideoTimestampSet bool
VideoSequenceSet bool
AudioTimestampSet bool
AudioSequenceSet bool
}
func NewRoom(name string, roomID ulid.ULID, ownerID peer.ID) *Room {
@@ -39,133 +63,109 @@ func NewRoom(name string, roomID ulid.ULID, ownerID peer.ID) *Room {
Name: name,
OwnerID: ownerID,
},
Participants: common.NewSafeMap[ulid.ULID, *Participant](),
videoBroadcastChan: make(chan *rtp.Packet, 1000), // Large buffer for incoming packets
audioBroadcastChan: make(chan *rtp.Packet, 500),
broadcastStop: make(chan struct{}),
PeerConnection: nil,
DataChannel: nil,
Participants: make(map[ulid.ULID]*Participant),
}
// Start async broadcasters
go r.videoBroadcaster()
go r.audioBroadcaster()
emptyChannels := make([]chan<- *participantPacket, 0)
r.participantChannels.Store(&emptyChannels)
return r
}
// Close closes up Room (stream ended)
func (r *Room) Close() {
if r.DataChannel != nil {
err := r.DataChannel.Close()
if err != nil {
slog.Error("Failed to close Room DataChannel", err)
}
r.DataChannel = nil
}
if r.PeerConnection != nil {
err := r.PeerConnection.Close()
if err != nil {
slog.Error("Failed to close Room PeerConnection", err)
}
r.PeerConnection = nil
}
}
// AddParticipant adds a Participant to a Room
func (r *Room) AddParticipant(participant *Participant) {
slog.Debug("Adding participant to room", "participant", participant.ID, "room", r.Name)
r.Participants.Set(participant.ID, participant)
r.participantsMtx.Lock()
defer r.participantsMtx.Unlock()
r.Participants[participant.ID] = participant
// Update channel slice atomically
current := r.participantChannels.Load()
newChannels := make([]chan<- *participantPacket, len(*current)+1)
copy(newChannels, *current)
newChannels[len(*current)] = participant.packetQueue
r.participantChannels.Store(&newChannels)
slog.Debug("Added participant", "participant", participant.ID, "room", r.Name)
}
// RemoveParticipantByID removes a Participant from a Room by participant's ID
func (r *Room) RemoveParticipantByID(pID ulid.ULID) {
if _, ok := r.Participants.Get(pID); ok {
r.Participants.Delete(pID)
r.participantsMtx.Lock()
defer r.participantsMtx.Unlock()
participant, ok := r.Participants[pID]
if !ok {
return
}
delete(r.Participants, pID)
// Update channel slice
current := r.participantChannels.Load()
newChannels := make([]chan<- *participantPacket, 0, len(*current)-1)
for _, ch := range *current {
if ch != participant.packetQueue {
newChannels = append(newChannels, ch)
}
}
r.participantChannels.Store(&newChannels)
slog.Debug("Removed participant", "participant", pID, "room", r.Name)
}
// IsOnline checks if the room is online (has both audio and video tracks)
// IsOnline checks if the room is online
func (r *Room) IsOnline() bool {
return r.AudioTrack != nil && r.VideoTrack != nil
return r.PeerConnection != nil
}
func (r *Room) SetTrack(trackType webrtc.RTPCodecType, track *webrtc.TrackLocalStaticRTP) {
switch trackType {
case webrtc.RTPCodecTypeAudio:
r.AudioTrack = track
case webrtc.RTPCodecTypeVideo:
r.VideoTrack = track
default:
slog.Warn("Unknown track type", "room", r.Name, "trackType", trackType)
func (r *Room) BroadcastPacketRetimed(kind webrtc.RTPCodecType, pkt *rtp.Packet, timeDiff int64, sequenceDiff int) {
// Lock-free load of channel slice
channels := r.participantChannels.Load()
// no participants..
if len(*channels) == 0 {
return
}
}
// BroadcastPacket enqueues packet for async broadcast (non-blocking)
func (r *Room) BroadcastPacket(kind webrtc.RTPCodecType, pkt *rtp.Packet) {
start := time.Now()
if kind == webrtc.RTPCodecTypeVideo {
// Send to each participant channel (non-blocking)
for i, ch := range *channels {
// Get packet struct from pool
pp := participantPacketPool.Get().(*participantPacket)
pp.kind = kind
pp.packet = pkt.Clone()
pp.timeDiff = timeDiff
pp.sequenceDiff = sequenceDiff
select {
case r.videoBroadcastChan <- pkt:
duration := time.Since(start)
if duration > 10*time.Millisecond {
slog.Warn("Slow video broadcast enqueue", "duration", duration, "room", r.Name)
}
case ch <- pp:
// Sent successfully
default:
// Broadcast queue full - system overload, drop packet globally
slog.Warn("Video broadcast queue full, dropping packet", "room", r.Name)
}
} else {
select {
case r.audioBroadcastChan <- pkt:
duration := time.Since(start)
if duration > 10*time.Millisecond {
slog.Warn("Slow audio broadcast enqueue", "duration", duration, "room", r.Name)
}
default:
slog.Warn("Audio broadcast queue full, dropping packet", "room", r.Name)
}
}
}
// Close stops the broadcasters
func (r *Room) Close() {
close(r.broadcastStop)
close(r.videoBroadcastChan)
close(r.audioBroadcastChan)
}
// videoBroadcaster runs async fan-out for video packets
func (r *Room) videoBroadcaster() {
for {
select {
case pkt := <-r.videoBroadcastChan:
// Fan out to all participants without blocking
r.Participants.Range(func(_ ulid.ULID, participant *Participant) bool {
if participant.VideoChan != nil {
// Clone packet for each participant to avoid shared pointer issues
clonedPkt := pkt.Clone()
select {
case participant.VideoChan <- clonedPkt:
// Sent
default:
// Participant slow, drop packet
slog.Debug("Dropped video packet for slow participant",
"room", r.Name,
"participant", participant.ID)
}
}
return true
})
case <-r.broadcastStop:
return
}
}
}
// audioBroadcaster runs async fan-out for audio packets
func (r *Room) audioBroadcaster() {
for {
select {
case pkt := <-r.audioBroadcastChan:
r.Participants.Range(func(_ ulid.ULID, participant *Participant) bool {
if participant.AudioChan != nil {
// Clone packet for each participant to avoid shared pointer issues
clonedPkt := pkt.Clone()
select {
case participant.AudioChan <- clonedPkt:
// Sent
default:
// Participant slow, drop packet
slog.Debug("Dropped audio packet for slow participant",
"room", r.Name,
"participant", participant.ID)
}
}
return true
})
case <-r.broadcastStop:
return
// Channel full, drop packet, log?
slog.Warn("Channel full, dropping packet", "channel_index", i)
participantPacketPool.Put(pp)
}
}
}