mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 08:45:38 +02:00
Fixed multi-controllers, optimize and improve code in relay and nestri-server
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
"sync"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/pion/rtp"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
@@ -22,8 +24,15 @@ type Participant struct {
|
||||
// Per-viewer tracks and channels
|
||||
VideoTrack *webrtc.TrackLocalStaticRTP
|
||||
AudioTrack *webrtc.TrackLocalStaticRTP
|
||||
VideoChan chan *rtp.Packet
|
||||
AudioChan chan *rtp.Packet
|
||||
|
||||
// Per-viewer RTP state for retiming
|
||||
VideoSequenceNumber uint16
|
||||
VideoTimestamp uint32
|
||||
AudioSequenceNumber uint16
|
||||
AudioTimestamp uint32
|
||||
|
||||
packetQueue chan *participantPacket
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewParticipant(sessionID string, peerID peer.ID) (*Participant, error) {
|
||||
@@ -31,24 +40,50 @@ func NewParticipant(sessionID string, peerID peer.ID) (*Participant, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ULID for Participant: %w", err)
|
||||
}
|
||||
return &Participant{
|
||||
ID: id,
|
||||
SessionID: sessionID,
|
||||
PeerID: peerID,
|
||||
VideoChan: make(chan *rtp.Packet, 500),
|
||||
AudioChan: make(chan *rtp.Packet, 100),
|
||||
}, nil
|
||||
p := &Participant{
|
||||
ID: id,
|
||||
SessionID: sessionID,
|
||||
PeerID: peerID,
|
||||
VideoSequenceNumber: 0,
|
||||
VideoTimestamp: 0,
|
||||
AudioSequenceNumber: 0,
|
||||
AudioTimestamp: 0,
|
||||
packetQueue: make(chan *participantPacket, 1000),
|
||||
}
|
||||
|
||||
go p.packetWriter()
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// SetTrack sets audio/video track for Participant
|
||||
func (p *Participant) SetTrack(trackType webrtc.RTPCodecType, track *webrtc.TrackLocalStaticRTP) {
|
||||
switch trackType {
|
||||
case webrtc.RTPCodecTypeAudio:
|
||||
p.AudioTrack = track
|
||||
_, err := p.PeerConnection.AddTrack(track)
|
||||
if err != nil {
|
||||
slog.Error("Failed to add Participant audio track", err)
|
||||
}
|
||||
case webrtc.RTPCodecTypeVideo:
|
||||
p.VideoTrack = track
|
||||
_, err := p.PeerConnection.AddTrack(track)
|
||||
if err != nil {
|
||||
slog.Error("Failed to add Participant video track", err)
|
||||
}
|
||||
default:
|
||||
slog.Warn("Unknown track type", "participant", p.ID, "trackType", trackType)
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up participant resources
|
||||
func (p *Participant) Close() {
|
||||
if p.VideoChan != nil {
|
||||
close(p.VideoChan)
|
||||
p.VideoChan = nil
|
||||
}
|
||||
if p.AudioChan != nil {
|
||||
close(p.AudioChan)
|
||||
p.AudioChan = nil
|
||||
if p.DataChannel != nil {
|
||||
err := p.DataChannel.Close()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close Participant DataChannel", err)
|
||||
}
|
||||
p.DataChannel = nil
|
||||
}
|
||||
if p.PeerConnection != nil {
|
||||
err := p.PeerConnection.Close()
|
||||
@@ -57,4 +92,45 @@ func (p *Participant) Close() {
|
||||
}
|
||||
p.PeerConnection = nil
|
||||
}
|
||||
if p.VideoTrack != nil {
|
||||
p.VideoTrack = nil
|
||||
}
|
||||
if p.AudioTrack != nil {
|
||||
p.AudioTrack = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Participant) packetWriter() {
|
||||
for pkt := range p.packetQueue {
|
||||
var track *webrtc.TrackLocalStaticRTP
|
||||
var sequenceNumber uint16
|
||||
var timestamp uint32
|
||||
|
||||
// No mutex needed - only this goroutine modifies these
|
||||
if pkt.kind == webrtc.RTPCodecTypeAudio {
|
||||
track = p.AudioTrack
|
||||
p.AudioSequenceNumber = uint16(int(p.AudioSequenceNumber) + pkt.sequenceDiff)
|
||||
p.AudioTimestamp = uint32(int64(p.AudioTimestamp) + pkt.timeDiff)
|
||||
sequenceNumber = p.AudioSequenceNumber
|
||||
timestamp = p.AudioTimestamp
|
||||
} else {
|
||||
track = p.VideoTrack
|
||||
p.VideoSequenceNumber = uint16(int(p.VideoSequenceNumber) + pkt.sequenceDiff)
|
||||
p.VideoTimestamp = uint32(int64(p.VideoTimestamp) + pkt.timeDiff)
|
||||
sequenceNumber = p.VideoSequenceNumber
|
||||
timestamp = p.VideoTimestamp
|
||||
}
|
||||
|
||||
if track != nil {
|
||||
pkt.packet.SequenceNumber = sequenceNumber
|
||||
pkt.packet.Timestamp = timestamp
|
||||
|
||||
if err := track.WriteRTP(pkt.packet); err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
slog.Error("WriteRTP failed", "participant", p.ID, "kind", pkt.kind, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Return packet struct to pool
|
||||
participantPacketPool.Put(pkt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ package shared
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"relay/internal/common"
|
||||
"relay/internal/connections"
|
||||
"time"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/oklog/ulid/v2"
|
||||
@@ -12,6 +12,19 @@ import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
var participantPacketPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &participantPacket{}
|
||||
},
|
||||
}
|
||||
|
||||
type participantPacket struct {
|
||||
kind webrtc.RTPCodecType
|
||||
packet *rtp.Packet
|
||||
timeDiff int64
|
||||
sequenceDiff int
|
||||
}
|
||||
|
||||
type RoomInfo struct {
|
||||
ID ulid.ULID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -20,16 +33,27 @@ type RoomInfo struct {
|
||||
|
||||
type Room struct {
|
||||
RoomInfo
|
||||
AudioCodec webrtc.RTPCodecCapability
|
||||
VideoCodec webrtc.RTPCodecCapability
|
||||
PeerConnection *webrtc.PeerConnection
|
||||
AudioTrack *webrtc.TrackLocalStaticRTP
|
||||
VideoTrack *webrtc.TrackLocalStaticRTP
|
||||
DataChannel *connections.NestriDataChannel
|
||||
Participants *common.SafeMap[ulid.ULID, *Participant]
|
||||
|
||||
// Broadcast queues (unbuffered, fan-out happens async)
|
||||
videoBroadcastChan chan *rtp.Packet
|
||||
audioBroadcastChan chan *rtp.Packet
|
||||
broadcastStop chan struct{}
|
||||
// Atomic pointer to slice of participant channels
|
||||
participantChannels atomic.Pointer[[]chan<- *participantPacket]
|
||||
participantsMtx sync.Mutex // Use only for add/remove
|
||||
|
||||
Participants map[ulid.ULID]*Participant // Keep general track of Participant(s)
|
||||
|
||||
// Track last seen values to calculate diffs
|
||||
LastVideoTimestamp uint32
|
||||
LastVideoSequenceNumber uint16
|
||||
LastAudioTimestamp uint32
|
||||
LastAudioSequenceNumber uint16
|
||||
|
||||
VideoTimestampSet bool
|
||||
VideoSequenceSet bool
|
||||
AudioTimestampSet bool
|
||||
AudioSequenceSet bool
|
||||
}
|
||||
|
||||
func NewRoom(name string, roomID ulid.ULID, ownerID peer.ID) *Room {
|
||||
@@ -39,133 +63,109 @@ func NewRoom(name string, roomID ulid.ULID, ownerID peer.ID) *Room {
|
||||
Name: name,
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
Participants: common.NewSafeMap[ulid.ULID, *Participant](),
|
||||
videoBroadcastChan: make(chan *rtp.Packet, 1000), // Large buffer for incoming packets
|
||||
audioBroadcastChan: make(chan *rtp.Packet, 500),
|
||||
broadcastStop: make(chan struct{}),
|
||||
PeerConnection: nil,
|
||||
DataChannel: nil,
|
||||
Participants: make(map[ulid.ULID]*Participant),
|
||||
}
|
||||
|
||||
// Start async broadcasters
|
||||
go r.videoBroadcaster()
|
||||
go r.audioBroadcaster()
|
||||
emptyChannels := make([]chan<- *participantPacket, 0)
|
||||
r.participantChannels.Store(&emptyChannels)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Close closes up Room (stream ended)
|
||||
func (r *Room) Close() {
|
||||
if r.DataChannel != nil {
|
||||
err := r.DataChannel.Close()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close Room DataChannel", err)
|
||||
}
|
||||
r.DataChannel = nil
|
||||
}
|
||||
if r.PeerConnection != nil {
|
||||
err := r.PeerConnection.Close()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close Room PeerConnection", err)
|
||||
}
|
||||
r.PeerConnection = nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddParticipant adds a Participant to a Room
|
||||
func (r *Room) AddParticipant(participant *Participant) {
|
||||
slog.Debug("Adding participant to room", "participant", participant.ID, "room", r.Name)
|
||||
r.Participants.Set(participant.ID, participant)
|
||||
r.participantsMtx.Lock()
|
||||
defer r.participantsMtx.Unlock()
|
||||
|
||||
r.Participants[participant.ID] = participant
|
||||
|
||||
// Update channel slice atomically
|
||||
current := r.participantChannels.Load()
|
||||
newChannels := make([]chan<- *participantPacket, len(*current)+1)
|
||||
copy(newChannels, *current)
|
||||
newChannels[len(*current)] = participant.packetQueue
|
||||
|
||||
r.participantChannels.Store(&newChannels)
|
||||
|
||||
slog.Debug("Added participant", "participant", participant.ID, "room", r.Name)
|
||||
}
|
||||
|
||||
// RemoveParticipantByID removes a Participant from a Room by participant's ID
|
||||
func (r *Room) RemoveParticipantByID(pID ulid.ULID) {
|
||||
if _, ok := r.Participants.Get(pID); ok {
|
||||
r.Participants.Delete(pID)
|
||||
r.participantsMtx.Lock()
|
||||
defer r.participantsMtx.Unlock()
|
||||
|
||||
participant, ok := r.Participants[pID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(r.Participants, pID)
|
||||
|
||||
// Update channel slice
|
||||
current := r.participantChannels.Load()
|
||||
newChannels := make([]chan<- *participantPacket, 0, len(*current)-1)
|
||||
for _, ch := range *current {
|
||||
if ch != participant.packetQueue {
|
||||
newChannels = append(newChannels, ch)
|
||||
}
|
||||
}
|
||||
|
||||
r.participantChannels.Store(&newChannels)
|
||||
|
||||
slog.Debug("Removed participant", "participant", pID, "room", r.Name)
|
||||
}
|
||||
|
||||
// IsOnline checks if the room is online (has both audio and video tracks)
|
||||
// IsOnline checks if the room is online
|
||||
func (r *Room) IsOnline() bool {
|
||||
return r.AudioTrack != nil && r.VideoTrack != nil
|
||||
return r.PeerConnection != nil
|
||||
}
|
||||
|
||||
func (r *Room) SetTrack(trackType webrtc.RTPCodecType, track *webrtc.TrackLocalStaticRTP) {
|
||||
switch trackType {
|
||||
case webrtc.RTPCodecTypeAudio:
|
||||
r.AudioTrack = track
|
||||
case webrtc.RTPCodecTypeVideo:
|
||||
r.VideoTrack = track
|
||||
default:
|
||||
slog.Warn("Unknown track type", "room", r.Name, "trackType", trackType)
|
||||
func (r *Room) BroadcastPacketRetimed(kind webrtc.RTPCodecType, pkt *rtp.Packet, timeDiff int64, sequenceDiff int) {
|
||||
// Lock-free load of channel slice
|
||||
channels := r.participantChannels.Load()
|
||||
|
||||
// no participants..
|
||||
if len(*channels) == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastPacket enqueues packet for async broadcast (non-blocking)
|
||||
func (r *Room) BroadcastPacket(kind webrtc.RTPCodecType, pkt *rtp.Packet) {
|
||||
start := time.Now()
|
||||
if kind == webrtc.RTPCodecTypeVideo {
|
||||
// Send to each participant channel (non-blocking)
|
||||
for i, ch := range *channels {
|
||||
// Get packet struct from pool
|
||||
pp := participantPacketPool.Get().(*participantPacket)
|
||||
pp.kind = kind
|
||||
pp.packet = pkt.Clone()
|
||||
pp.timeDiff = timeDiff
|
||||
pp.sequenceDiff = sequenceDiff
|
||||
|
||||
select {
|
||||
case r.videoBroadcastChan <- pkt:
|
||||
duration := time.Since(start)
|
||||
if duration > 10*time.Millisecond {
|
||||
slog.Warn("Slow video broadcast enqueue", "duration", duration, "room", r.Name)
|
||||
}
|
||||
case ch <- pp:
|
||||
// Sent successfully
|
||||
default:
|
||||
// Broadcast queue full - system overload, drop packet globally
|
||||
slog.Warn("Video broadcast queue full, dropping packet", "room", r.Name)
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case r.audioBroadcastChan <- pkt:
|
||||
duration := time.Since(start)
|
||||
if duration > 10*time.Millisecond {
|
||||
slog.Warn("Slow audio broadcast enqueue", "duration", duration, "room", r.Name)
|
||||
}
|
||||
default:
|
||||
slog.Warn("Audio broadcast queue full, dropping packet", "room", r.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the broadcasters
|
||||
func (r *Room) Close() {
|
||||
close(r.broadcastStop)
|
||||
close(r.videoBroadcastChan)
|
||||
close(r.audioBroadcastChan)
|
||||
}
|
||||
|
||||
// videoBroadcaster runs async fan-out for video packets
|
||||
func (r *Room) videoBroadcaster() {
|
||||
for {
|
||||
select {
|
||||
case pkt := <-r.videoBroadcastChan:
|
||||
// Fan out to all participants without blocking
|
||||
r.Participants.Range(func(_ ulid.ULID, participant *Participant) bool {
|
||||
if participant.VideoChan != nil {
|
||||
// Clone packet for each participant to avoid shared pointer issues
|
||||
clonedPkt := pkt.Clone()
|
||||
select {
|
||||
case participant.VideoChan <- clonedPkt:
|
||||
// Sent
|
||||
default:
|
||||
// Participant slow, drop packet
|
||||
slog.Debug("Dropped video packet for slow participant",
|
||||
"room", r.Name,
|
||||
"participant", participant.ID)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-r.broadcastStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// audioBroadcaster runs async fan-out for audio packets
|
||||
func (r *Room) audioBroadcaster() {
|
||||
for {
|
||||
select {
|
||||
case pkt := <-r.audioBroadcastChan:
|
||||
r.Participants.Range(func(_ ulid.ULID, participant *Participant) bool {
|
||||
if participant.AudioChan != nil {
|
||||
// Clone packet for each participant to avoid shared pointer issues
|
||||
clonedPkt := pkt.Clone()
|
||||
select {
|
||||
case participant.AudioChan <- clonedPkt:
|
||||
// Sent
|
||||
default:
|
||||
// Participant slow, drop packet
|
||||
slog.Debug("Dropped audio packet for slow participant",
|
||||
"room", r.Name,
|
||||
"participant", participant.ID)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-r.broadcastStop:
|
||||
return
|
||||
// Channel full, drop packet, log?
|
||||
slog.Warn("Channel full, dropping packet", "channel_index", i)
|
||||
participantPacketPool.Put(pp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user