feat: Add streaming support (#125)

This adds:
- [x] Keyboard and mouse handling on the frontend
- [x] Video and audio streaming from the backend to the frontend
- [x] Input server that works with Websockets

Update - 17/11
- [ ] Master docker container to run this
- [ ] Steam runtime
- [ ] Entrypoint.sh

---------

Co-authored-by: Kristian Ollikainen <14197772+DatCaptainHorse@users.noreply.github.com>
Co-authored-by: Kristian Ollikainen <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
Wanjohi
2024-12-08 14:54:56 +03:00
committed by GitHub
parent 5eb21eeadb
commit 379db1c87b
137 changed files with 12737 additions and 5234 deletions

View File

@@ -0,0 +1,16 @@
[package]
name = "dev"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "nestri-test-server"
path = "src/main.rs"
[dependencies]
webrtc = "0.11.0"
tokio = { version = "1.41.1", features = ["full"] }
reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1.0.215", features = ["derive"]}
serde_json = "1.0.133"

View File

@@ -0,0 +1,11 @@
FROM archlinux:latest
RUN pacman -Syu --noconfirm
RUN pacman -Su --noconfirm \
gstreamer gst-plugins-base gst-plugins-good gst-plugin-rswebrtc
RUN pacman -Syu --noconfirm \
mesa mesa-utils xorg-xwayland vulkan-intel vpl-gpu-rt intel-media-driver gst-plugin-va gst-plugins-bad gst-plugin-fmp4 gst-plugin-qsv gst-plugin-pipewire
CMD [ "bash","-c", "gst-launch-1.0 videotestsrc ! openh264enc ! whip0. audiotestsrc ! opusenc ! whip0. whipclientsink name=whip0 signaller::whip-endpoint=http://localhost:8088/api/whip/test" ]

17
packages/relay/dev/server.sh Executable file
View File

@@ -0,0 +1,17 @@
#! /bin/bash -e
# sudo apt install build-essential -y
# To run tests, run the relay first - go run main.go.
# Run the docker container next - docker run --rm --init -d --device /dev/dri --network=host test-server
# Then run the nestri-test-server - cd packages/relay/dev cargo run
# Then run the frontend site, and navigate to http://localhost:5713/play/test
# Expected behavior, see some random messages on the browser's console tab
# And if you input works correctly, it should be logged to the console on the server-side of things
# docker build -t test-server -f Containerfile .
# docker run --rm --init -d --device /dev/dri --network=host test-server
# echo -e "Navigate to http://localhost:5713/play/test"

View File

@@ -0,0 +1,11 @@
mod room;
#[tokio::main]
async fn main() -> std::io::Result<()> {
let room = "test";
let base_url = "http://localhost:8088";
let mut room_handler = room::Room::new(room, base_url).await?;
room_handler.run().await?;
Ok(())
}

View File

@@ -0,0 +1,292 @@
use reqwest;
use std::collections::HashSet;
use std::io;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio::time::Duration;
use webrtc::api::interceptor_registry::register_default_interceptors;
// use std::collections::HashSet;
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use webrtc::api::media_engine::MediaEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::math_rand_alpha;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
enum InputMessage {
#[serde(rename = "mousemove")]
MouseMove { x: i32, y: i32 },
#[serde(rename = "mousemoveabs")]
MouseMoveAbs { x: i32, y: i32 },
#[serde(rename = "wheel")]
Wheel { x: f64, y: f64 },
#[serde(rename = "mousedown")]
MouseDown { key: i32 },
// Add other variants as needed
#[serde(rename = "mouseup")]
MouseUp { key: i32 },
#[serde(rename = "keydown")]
KeyDown { key: i32 },
#[serde(rename = "keyup")]
KeyUp { key: i32 },
}
pub struct Room {
peer_connection: Arc<webrtc::peer_connection::RTCPeerConnection>,
data_channel: Arc<webrtc::data_channel::RTCDataChannel>,
done_tx: mpsc::Sender<()>,
done_rx: mpsc::Receiver<()>,
base_url: String,
stream_name: String,
// pipeline: Arc<Mutex<gst::Pipeline>>,
}
impl Room {
pub async fn new(
stream_name: &str,
base_url: &str,
// pipeline: Arc<Mutex<gst::Pipeline>>,
) -> io::Result<Self> {
// Create a MediaEngine object to configure the supported codec
let mut m = MediaEngine::default();
// Register default codecs
let _ = m.register_default_codecs().map_err(map_to_io_error)?;
let mut registry = Registry::new();
// Use the default set of Interceptors
registry = register_default_interceptors(registry, &mut m).map_err(map_to_io_error)?;
// Create the API object with the MediaEngine
let api = APIBuilder::new()
.with_media_engine(m)
.with_interceptor_registry(registry)
.build();
// Prepare the configuration
let config = RTCConfiguration {
ice_servers: vec![RTCIceServer {
urls: vec!["stun:stun.l.google.com:19302".to_owned()],
..Default::default()
}],
..Default::default()
};
// Create a new RTCPeerConnection
let peer_connection = Arc::new(
api.new_peer_connection(config)
.await
.map_err(map_to_io_error)?,
);
// Create a datachannel with label 'data'
let data_channel = peer_connection
.create_data_channel("input", None)
.await
.map_err(map_to_io_error)?;
let (done_tx, done_rx) = mpsc::channel::<()>(1);
let done_tx_clone = done_tx.clone();
// Peer connection state change handler
peer_connection.on_peer_connection_state_change(Box::new(
move |s: RTCPeerConnectionState| {
println!("Peer Connection State has changed: {s}");
if s == RTCPeerConnectionState::Failed {
println!("Peer Connection has gone to failed exiting");
let _ = done_tx_clone.try_send(());
}
Box::pin(async {})
},
));
Ok(Self {
peer_connection,
// pipeline,
data_channel,
done_tx,
done_rx,
base_url: base_url.to_string(),
stream_name: stream_name.to_string(),
})
}
pub async fn run(&mut self) -> io::Result<()> {
// Create an async channel for sending events to the pipeline
let (event_tx, mut event_rx) = mpsc::channel(10);
// A shared state to track currently pressed keys
let pressed_keys = Arc::new(tokio::sync::Mutex::new(HashSet::new()));
// Spawn a task to process events for the pipeline
let pipeline_task = {
// let pipeline = Arc::clone(self.pipeline);
// let pipeline_clone = self.pipeline.clone();
tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
// let pipeline = pipeline_clone.lock().await;
// pipeline.send_event(event);
println!("Invoked an event: {}", event)
}
})
};
let data_channel = self.data_channel.clone();
//TODO: Handle heartbeats here
let d1 = Arc::clone(&self.data_channel);
data_channel.on_open(Box::new(move || {
println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", d1.label(), d1.id());
let d2 = Arc::clone(&d1);
Box::pin(async move {
let mut result = std::io::Result::<usize>::Ok(0);
while result.is_ok() {
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
tokio::select! {
_ = timeout.as_mut() =>{
let message = math_rand_alpha(15);
println!("Sending '{message}'");
result = d2.send_text(message).await.map_err(map_to_io_error);
}
};
}
})
}));
// Data channel message handler
let d_label = data_channel.label().to_owned();
data_channel.on_message(Box::new(move |msg: DataChannelMessage| {
let msg_str = String::from_utf8(msg.data.to_vec()).unwrap();
println!("Message from DataChannel '{d_label}': '{msg_str}'");
let event_tx = event_tx.clone();
let pressed_keys = Arc::clone(&pressed_keys);
tokio::spawn(async move {
if let Ok(input_msg) = from_str::<InputMessage>(&msg_str) {
if let Some(event) = handle_input_message(input_msg, &pressed_keys).await {
event_tx.send(event).await.unwrap();
}
}
});
Box::pin(async {})
}));
// Create an offer to send to the browser
let offer = self
.peer_connection
.create_offer(None)
.await
.map_err(map_to_io_error)?;
// Create channel that is blocked until ICE Gathering is complete
let mut gather_complete = self.peer_connection.gathering_complete_promise().await;
// Sets the LocalDescription, and starts our UDP listeners
self.peer_connection
.set_local_description(offer)
.await
.map_err(map_to_io_error)?;
// Block until ICE Gathering is complete, disabling trickle ICE
// we do this because we only can exchange one signaling message
// in a production application you should exchange ICE Candidates via OnICECandidate
let _ = gather_complete.recv().await;
if let Some(local_description) = self.peer_connection.local_description().await {
let url = format!("{}/api/whep/{}", self.base_url, self.stream_name);
let response = reqwest::Client::new()
.post(&url)
.header("Content-Type", "application/sdp")
.body(local_description.sdp.clone()) // clone if you don't want to move offer.sdp
.send()
.await
.map_err(map_to_io_error)?;
let answer = response
.json::<RTCSessionDescription>()
.await
.map_err(map_to_io_error)?;
self.peer_connection
.set_remote_description(answer)
.await
.map_err(map_to_io_error)?;
} else {
println!("generate local_description failed!");
};
println!("Press ctrl-c to stop");
tokio::select! {
_ = self.done_rx.recv() => {
println!("received done signal!");
}
_ = tokio::signal::ctrl_c() => {
println!();
}
};
self.peer_connection
.close()
.await
.map_err(map_to_io_error)?;
//FIXME: Ctr + C is not working... i suspect it has something to do with this guy -- Do not forget to fix packages/server/room.rs as well
pipeline_task.await?;
Ok(())
}
}
fn map_to_io_error<E: std::fmt::Display>(e: E) -> io::Error {
io::Error::new(io::ErrorKind::Other, format!("{}", e))
}
async fn handle_input_message(
input_msg: InputMessage,
pressed_keys: &Arc<tokio::sync::Mutex<HashSet<i32>>>,
) -> Option<String> {
match input_msg {
InputMessage::MouseMove { x, y } => Some("MouseMoved".to_string()),
InputMessage::MouseMoveAbs { x, y } => Some("MouseMoveAbsolute".to_string()),
InputMessage::KeyDown { key } => {
let mut keys = pressed_keys.lock().await;
// If the key is already pressed, return to prevent key lockup
if keys.contains(&key) {
return None;
}
keys.insert(key);
Some("KeyDown".to_string())
}
InputMessage::KeyUp { key } => {
let mut keys = pressed_keys.lock().await;
// Remove the key from the pressed state when released
keys.remove(&key);
Some("KeyUp".to_string())
}
InputMessage::Wheel { x, y } => Some("Wheel".to_string()),
InputMessage::MouseDown { key } => Some("MouseDown".to_string()),
InputMessage::MouseUp { key } => Some("MouseUp".to_string()),
}
}

31
packages/relay/go.mod Normal file
View File

@@ -0,0 +1,31 @@
module relay
go 1.23
require (
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/pion/interceptor v0.1.37
github.com/pion/webrtc/v4 v4.0.2
)
require (
github.com/pion/datachannel v1.5.9 // indirect
github.com/pion/dtls/v3 v3.0.4 // indirect
github.com/pion/ice/v4 v4.0.2 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtcp v1.2.14 // indirect
github.com/pion/rtp v1.8.9 // indirect
github.com/pion/sctp v1.8.34 // indirect
github.com/pion/sdp/v3 v3.0.9 // indirect
github.com/pion/srtp/v3 v3.0.4 // indirect
github.com/pion/stun/v3 v3.0.0 // indirect
github.com/pion/transport/v3 v3.0.7 // indirect
github.com/pion/turn/v4 v4.0.0 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
golang.org/x/crypto v0.29.0 // indirect
golang.org/x/net v0.31.0 // indirect
golang.org/x/sys v0.27.0 // indirect
)

62
packages/relay/go.sum Normal file
View File

@@ -0,0 +1,62 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/pion/datachannel v1.5.9 h1:LpIWAOYPyDrXtU+BW7X0Yt/vGtYxtXQ8ql7dFfYUVZA=
github.com/pion/datachannel v1.5.9/go.mod h1:kDUuk4CU4Uxp82NH4LQZbISULkX/HtzKa4P7ldf9izE=
github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U=
github.com/pion/dtls/v3 v3.0.4/go.mod h1:R373CsjxWqNPf6MEkfdy3aSe9niZvL/JaKlGeFphtMg=
github.com/pion/ice/v4 v4.0.2 h1:1JhBRX8iQLi0+TfcavTjPjI6GO41MFn4CeTBX+Y9h5s=
github.com/pion/ice/v4 v4.0.2/go.mod h1:DCdqyzgtsDNYN6/3U8044j3U7qsJ9KFJC92VnOWHvXg=
github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI=
github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.14 h1:KCkGV3vJ+4DAJmvP0vaQShsb0xkRfWkO540Gy102KyE=
github.com/pion/rtcp v1.2.14/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4=
github.com/pion/rtp v1.8.9 h1:E2HX740TZKaqdcPmf4pw6ZZuG8u5RlMMt+l3dxeu6Wk=
github.com/pion/rtp v1.8.9/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU=
github.com/pion/sctp v1.8.34 h1:rCuD3m53i0oGxCSp7FLQKvqVx0Nf5AUAHhMRXTTQjBc=
github.com/pion/sctp v1.8.34/go.mod h1:yWkCClkXlzVW7BXfI2PjrUGBwUI0CjXJBkhLt+sdo4U=
github.com/pion/sdp/v3 v3.0.9 h1:pX++dCHoHUwq43kuwf3PyJfHlwIj4hXA7Vrifiq0IJY=
github.com/pion/sdp/v3 v3.0.9/go.mod h1:B5xmvENq5IXJimIO4zfp6LAe1fD9N+kFv+V/1lOdz8M=
github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M=
github.com/pion/srtp/v3 v3.0.4/go.mod h1:1Jx3FwDoxpRaTh1oRV8A/6G1BnFL+QI82eK4ms8EEJQ=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM=
github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA=
github.com/pion/webrtc/v4 v4.0.2 h1:fBwm5/hqSUybrCWl0DDBSTDrpbkcgkqpeLmXw9CsBQA=
github.com/pion/webrtc/v4 v4.0.2/go.mod h1:moylBT2A4dNoEaYBCdV1nThM3TLwRHzWszIG+eSPaqQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ=
golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg=
golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo=
golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,100 @@
package relay
import (
"github.com/pion/interceptor"
"github.com/pion/webrtc/v4"
"log"
)
var globalWebRTCAPI *webrtc.API
var globalWebRTCConfig = webrtc.Configuration{
ICETransportPolicy: webrtc.ICETransportPolicyAll,
BundlePolicy: webrtc.BundlePolicyBalanced,
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
}
func InitWebRTCAPI() error {
var err error
flags := GetFlags()
// Media engine
mediaEngine := &webrtc.MediaEngine{}
// Default codecs cover most of our needs
err = mediaEngine.RegisterDefaultCodecs()
if err != nil {
return err
}
// Add H.265 for special cases
videoRTCPFeedback := []webrtc.RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}}
for _, codec := range []webrtc.RTPCodecParameters{
{
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH265, ClockRate: 90000, RTCPFeedback: videoRTCPFeedback},
PayloadType: 48,
},
{
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeRTX, ClockRate: 90000, SDPFmtpLine: "apt=48"},
PayloadType: 49,
},
} {
if err := mediaEngine.RegisterCodec(codec, webrtc.RTPCodecTypeVideo); err != nil {
return err
}
}
// Interceptor registry
interceptorRegistry := &interceptor.Registry{}
// Use default set
err = webrtc.RegisterDefaultInterceptors(mediaEngine, interceptorRegistry)
if err != nil {
return err
}
// Setting engine
settingEngine := webrtc.SettingEngine{}
// New in v4, reduces CPU usage and latency when enabled
settingEngine.EnableSCTPZeroChecksum(true)
// Set the UDP port range used by WebRTC
err = settingEngine.SetEphemeralUDPPortRange(uint16(flags.WebRTCUDPStart), uint16(flags.WebRTCUDPEnd))
if err != nil {
return err
}
// Create a new API object with our customized settings
globalWebRTCAPI = webrtc.NewAPI(webrtc.WithMediaEngine(mediaEngine), webrtc.WithSettingEngine(settingEngine), webrtc.WithInterceptorRegistry(interceptorRegistry))
return nil
}
// GetWebRTCAPI returns the global WebRTC API
func GetWebRTCAPI() *webrtc.API {
return globalWebRTCAPI
}
// CreatePeerConnection sets up a new peer connection
func CreatePeerConnection(onClose func()) (*webrtc.PeerConnection, error) {
pc, err := globalWebRTCAPI.NewPeerConnection(globalWebRTCConfig)
if err != nil {
return nil, err
}
// Log connection state changes and handle failed/disconnected connections
pc.OnConnectionStateChange(func(connectionState webrtc.PeerConnectionState) {
// Close PeerConnection in cases
if connectionState == webrtc.PeerConnectionStateFailed ||
connectionState == webrtc.PeerConnectionStateDisconnected ||
connectionState == webrtc.PeerConnectionStateClosed {
err := pc.Close()
if err != nil {
log.Printf("Error closing PeerConnection: %s\n", err.Error())
}
onClose()
}
})
return pc, nil
}

View File

@@ -0,0 +1,72 @@
package relay
import (
"github.com/pion/webrtc/v4"
"log"
)
// NestriDataChannel is a custom data channel with callbacks
type NestriDataChannel struct {
*webrtc.DataChannel
binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback
}
// NewNestriDataChannel creates a new NestriDataChannel from *webrtc.DataChannel
func NewNestriDataChannel(dc *webrtc.DataChannel) *NestriDataChannel {
ndc := &NestriDataChannel{
DataChannel: dc,
binaryCallbacks: make(map[string]OnMessageCallback),
}
// Handler for incoming messages
ndc.OnMessage(func(msg webrtc.DataChannelMessage) {
// If string type message, ignore
if msg.IsString {
return
}
// Decode message
var base MessageBase
if err := DecodeMessage(msg.Data, &base); err != nil {
log.Printf("Failed to decode binary DataChannel message, reason: %s\n", err)
return
}
// Handle message type callback
if callback, ok := ndc.binaryCallbacks[base.PayloadType]; ok {
go callback(msg.Data)
} // TODO: Log unknown message type?
})
return ndc
}
// SendBinary sends a binary message to the data channel
func (ndc *NestriDataChannel) SendBinary(data []byte) error {
return ndc.Send(data)
}
// RegisterMessageCallback registers a callback for a given binary message type
func (ndc *NestriDataChannel) RegisterMessageCallback(msgType string, callback OnMessageCallback) {
if ndc.binaryCallbacks == nil {
ndc.binaryCallbacks = make(map[string]OnMessageCallback)
}
ndc.binaryCallbacks[msgType] = callback
}
// UnregisterMessageCallback removes the callback for a given binary message type
func (ndc *NestriDataChannel) UnregisterMessageCallback(msgType string) {
if ndc.binaryCallbacks != nil {
delete(ndc.binaryCallbacks, msgType)
}
}
// RegisterOnOpen registers a callback for the data channel opening
func (ndc *NestriDataChannel) RegisterOnOpen(callback func()) {
ndc.OnOpen(callback)
}
// RegisterOnClose registers a callback for the data channel closing
func (ndc *NestriDataChannel) RegisterOnClose(callback func()) {
ndc.OnClose(callback)
}

View File

@@ -0,0 +1,189 @@
package relay
import (
"github.com/pion/webrtc/v4"
"log"
)
func participantHandler(participant *Participant, room *Room) {
// Callback for closing PeerConnection
onPCClose := func() {
if GetFlags().Verbose {
log.Printf("Closed PeerConnection for participant: '%s'\n", participant.ID)
}
room.removeParticipantByID(participant.ID)
}
var err error
participant.PeerConnection, err = CreatePeerConnection(onPCClose)
if err != nil {
log.Printf("Failed to create PeerConnection for participant: '%s' - reason: %s\n", participant.ID, err)
return
}
// Data channel settings
settingOrdered := false
settingMaxRetransmits := uint16(0)
dc, err := participant.PeerConnection.CreateDataChannel("data", &webrtc.DataChannelInit{
Ordered: &settingOrdered,
MaxRetransmits: &settingMaxRetransmits,
})
if err != nil {
log.Printf("Failed to create data channel for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
participant.DataChannel = NewNestriDataChannel(dc)
// Register channel opening handling
participant.DataChannel.RegisterOnOpen(func() {
if GetFlags().Verbose {
log.Printf("DataChannel open for participant: %s\n", participant.ID)
}
})
// Register channel closing handling
participant.DataChannel.RegisterOnClose(func() {
if GetFlags().Verbose {
log.Printf("DataChannel closed for participant: %s\n", participant.ID)
}
})
// Register text message handling
participant.DataChannel.RegisterMessageCallback("input", func(data []byte) {
// Send to room if it has a DataChannel
if room.DataChannel != nil {
// If debug mode, decode and add our timestamp, otherwise just send to room
if GetFlags().Debug {
var inputMsg MessageInput
if err = DecodeMessage(data, &inputMsg); err != nil {
log.Printf("Failed to decode input message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
inputMsg.LatencyTracker.AddTimestamp("relay_to_node")
// Encode and send
if data, err = EncodeMessage(inputMsg); err != nil {
log.Printf("Failed to encode input message for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
if err = room.DataChannel.SendBinary(data); err != nil {
log.Printf("Failed to send input message to room: '%s' - reason: %s\n", room.Name, err)
}
} else {
if err = room.DataChannel.SendBinary(data); err != nil {
log.Printf("Failed to send input message to room: '%s' - reason: %s\n", room.Name, err)
}
}
}
})
participant.PeerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
return
}
if GetFlags().Verbose {
log.Printf("ICE candidate for participant: '%s' in room: '%s'\n", participant.ID, room.Name)
}
err = participant.WebSocket.SendICECandidateMessageWS(candidate.ToJSON())
if err != nil {
log.Printf("Failed to send ICE candidate for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
})
iceHolder := make([]webrtc.ICECandidateInit, 0)
// ICE callback
participant.WebSocket.RegisterMessageCallback("ice", func(data []byte) {
var iceMsg MessageICECandidate
if err = DecodeMessage(data, &iceMsg); err != nil {
log.Printf("Failed to decode ICE message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
candidate := webrtc.ICECandidateInit{
Candidate: iceMsg.Candidate.Candidate,
}
if participant.PeerConnection.RemoteDescription() != nil {
if err = participant.PeerConnection.AddICECandidate(candidate); err != nil {
log.Printf("Failed to add ICE candidate from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
// Add held ICE candidates
for _, heldCandidate := range iceHolder {
if err = participant.PeerConnection.AddICECandidate(heldCandidate); err != nil {
log.Printf("Failed to add held ICE candidate from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
}
iceHolder = nil
} else {
iceHolder = append(iceHolder, candidate)
}
})
// SDP answer callback
participant.WebSocket.RegisterMessageCallback("sdp", func(data []byte) {
var sdpMsg MessageSDP
if err = DecodeMessage(data, &sdpMsg); err != nil {
log.Printf("Failed to decode SDP message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
handleParticipantSDP(participant, sdpMsg)
})
// Log callback
participant.WebSocket.RegisterMessageCallback("log", func(data []byte) {
var logMsg MessageLog
if err = DecodeMessage(data, &logMsg); err != nil {
log.Printf("Failed to decode log message from participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
return
}
// TODO: Handle log message sending to metrics server
})
// Metrics callback
participant.WebSocket.RegisterMessageCallback("metrics", func(data []byte) {
// Ignore for now
})
participant.WebSocket.RegisterOnClose(func() {
if GetFlags().Verbose {
log.Printf("WebSocket closed for participant: '%s' in room: '%s'\n", participant.ID, room.Name)
}
// Remove from Room
room.removeParticipantByID(participant.ID)
})
log.Printf("Participant: '%s' in room: '%s' is now ready, sending an OK\n", participant.ID, room.Name)
if err = participant.WebSocket.SendAnswerMessageWS(AnswerOK); err != nil {
log.Printf("Failed to send OK answer for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
// If room is already online, send also offer
if room.Online {
if room.AudioTrack != nil {
if err = participant.addTrack(&room.AudioTrack); err != nil {
log.Printf("Failed to add audio track for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
}
if room.VideoTrack != nil {
if err = participant.addTrack(&room.VideoTrack); err != nil {
log.Printf("Failed to add video track for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
}
if err = participant.signalOffer(); err != nil {
log.Printf("Failed to signal offer for participant: '%s' in room: '%s' - reason: %s\n", participant.ID, room.Name, err)
}
}
}
// SDP answer handler for participants
func handleParticipantSDP(participant *Participant, answerMsg MessageSDP) {
// Get SDP offer
sdpAnswer := answerMsg.SDP.SDP
// Set remote description
err := participant.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer,
SDP: sdpAnswer,
})
if err != nil {
log.Printf("Failed to set remote description for participant: '%s' - reason: %s\n", participant.ID, err)
}
}

View File

@@ -0,0 +1,82 @@
package relay
import (
"flag"
"log"
"os"
"strconv"
"github.com/pion/webrtc/v4"
)
var globalFlags *Flags
type Flags struct {
Verbose bool
Debug bool
EndpointPort int
WebRTCUDPStart int
WebRTCUDPEnd int
STUNServer string
}
func (flags *Flags) DebugLog() {
log.Println("Relay Flags:")
log.Println("> Verbose: ", flags.Verbose)
log.Println("> Debug: ", flags.Debug)
log.Println("> Endpoint Port: ", flags.EndpointPort)
log.Println("> WebRTC UDP Range Start: ", flags.WebRTCUDPStart)
log.Println("> WebRTC UDP Range End: ", flags.WebRTCUDPEnd)
log.Println("> WebRTC STUN Server: ", flags.STUNServer)
}
func getEnvAsInt(name string, defaultVal int) int {
valueStr := os.Getenv(name)
if value, err := strconv.Atoi(valueStr); err != nil {
return defaultVal
} else {
return value
}
}
func getEnvAsBool(name string, defaultVal bool) bool {
valueStr := os.Getenv(name)
val, err := strconv.ParseBool(valueStr)
if err != nil {
return defaultVal
}
return val
}
func getEnvAsString(name string, defaultVal string) string {
valueStr := os.Getenv(name)
if len(valueStr) == 0 {
return defaultVal
}
return valueStr
}
func InitFlags() {
// Create Flags struct
globalFlags = &Flags{}
// Get flags
flag.BoolVar(&globalFlags.Verbose, "verbose", getEnvAsBool("VERBOSE", false), "Verbose mode")
flag.BoolVar(&globalFlags.Debug, "debug", getEnvAsBool("DEBUG", false), "Debug mode")
flag.IntVar(&globalFlags.EndpointPort, "endpointPort", getEnvAsInt("ENDPOINT_PORT", 8088), "HTTP endpoint port")
flag.IntVar(&globalFlags.WebRTCUDPStart, "webrtcUDPStart", getEnvAsInt("WEBRTC_UDP_START", 10000), "WebRTC UDP port range start")
flag.IntVar(&globalFlags.WebRTCUDPEnd, "webrtcUDPEnd", getEnvAsInt("WEBRTC_UDP_END", 20000), "WebRTC UDP port range end")
flag.StringVar(&globalFlags.STUNServer, "stunServer", getEnvAsString("STUN_SERVER", "stun.l.google.com:19302"), "WebRTC STUN server")
// Parse flags
flag.Parse()
// ICE STUN servers
globalWebRTCConfig.ICEServers = []webrtc.ICEServer{
{
URLs: []string{"stun:" + globalFlags.STUNServer},
},
}
}
func GetFlags() *Flags {
return globalFlags
}

View File

@@ -0,0 +1,123 @@
package relay
import (
"github.com/gorilla/websocket"
"log"
"net/http"
"strconv"
)
var httpMux *http.ServeMux
func InitHTTPEndpoint() {
// Create HTTP mux which serves our WS endpoint
httpMux = http.NewServeMux()
// Endpoints themselves
httpMux.Handle("/", http.NotFoundHandler())
httpMux.HandleFunc("/api/ws/{roomName}", corsAnyHandler(wsHandler))
// Get our serving port
port := GetFlags().EndpointPort
// Log and start the endpoint server
log.Println("Starting HTTP endpoint server on :", strconv.Itoa(port))
go func() {
log.Fatal((&http.Server{
Handler: httpMux,
Addr: ":" + strconv.Itoa(port),
}).ListenAndServe())
}()
}
// logHTTPError logs (if verbose) and sends an error code to requester
func logHTTPError(w http.ResponseWriter, err string, code int) {
if GetFlags().Verbose {
log.Println(err)
}
http.Error(w, err, code)
}
// corsAnyHandler allows any origin to access the endpoint
func corsAnyHandler(next func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(res http.ResponseWriter, req *http.Request) {
// Allow all origins
res.Header().Set("Access-Control-Allow-Origin", "*")
res.Header().Set("Access-Control-Allow-Methods", "*")
res.Header().Set("Access-Control-Allow-Headers", "*")
if req.Method != http.MethodOptions {
next(res, req)
}
}
}
// wsHandler is the handler for the /api/ws/{roomName} endpoint
func wsHandler(w http.ResponseWriter, r *http.Request) {
// Get given room name now
roomName := r.PathValue("roomName")
if len(roomName) <= 0 {
logHTTPError(w, "no room name given", http.StatusBadRequest)
return
}
// Get or create room in any case
room := GetOrCreateRoom(roomName)
// Upgrade to WebSocket
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
wsConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logHTTPError(w, err.Error(), http.StatusInternalServerError)
return
}
// Create SafeWebSocket
ws := NewSafeWebSocket(wsConn)
// Assign message handler for join request
ws.RegisterMessageCallback("join", func(data []byte) {
var joinMsg MessageJoin
if err = DecodeMessage(data, &joinMsg); err != nil {
log.Printf("Failed to decode join message: %s\n", err)
return
}
if GetFlags().Verbose {
log.Printf("Join request for room: '%s' from: '%s'\n", room.Name, joinMsg.JoinerType.String())
}
// Handle join request, depending if it's from ingest/node or participant/client
switch joinMsg.JoinerType {
case JoinerNode:
// If room already online, send InUse answer
if room.Online {
if err = ws.SendAnswerMessageWS(AnswerInUse); err != nil {
log.Printf("Failed to send InUse answer for Room: '%s' - reason: %s\n", room.Name, err)
}
return
}
room.assignWebSocket(ws)
go ingestHandler(room)
case JoinerClient:
// Create participant and add to room regardless of online status
participant := NewParticipant(ws)
room.addParticipant(participant)
// If room not online, send Offline answer
if !room.Online {
if err = ws.SendAnswerMessageWS(AnswerOffline); err != nil {
log.Printf("Failed to send Offline answer for Room: '%s' - reason: %s\n", room.Name, err)
}
}
go participantHandler(participant, room)
default:
log.Printf("Unknown joiner type: %d\n", joinMsg.JoinerType)
}
// Unregister ourselves, if something happens on the other side they should just reconnect?
ws.UnregisterMessageCallback("join")
})
}

View File

@@ -0,0 +1,251 @@
package relay
import (
"errors"
"fmt"
"github.com/pion/webrtc/v4"
"io"
"log"
"strings"
)
func ingestHandler(room *Room) {
// Callback for closing PeerConnection
onPCClose := func() {
if GetFlags().Verbose {
log.Printf("Closed PeerConnection for room: '%s'\n", room.Name)
}
room.Online = false
DeleteRoomIfEmpty(room)
}
var err error
room.PeerConnection, err = CreatePeerConnection(onPCClose)
if err != nil {
log.Printf("Failed to create PeerConnection for room: '%s' - reason: %s\n", room.Name, err)
return
}
room.PeerConnection.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
var localTrack *webrtc.TrackLocalStaticRTP
if remoteTrack.Kind() == webrtc.RTPCodecTypeVideo {
if GetFlags().Verbose {
log.Printf("Received video track for room: '%s'\n", room.Name)
}
localTrack, err = webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, "video", fmt.Sprint("nestri-", room.Name))
if err != nil {
log.Printf("Failed to create local video track for room: '%s' - reason: %s\n", room.Name, err)
return
}
room.VideoTrack = localTrack
} else if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio {
if GetFlags().Verbose {
log.Printf("Received audio track for room: '%s'\n", room.Name)
}
localTrack, err = webrtc.NewTrackLocalStaticRTP(remoteTrack.Codec().RTPCodecCapability, "audio", fmt.Sprint("nestri-", room.Name))
if err != nil {
log.Printf("Failed to create local audio track for room: '%s' - reason: %s\n", room.Name, err)
return
}
room.AudioTrack = localTrack
}
// If both audio and video tracks are set, set online state
if room.AudioTrack != nil && room.VideoTrack != nil {
room.Online = true
if GetFlags().Verbose {
log.Printf("Room online and receiving: '%s' - signaling participants\n", room.Name)
}
room.signalParticipantsWithTracks()
}
rtpBuffer := make([]byte, 1400)
for {
read, _, err := remoteTrack.Read(rtpBuffer)
if err != nil {
// EOF is expected when stopping room
if !errors.Is(err, io.EOF) {
log.Printf("RTP read error from room: '%s' - reason: %s\n", room.Name, err)
}
break
}
_, err = localTrack.Write(rtpBuffer[:read])
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
log.Printf("Failed to write RTP to local track for room: '%s' - reason: %s\n", room.Name, err)
break
}
}
if remoteTrack.Kind() == webrtc.RTPCodecTypeVideo {
room.VideoTrack = nil
} else if remoteTrack.Kind() == webrtc.RTPCodecTypeAudio {
room.AudioTrack = nil
}
if room.VideoTrack == nil && room.AudioTrack == nil {
room.Online = false
if GetFlags().Verbose {
log.Printf("Room offline and not receiving: '%s'\n", room.Name)
}
// Signal participants of room offline
room.signalParticipantsOffline()
DeleteRoomIfEmpty(room)
}
})
room.PeerConnection.OnDataChannel(func(dc *webrtc.DataChannel) {
room.DataChannel = NewNestriDataChannel(dc)
if GetFlags().Verbose {
log.Printf("New DataChannel for room: '%s' - '%s'\n", room.Name, room.DataChannel.Label())
}
// Register channel opening handling
room.DataChannel.RegisterOnOpen(func() {
if GetFlags().Verbose {
log.Printf("DataChannel for room: '%s' - '%s' open\n", room.Name, room.DataChannel.Label())
}
})
room.DataChannel.OnClose(func() {
if GetFlags().Verbose {
log.Printf("DataChannel for room: '%s' - '%s' closed\n", room.Name, room.DataChannel.Label())
}
})
// We do not handle any messages from ingest via DataChannel yet
})
room.PeerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
return
}
if GetFlags().Verbose {
log.Printf("ICE candidate for room: '%s'\n", room.Name)
}
err = room.WebSocket.SendICECandidateMessageWS(candidate.ToJSON())
if err != nil {
log.Printf("Failed to send ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
}
})
iceHolder := make([]webrtc.ICECandidateInit, 0)
// ICE callback
room.WebSocket.RegisterMessageCallback("ice", func(data []byte) {
var iceMsg MessageICECandidate
if err = DecodeMessage(data, &iceMsg); err != nil {
log.Printf("Failed to decode ICE candidate message from ingest for room: '%s' - reason: %s\n", room.Name, err)
return
}
candidate := webrtc.ICECandidateInit{
Candidate: iceMsg.Candidate.Candidate,
}
if room.PeerConnection != nil {
// If remote isn't set yet, store ICE candidates
if room.PeerConnection.RemoteDescription() != nil {
if err = room.PeerConnection.AddICECandidate(candidate); err != nil {
log.Printf("Failed to add ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
}
// Add any held ICE candidates
for _, heldCandidate := range iceHolder {
if err = room.PeerConnection.AddICECandidate(heldCandidate); err != nil {
log.Printf("Failed to add held ICE candidate for room: '%s' - reason: %s\n", room.Name, err)
}
}
iceHolder = nil
} else {
iceHolder = append(iceHolder, candidate)
}
} else {
log.Printf("ICE candidate received before PeerConnection for room: '%s'\n", room.Name)
}
})
// SDP offer callback
room.WebSocket.RegisterMessageCallback("sdp", func(data []byte) {
var sdpMsg MessageSDP
if err = DecodeMessage(data, &sdpMsg); err != nil {
log.Printf("Failed to decode SDP message from ingest for room: '%s' - reason: %s\n", room.Name, err)
return
}
answer := handleIngestSDP(room, sdpMsg)
if answer != nil {
if err = room.WebSocket.SendSDPMessageWS(*answer); err != nil {
log.Printf("Failed to send SDP answer to ingest for room: '%s' - reason: %s\n", room.Name, err)
}
} else {
log.Printf("Failed to handle SDP message from ingest for room: '%s'\n", room.Name)
}
})
// Log callback
room.WebSocket.RegisterMessageCallback("log", func(data []byte) {
var logMsg MessageLog
if err = DecodeMessage(data, &logMsg); err != nil {
log.Printf("Failed to decode log message from ingest for room: '%s' - reason: %s\n", room.Name, err)
return
}
// TODO: Handle log message sending to metrics server
})
// Metrics callback
room.WebSocket.RegisterMessageCallback("metrics", func(data []byte) {
var metricsMsg MessageMetrics
if err = DecodeMessage(data, &metricsMsg); err != nil {
log.Printf("Failed to decode metrics message from ingest for room: '%s' - reason: %s\n", room.Name, err)
return
}
// TODO: Handle metrics message sending to metrics server
})
room.WebSocket.RegisterOnClose(func() {
// If PeerConnection is not open or does not exist, delete room
if (room.PeerConnection != nil && room.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateConnected) ||
room.PeerConnection == nil {
DeleteRoomIfEmpty(room)
}
})
log.Printf("Room: '%s' is ready, sending an OK\n", room.Name)
if err = room.WebSocket.SendAnswerMessageWS(AnswerOK); err != nil {
log.Printf("Failed to send OK answer for room: '%s' - reason: %s\n", room.Name, err)
}
}
// SDP offer handler, returns SDP answer
func handleIngestSDP(room *Room, offerMsg MessageSDP) *webrtc.SessionDescription {
var err error
// Get SDP offer
sdpOffer := offerMsg.SDP.SDP
// Modify SDP offer to remove opus "sprop-maxcapturerate=24000" (fixes opus bad quality issue, present in GStreamer)
sdpOffer = strings.Replace(sdpOffer, ";sprop-maxcapturerate=24000", "", -1)
// Set new remote description
err = room.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer,
SDP: sdpOffer,
})
if err != nil {
log.Printf("Failed to set remote description for room: '%s' - reason: %s\n", room.Name, err)
return nil
}
// Create SDP answer
answer, err := room.PeerConnection.CreateAnswer(nil)
if err != nil {
log.Printf("Failed to create SDP answer for room: '%s' - reason: %s\n", room.Name, err)
return nil
}
// Set local description
err = room.PeerConnection.SetLocalDescription(answer)
if err != nil {
log.Printf("Failed to set local description for room: '%s' - reason: %s\n", room.Name, err)
return nil
}
return &answer
}

View File

@@ -0,0 +1,114 @@
package relay
import (
"fmt"
"time"
)
type TimestampEntry struct {
Stage string `json:"stage"`
Time string `json:"time"` // ISO 8601 string
}
// LatencyTracker provides a generic structure for measuring time taken at various stages in message processing.
// It can be embedded in message structs for tracking the flow of data and calculating round-trip latency.
type LatencyTracker struct {
SequenceID string `json:"sequence_id"`
Timestamps []TimestampEntry `json:"timestamps"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// NewLatencyTracker initializes a new LatencyTracker with the given sequence ID
func NewLatencyTracker(sequenceID string) *LatencyTracker {
return &LatencyTracker{
SequenceID: sequenceID,
Timestamps: make([]TimestampEntry, 0),
Metadata: make(map[string]string),
}
}
// AddTimestamp adds a new timestamp for a specific stage
func (lt *LatencyTracker) AddTimestamp(stage string) {
lt.Timestamps = append(lt.Timestamps, TimestampEntry{
Stage: stage,
// Ensure extremely precise UTC RFC3339 timestamps (down to nanoseconds)
Time: time.Now().UTC().Format(time.RFC3339Nano),
})
}
// TotalLatency calculates the total latency from the earliest to the latest timestamp
func (lt *LatencyTracker) TotalLatency() (int64, error) {
if len(lt.Timestamps) < 2 {
return 0, nil // Not enough timestamps to calculate latency
}
var earliest, latest time.Time
for _, ts := range lt.Timestamps {
t, err := time.Parse(time.RFC3339, ts.Time)
if err != nil {
return 0, err
}
if earliest.IsZero() || t.Before(earliest) {
earliest = t
}
if latest.IsZero() || t.After(latest) {
latest = t
}
}
return latest.Sub(earliest).Milliseconds(), nil
}
// PainPoints returns a list of stages where the duration exceeds the given threshold.
func (lt *LatencyTracker) PainPoints(threshold time.Duration) []string {
var painPoints []string
var lastStage string
var lastTime time.Time
for _, ts := range lt.Timestamps {
stage := ts.Stage
t := ts.Time
if lastStage == "" {
lastStage = stage
lastTime, _ = time.Parse(time.RFC3339, t)
continue
}
currentTime, _ := time.Parse(time.RFC3339, t)
if currentTime.Sub(lastTime) > threshold {
painPoints = append(painPoints, fmt.Sprintf("%s -> %s", lastStage, stage))
}
lastStage = stage
lastTime = currentTime
}
return painPoints
}
// StageLatency calculates the time taken between two specific stages.
func (lt *LatencyTracker) StageLatency(startStage, endStage string) (time.Duration, error) {
startTime, endTime := "", ""
for _, ts := range lt.Timestamps {
if ts.Stage == startStage {
startTime = ts.Time
}
if ts.Stage == endStage {
endTime = ts.Time
}
}
if startTime == "" || endTime == "" {
return 0, fmt.Errorf("missing timestamps for stages: %s -> %s", startStage, endStage)
}
start, err := time.Parse(time.RFC3339, startTime)
if err != nil {
return 0, err
}
end, err := time.Parse(time.RFC3339, endTime)
if err != nil {
return 0, err
}
return end.Sub(start), nil
}

View File

@@ -0,0 +1,227 @@
package relay
import (
"bytes"
"compress/gzip"
"encoding/json"
"fmt"
"github.com/pion/webrtc/v4"
"time"
)
// OnMessageCallback is a callback for binary messages of given type
type OnMessageCallback func(data []byte)
// MessageBase is the base type for WS/DC messages.
type MessageBase struct {
PayloadType string `json:"payload_type"`
LatencyTracker LatencyTracker `json:"latency_tracker,omitempty"`
}
// MessageInput represents an input message.
type MessageInput struct {
MessageBase
Data string `json:"data"`
}
// MessageLog represents a log message.
type MessageLog struct {
MessageBase
Level string `json:"level"`
Message string `json:"message"`
Time string `json:"time"`
}
// MessageMetrics represents a metrics/heartbeat message.
type MessageMetrics struct {
MessageBase
UsageCPU float64 `json:"usage_cpu"`
UsageMemory float64 `json:"usage_memory"`
Uptime uint64 `json:"uptime"`
PipelineLatency float64 `json:"pipeline_latency"`
}
// MessageICECandidate represents an ICE candidate message.
type MessageICECandidate struct {
MessageBase
Candidate webrtc.ICECandidateInit `json:"candidate"`
}
// MessageSDP represents an SDP message.
type MessageSDP struct {
MessageBase
SDP webrtc.SessionDescription `json:"sdp"`
}
// JoinerType is an enum for the type of incoming room joiner
type JoinerType int
const (
JoinerNode JoinerType = iota
JoinerClient
)
func (jt *JoinerType) String() string {
switch *jt {
case JoinerNode:
return "node"
case JoinerClient:
return "client"
default:
return "unknown"
}
}
// MessageJoin is used to tell us that either participant or ingest wants to join the room
type MessageJoin struct {
MessageBase
JoinerType JoinerType `json:"joiner_type"`
}
// AnswerType is an enum for the type of answer, signaling Room state for a joiner
type AnswerType int
const (
AnswerOffline AnswerType = iota // For participant/client, when the room is offline without stream
AnswerInUse // For ingest/node joiner, when the room is already in use by another ingest/node
AnswerOK // For both, when the join request is handled successfully
)
// MessageAnswer is used to send the answer to a join request
type MessageAnswer struct {
MessageBase
AnswerType AnswerType `json:"answer_type"`
}
// EncodeMessage encodes a message to be sent with gzip compression
func EncodeMessage(msg interface{}) ([]byte, error) {
// Marshal the message to JSON
data, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("failed to encode message: %w", err)
}
// Gzip compress the JSON
var compressedData bytes.Buffer
writer := gzip.NewWriter(&compressedData)
_, err = writer.Write(data)
if err != nil {
return nil, fmt.Errorf("failed to compress message: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("failed to finalize compression: %w", err)
}
return compressedData.Bytes(), nil
}
// DecodeMessage decodes a message received with gzip decompression
func DecodeMessage(data []byte, target interface{}) error {
// Gzip decompress the data
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return fmt.Errorf("failed to initialize decompression: %w", err)
}
defer func(reader *gzip.Reader) {
if err = reader.Close(); err != nil {
fmt.Printf("failed to close reader: %v\n", err)
}
}(reader)
// Decode the JSON
err = json.NewDecoder(reader).Decode(target)
if err != nil {
return fmt.Errorf("failed to decode message: %w", err)
}
return nil
}
// SendLogMessageWS sends a log message to the given WebSocket connection.
func (ws *SafeWebSocket) SendLogMessageWS(level, message string) error {
msg := MessageLog{
MessageBase: MessageBase{PayloadType: "log"},
Level: level,
Message: message,
Time: time.Now().Format(time.RFC3339),
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode log message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendMetricsMessageWS sends a metrics message to the given WebSocket connection.
func (ws *SafeWebSocket) SendMetricsMessageWS(usageCPU, usageMemory float64, uptime uint64, pipelineLatency float64) error {
msg := MessageMetrics{
MessageBase: MessageBase{PayloadType: "metrics"},
UsageCPU: usageCPU,
UsageMemory: usageMemory,
Uptime: uptime,
PipelineLatency: pipelineLatency,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode metrics message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendICECandidateMessageWS sends an ICE candidate message to the given WebSocket connection.
func (ws *SafeWebSocket) SendICECandidateMessageWS(candidate webrtc.ICECandidateInit) error {
msg := MessageICECandidate{
MessageBase: MessageBase{PayloadType: "ice"},
Candidate: candidate,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode ICE candidate message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendSDPMessageWS sends an SDP message to the given WebSocket connection.
func (ws *SafeWebSocket) SendSDPMessageWS(sdp webrtc.SessionDescription) error {
msg := MessageSDP{
MessageBase: MessageBase{PayloadType: "sdp"},
SDP: sdp,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode SDP message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendAnswerMessageWS sends an answer message to the given WebSocket connection.
func (ws *SafeWebSocket) SendAnswerMessageWS(answer AnswerType) error {
msg := MessageAnswer{
MessageBase: MessageBase{PayloadType: "answer"},
AnswerType: answer,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode answer message: %w", err)
}
return ws.SendBinary(encoded)
}
// SendInputMessageDC sends an input message to the given DataChannel connection.
func (ndc *NestriDataChannel) SendInputMessageDC(data string) error {
msg := MessageInput{
MessageBase: MessageBase{PayloadType: "input"},
Data: data,
}
encoded, err := EncodeMessage(msg)
if err != nil {
return fmt.Errorf("failed to encode input message: %w", err)
}
return ndc.SendBinary(encoded)
}

View File

@@ -0,0 +1,69 @@
package relay
import (
"fmt"
"github.com/google/uuid"
"github.com/pion/webrtc/v4"
"math/rand"
)
type Participant struct {
ID uuid.UUID //< Internal IDs are useful to keeping unique internal track and not have conflicts later
Name string
WebSocket *SafeWebSocket
PeerConnection *webrtc.PeerConnection
DataChannel *NestriDataChannel
}
func NewParticipant(ws *SafeWebSocket) *Participant {
return &Participant{
ID: uuid.New(),
Name: createRandomName(),
WebSocket: ws,
}
}
func (vw *Participant) addTrack(trackLocal *webrtc.TrackLocal) error {
rtpSender, err := vw.PeerConnection.AddTrack(*trackLocal)
if err != nil {
return err
}
go func() {
rtcpBuffer := make([]byte, 1400)
for {
if _, _, rtcpErr := rtpSender.Read(rtcpBuffer); rtcpErr != nil {
return
}
}
}()
return nil
}
func (vw *Participant) signalOffer() error {
if vw.PeerConnection == nil {
return fmt.Errorf("peer connection is nil for participant: '%s' - cannot signal offer", vw.ID)
}
offer, err := vw.PeerConnection.CreateOffer(nil)
if err != nil {
return err
}
err = vw.PeerConnection.SetLocalDescription(offer)
if err != nil {
return err
}
return vw.WebSocket.SendSDPMessageWS(offer)
}
var namesFirst = []string{"Happy", "Sad", "Angry", "Calm", "Excited", "Bored", "Confused", "Confident", "Curious", "Depressed", "Disappointed", "Embarrassed", "Energetic", "Fearful", "Frustrated", "Glad", "Guilty", "Hopeful", "Impatient", "Jealous", "Lonely", "Motivated", "Nervous", "Optimistic", "Pessimistic", "Proud", "Relaxed", "Shy", "Stressed", "Surprised", "Tired", "Worried"}
var namesSecond = []string{"Dragon", "Unicorn", "Troll", "Goblin", "Elf", "Dwarf", "Ogre", "Gnome", "Mermaid", "Siren", "Vampire", "Ghoul", "Werewolf", "Minotaur", "Centaur", "Griffin", "Phoenix", "Wyvern", "Hydra", "Kraken"}
func createRandomName() string {
randomFirst := namesFirst[rand.Intn(len(namesFirst))]
randomSecond := namesSecond[rand.Intn(len(namesSecond))]
return randomFirst + " " + randomSecond
}

View File

@@ -0,0 +1,179 @@
package relay
import (
"github.com/google/uuid"
"github.com/pion/webrtc/v4"
"log"
"sync"
)
var Rooms = make(map[uuid.UUID]*Room) //< Room ID -> Room
var RoomsMutex = sync.RWMutex{}
func GetRoomByID(id uuid.UUID) *Room {
RoomsMutex.RLock()
defer RoomsMutex.RUnlock()
if room, ok := Rooms[id]; ok {
return room
}
return nil
}
func GetRoomByName(name string) *Room {
RoomsMutex.RLock()
defer RoomsMutex.RUnlock()
for _, room := range Rooms {
if room.Name == name {
return room
}
}
return nil
}
func GetOrCreateRoom(name string) *Room {
if room := GetRoomByName(name); room != nil {
return room
}
RoomsMutex.Lock()
room := NewRoom(name)
Rooms[room.ID] = room
if GetFlags().Verbose {
log.Printf("New room: '%s'\n", room.Name)
}
RoomsMutex.Unlock()
return room
}
func DeleteRoomIfEmpty(room *Room) {
room.ParticipantsMutex.RLock()
defer room.ParticipantsMutex.RUnlock()
if !room.Online && len(room.Participants) <= 0 {
RoomsMutex.Lock()
delete(Rooms, room.ID)
RoomsMutex.Unlock()
}
}
type Room struct {
ID uuid.UUID //< Internal IDs are useful to keeping unique internal track
Name string
Online bool //< Whether the room is currently online, i.e. receiving data from a nestri-server
WebSocket *SafeWebSocket
PeerConnection *webrtc.PeerConnection
AudioTrack webrtc.TrackLocal
VideoTrack webrtc.TrackLocal
DataChannel *NestriDataChannel
Participants map[uuid.UUID]*Participant
ParticipantsMutex sync.RWMutex
}
func NewRoom(name string) *Room {
return &Room{
ID: uuid.New(),
Name: name,
Online: false,
Participants: make(map[uuid.UUID]*Participant),
}
}
// Assigns a WebSocket connection to a Room
func (r *Room) assignWebSocket(ws *SafeWebSocket) {
// If WS already assigned, warn
if r.WebSocket != nil {
log.Printf("Warning: Room '%s' already has a WebSocket assigned\n", r.Name)
}
r.WebSocket = ws
}
// Adds a Participant to a Room
func (r *Room) addParticipant(participant *Participant) {
r.ParticipantsMutex.Lock()
r.Participants[participant.ID] = participant
r.ParticipantsMutex.Unlock()
}
// Removes a Participant from a Room by participant's ID.
// If Room is offline and this is the last participant, the room is deleted
func (r *Room) removeParticipantByID(pID uuid.UUID) {
r.ParticipantsMutex.Lock()
delete(r.Participants, pID)
r.ParticipantsMutex.Unlock()
DeleteRoomIfEmpty(r)
}
// Removes a Participant from a Room by participant's name.
// If Room is offline and this is the last participant, the room is deleted
func (r *Room) removeParticipantByName(pName string) {
r.ParticipantsMutex.Lock()
for id, p := range r.Participants {
if p.Name == pName {
delete(r.Participants, id)
break
}
}
r.ParticipantsMutex.Unlock()
DeleteRoomIfEmpty(r)
}
// Signals all participants with offer and add tracks to their PeerConnections
func (r *Room) signalParticipantsWithTracks() {
r.ParticipantsMutex.RLock()
for _, participant := range r.Participants {
// Add tracks to participant's PeerConnection
if r.AudioTrack != nil {
if err := participant.addTrack(&r.AudioTrack); err != nil {
log.Printf("Failed to add audio track to participant: '%s' - reason: %s\n", participant.ID, err)
}
}
if r.VideoTrack != nil {
if err := participant.addTrack(&r.VideoTrack); err != nil {
log.Printf("Failed to add video track to participant: '%s' - reason: %s\n", participant.ID, err)
}
}
// Signal participant with offer
if err := participant.signalOffer(); err != nil {
log.Printf("Error signaling participant: %v\n", err)
}
}
r.ParticipantsMutex.RUnlock()
}
// Signals all participants that the Room is offline
func (r *Room) signalParticipantsOffline() {
r.ParticipantsMutex.RLock()
for _, participant := range r.Participants {
if err := participant.WebSocket.SendAnswerMessageWS(AnswerOffline); err != nil {
log.Printf("Failed to send Offline answer for participant: '%s' - reason: %s\n", participant.ID, err)
}
}
r.ParticipantsMutex.RUnlock()
}
// Broadcasts a message to Room's Participant's - excluding one given ID of
func (r *Room) broadcastMessage(msg webrtc.DataChannelMessage, excludeID uuid.UUID) {
r.ParticipantsMutex.RLock()
for d, participant := range r.Participants {
if participant.DataChannel != nil {
if d != excludeID { // Don't send back to the sender
if err := participant.DataChannel.SendText(string(msg.Data)); err != nil {
log.Printf("Error broadcasting to %s: %v\n", participant.Name, err)
}
}
}
}
if r.DataChannel != nil {
if err := r.DataChannel.SendText(string(msg.Data)); err != nil {
log.Printf("Error broadcasting to Room: %v\n", err)
}
}
r.ParticipantsMutex.RUnlock()
}
// Sends message to Room (nestri-server)
func (r *Room) sendToRoom(msg webrtc.DataChannelMessage) {
if r.DataChannel != nil {
if err := r.DataChannel.SendText(string(msg.Data)); err != nil {
log.Printf("Error broadcasting to Room: %v\n", err)
}
}
}

View File

@@ -0,0 +1,114 @@
package relay
import (
"github.com/gorilla/websocket"
"log"
"sync"
)
// SafeWebSocket is a websocket with a mutex
type SafeWebSocket struct {
*websocket.Conn
sync.Mutex
binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback
}
// NewSafeWebSocket creates a new SafeWebSocket from *websocket.Conn
func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
ws := &SafeWebSocket{
Conn: conn,
binaryCallbacks: make(map[string]OnMessageCallback),
}
// Launch a goroutine to handle binary messages
go func() {
for {
// Read binary message
kind, data, err := ws.Conn.ReadMessage()
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
// If unexpected close error, break
if GetFlags().Verbose {
log.Printf("Unexpected WebSocket close error, reason: %s\n", err)
}
break
} else if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
// If closing, just break
if GetFlags().Verbose {
log.Printf("WebSocket closing\n")
}
break
} else if err != nil {
log.Printf("Failed to read WebSocket message, reason: %s\n", err)
break
}
switch kind {
case websocket.TextMessage:
// Ignore, we use binary messages
continue
case websocket.BinaryMessage:
// Decode message
var msg MessageBase
if err = DecodeMessage(data, &msg); err != nil {
log.Printf("Failed to decode binary WebSocket message, reason: %s\n", err)
continue
}
// Handle message type callback
if callback, ok := ws.binaryCallbacks[msg.PayloadType]; ok {
callback(data)
} // TODO: Log unknown message type?
default:
log.Printf("Unknown WebSocket message type: %d\n", kind)
}
}
}()
return ws
}
// SendJSON writes JSON to a websocket with a mutex
func (ws *SafeWebSocket) SendJSON(v interface{}) error {
ws.Lock()
defer ws.Unlock()
return ws.Conn.WriteJSON(v)
}
// SendBinary writes binary to a websocket with a mutex
func (ws *SafeWebSocket) SendBinary(data []byte) error {
ws.Lock()
defer ws.Unlock()
return ws.Conn.WriteMessage(websocket.BinaryMessage, data)
}
// RegisterMessageCallback sets the callback for binary message of given type
func (ws *SafeWebSocket) RegisterMessageCallback(msgType string, callback OnMessageCallback) {
ws.Lock()
defer ws.Unlock()
if ws.binaryCallbacks == nil {
ws.binaryCallbacks = make(map[string]OnMessageCallback)
}
ws.binaryCallbacks[msgType] = callback
}
// UnregisterMessageCallback removes the callback for binary message of given type
func (ws *SafeWebSocket) UnregisterMessageCallback(msgType string) {
ws.Lock()
defer ws.Unlock()
if ws.binaryCallbacks != nil {
delete(ws.binaryCallbacks, msgType)
}
}
// RegisterOnClose sets the callback for websocket closing
func (ws *SafeWebSocket) RegisterOnClose(callback func()) {
ws.SetCloseHandler(func(code int, text string) error {
// Clear our callbacks
ws.Lock()
ws.binaryCallbacks = nil
ws.Unlock()
// Call the callback
callback()
return nil
})
}

32
packages/relay/main.go Normal file
View File

@@ -0,0 +1,32 @@
package main
import (
"log"
"os"
"os/signal"
relay "relay/internal"
"syscall"
)
func main() {
var err error
stopCh := make(chan os.Signal, 1)
signal.Notify(stopCh, os.Interrupt, syscall.SIGTERM)
// Get flags and log them
relay.InitFlags()
relay.GetFlags().DebugLog()
// Init WebRTC API
err = relay.InitWebRTCAPI()
if err != nil {
log.Fatal("Failed to initialize WebRTC API: ", err)
}
// Start our HTTP endpoints
relay.InitHTTPEndpoint()
// Wait for exit signal
<-stopCh
log.Println("Shutting down gracefully by signal...")
}