feat: Add streaming support (#125)

This adds:
- [x] Keyboard and mouse handling on the frontend
- [x] Video and audio streaming from the backend to the frontend
- [x] Input server that works with Websockets

Update - 17/11
- [ ] Master docker container to run this
- [ ] Steam runtime
- [ ] Entrypoint.sh

---------

Co-authored-by: Kristian Ollikainen <14197772+DatCaptainHorse@users.noreply.github.com>
Co-authored-by: Kristian Ollikainen <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
Wanjohi
2024-12-08 14:54:56 +03:00
committed by GitHub
parent 5eb21eeadb
commit 379db1c87b
137 changed files with 12737 additions and 5234 deletions

205
packages/server/src/args.rs Normal file
View File

@@ -0,0 +1,205 @@
use clap::{Arg, Command};
pub mod app_args;
pub mod device_args;
pub mod encoding_args;
pub struct Args {
pub app: app_args::AppArgs,
pub device: device_args::DeviceArgs,
pub encoding: encoding_args::EncodingArgs,
}
impl Args {
pub fn new() -> Self {
let matches = Command::new("nestri-server")
.arg(
Arg::new("verbose")
.short('v')
.long("verbose")
.env("VERBOSE")
.help("Enable verbose output")
.default_value("false"),
)
.arg(
Arg::new("debug-feed")
.short('d')
.long("debug-feed")
.env("DEBUG_FEED")
.help("Debug by showing a window on host")
.default_value("false"),
)
.arg(
Arg::new("debug-latency")
.short('l')
.long("debug-latency")
.env("DEBUG_LATENCY")
.help("Debug latency by showing time on feed")
.default_value("false"),
)
.arg(
Arg::new("relay-url")
.short('u')
.long("relay-url")
.env("RELAY_URL")
.help("Nestri relay URL")
)
.arg(
Arg::new("resolution")
.short('r')
.long("resolution")
.env("RESOLUTION")
.help("Display/stream resolution in 'WxH' format")
.default_value("1280x720"),
)
.arg(
Arg::new("framerate")
.short('f')
.long("framerate")
.env("FRAMERATE")
.help("Display/stream framerate")
.default_value("60"),
)
.arg(
Arg::new("room")
.long("room")
.env("NESTRI_ROOM")
.help("Nestri room name/identifier")
)
.arg(
Arg::new("gpu-vendor")
.short('g')
.long("gpu-vendor")
.env("GPU_VENDOR")
.help("GPU to find by vendor (e.g. 'nvidia')")
.required(false),
)
.arg(
Arg::new("gpu-name")
.short('n')
.long("gpu-name")
.env("GPU_NAME")
.help("GPU to find by name (e.g. 'rtx 3060')")
.required(false),
)
.arg(
Arg::new("gpu-index")
.short('i')
.long("gpu-index")
.env("GPU_INDEX")
.help("GPU index, if multiple similar GPUs are present")
.default_value("0"),
)
.arg(
Arg::new("gpu-card-path")
.long("gpu-card-path")
.env("GPU_CARD_PATH")
.help("Force a specific GPU by card/render path (e.g. '/dev/dri/card0')")
.required(false)
.conflicts_with_all(["gpu-vendor", "gpu-name", "gpu-index"]),
)
.arg(
Arg::new("video-codec")
.short('c')
.long("video-codec")
.env("VIDEO_CODEC")
.help("Preferred video codec ('h264', 'h265', 'av1')")
.default_value("h264"),
)
.arg(
Arg::new("video-encoder")
.long("video-encoder")
.env("VIDEO_ENCODER")
.help("Override video encoder (e.g. 'vah264enc')")
)
.arg(
Arg::new("video-rate-control")
.long("video-rate-control")
.env("VIDEO_RATE_CONTROL")
.help("Rate control method ('cqp', 'vbr', 'cbr')")
.default_value("vbr"),
)
.arg(
Arg::new("video-cqp")
.long("video-cqp")
.env("VIDEO_CQP")
.help("Constant Quantization Parameter (CQP) quality")
.default_value("26"),
)
.arg(
Arg::new("video-bitrate")
.long("video-bitrate")
.env("VIDEO_BITRATE")
.help("Target bitrate in kbps")
.default_value("6000"),
)
.arg(
Arg::new("video-bitrate-max")
.long("video-bitrate-max")
.env("VIDEO_BITRATE_MAX")
.help("Maximum bitrate in kbps")
.default_value("8000"),
)
.arg(
Arg::new("video-encoder-type")
.long("video-encoder-type")
.env("VIDEO_ENCODER_TYPE")
.help("Encoder type ('hardware', 'software')")
.default_value("hardware"),
)
.arg(
Arg::new("audio-capture-method")
.long("audio-capture-method")
.env("AUDIO_CAPTURE_METHOD")
.help("Audio capture method ('pipewire', 'pulseaudio', 'alsa')")
.default_value("pulseaudio"),
)
.arg(
Arg::new("audio-codec")
.long("audio-codec")
.env("AUDIO_CODEC")
.help("Preferred audio codec ('opus', 'aac')")
.default_value("opus"),
)
.arg(
Arg::new("audio-encoder")
.long("audio-encoder")
.env("AUDIO_ENCODER")
.help("Override audio encoder (e.g. 'opusenc')")
)
.arg(
Arg::new("audio-rate-control")
.long("audio-rate-control")
.env("AUDIO_RATE_CONTROL")
.help("Rate control method ('cqp', 'vbr', 'cbr')")
.default_value("vbr"),
)
.arg(
Arg::new("audio-bitrate")
.long("audio-bitrate")
.env("AUDIO_BITRATE")
.help("Target bitrate in kbps")
.default_value("128"),
)
.arg(
Arg::new("audio-bitrate-max")
.long("audio-bitrate-max")
.env("AUDIO_BITRATE_MAX")
.help("Maximum bitrate in kbps")
.default_value("192"),
)
.get_matches();
Self {
app: app_args::AppArgs::from_matches(&matches),
device: device_args::DeviceArgs::from_matches(&matches),
encoding: encoding_args::EncodingArgs::from_matches(&matches),
}
}
pub fn debug_print(&self) {
self.app.debug_print();
self.device.debug_print();
self.encoding.debug_print();
}
}

View File

@@ -0,0 +1,59 @@
pub struct AppArgs {
/// Verbose output mode
pub verbose: bool,
/// Debug the pipeline by showing a window on host
pub debug_feed: bool,
/// Debug the latency by showing time in stream
pub debug_latency: bool,
/// Virtual display resolution
pub resolution: (u32, u32),
/// Virtual display framerate
pub framerate: u32,
/// Nestri relay url
pub relay_url: String,
/// Nestri room name/identifier
pub room: String,
}
impl AppArgs {
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
Self {
verbose: matches.get_one::<String>("verbose").unwrap() == "true"
|| matches.get_one::<String>("verbose").unwrap() == "1",
debug_feed: matches.get_one::<String>("debug-feed").unwrap() == "true"
|| matches.get_one::<String>("debug-feed").unwrap() == "1",
debug_latency: matches.get_one::<String>("debug-latency").unwrap() == "true"
|| matches.get_one::<String>("debug-latency").unwrap() == "1",
resolution: {
let res = matches.get_one::<String>("resolution").unwrap().clone();
let parts: Vec<&str> = res.split('x').collect();
(
parts[0].parse::<u32>().unwrap(),
parts[1].parse::<u32>().unwrap(),
)
},
framerate: matches
.get_one::<String>("framerate")
.unwrap()
.parse::<u32>()
.unwrap(),
relay_url: matches.get_one::<String>("relay-url").unwrap().clone(),
// Generate random room name if not provided
room: matches.get_one::<String>("room")
.unwrap_or(&rand::random::<u32>().to_string())
.clone(),
}
}
pub fn debug_print(&self) {
println!("AppArgs:");
println!("> verbose: {}", self.verbose);
println!("> debug_feed: {}", self.debug_feed);
println!("> debug_latency: {}", self.debug_latency);
println!("> resolution: {}x{}", self.resolution.0, self.resolution.1);
println!("> framerate: {}", self.framerate);
println!("> relay_url: {}", self.relay_url);
println!("> room: {}", self.room);
}
}

View File

@@ -0,0 +1,41 @@
pub struct DeviceArgs {
/// GPU vendor (e.g. "intel")
pub gpu_vendor: String,
/// GPU name (e.g. "a770")
pub gpu_name: String,
/// GPU index, if multiple same GPUs are present
pub gpu_index: u32,
/// GPU card/render path, sets card explicitly from such path
pub gpu_card_path: String,
}
impl DeviceArgs {
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
Self {
gpu_vendor: matches
.get_one::<String>("gpu-vendor")
.unwrap_or(&"".to_string())
.clone(),
gpu_name: matches
.get_one::<String>("gpu-name")
.unwrap_or(&"".to_string())
.clone(),
gpu_index: matches
.get_one::<String>("gpu-index")
.unwrap()
.parse::<u32>()
.unwrap(),
gpu_card_path: matches
.get_one::<String>("gpu-card-path")
.unwrap_or(&"".to_string())
.clone(),
}
}
pub fn debug_print(&self) {
println!("DeviceArgs:");
println!("> gpu_vendor: {}", self.gpu_vendor);
println!("> gpu_name: {}", self.gpu_name);
println!("> gpu_index: {}", self.gpu_index);
println!("> gpu_card_path: {}", self.gpu_card_path);
}
}

View File

@@ -0,0 +1,190 @@
use std::ops::Deref;
#[derive(Debug, PartialEq, Eq)]
pub struct RateControlCQP {
/// Constant Quantization Parameter (CQP) quality level
pub quality: u32,
}
#[derive(Debug, PartialEq, Eq)]
pub struct RateControlVBR {
/// Target bitrate in kbps
pub target_bitrate: i32,
/// Maximum bitrate in kbps
pub max_bitrate: i32,
}
#[derive(Debug, PartialEq, Eq)]
pub struct RateControlCBR {
/// Target bitrate in kbps
pub target_bitrate: i32,
}
#[derive(Debug, PartialEq, Eq)]
pub enum RateControl {
/// Constant Quantization Parameter
CQP(RateControlCQP),
/// Variable Bitrate
VBR(RateControlVBR),
/// Constant Bitrate
CBR(RateControlCBR),
}
pub struct EncodingOptionsBase {
/// Codec (e.g. "h264", "opus" etc.)
pub codec: String,
/// Overridable encoder (e.g. "vah264lpenc", "opusenc" etc.)
pub encoder: String,
/// Rate control method (e.g. "cqp", "vbr", "cbr")
pub rate_control: RateControl,
}
impl EncodingOptionsBase {
pub fn debug_print(&self) {
println!("> Codec: {}", self.codec);
println!("> Encoder: {}", self.encoder);
match &self.rate_control {
RateControl::CQP(cqp) => {
println!("> Rate Control: CQP");
println!("-> Quality: {}", cqp.quality);
}
RateControl::VBR(vbr) => {
println!("> Rate Control: VBR");
println!("-> Target Bitrate: {}", vbr.target_bitrate);
println!("-> Max Bitrate: {}", vbr.max_bitrate);
}
RateControl::CBR(cbr) => {
println!("> Rate Control: CBR");
println!("-> Target Bitrate: {}", cbr.target_bitrate);
}
}
}
}
pub struct VideoEncodingOptions {
pub base: EncodingOptionsBase,
/// Encoder type (e.g. "hardware", "software")
pub encoder_type: String,
}
impl VideoEncodingOptions {
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
Self {
base: EncodingOptionsBase {
codec: matches.get_one::<String>("video-codec").unwrap().clone(),
encoder: matches
.get_one::<String>("video-encoder")
.unwrap_or(&"".to_string())
.clone(),
rate_control: match matches.get_one::<String>("video-rate-control").unwrap().as_str() {
"cqp" => RateControl::CQP(RateControlCQP {
quality: matches.get_one::<String>("video-cqp").unwrap().parse::<u32>().unwrap(),
}),
"cbr" => RateControl::CBR(RateControlCBR {
target_bitrate: matches.get_one::<String>("video-bitrate").unwrap().parse::<i32>().unwrap(),
}),
"vbr" => RateControl::VBR(RateControlVBR {
target_bitrate: matches.get_one::<String>("video-bitrate").unwrap().parse::<i32>().unwrap(),
max_bitrate: matches.get_one::<String>("video-bitrate-max").unwrap().parse::<i32>().unwrap(),
}),
_ => panic!("Invalid rate control method for video"),
},
},
encoder_type: matches.get_one::<String>("video-encoder-type").unwrap_or(&"hardware".to_string()).clone(),
}
}
pub fn debug_print(&self) {
println!("Video Encoding Options:");
self.base.debug_print();
println!("> Encoder Type: {}", self.encoder_type);
}
}
impl Deref for VideoEncodingOptions {
type Target = EncodingOptionsBase;
fn deref(&self) -> &Self::Target {
&self.base
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum AudioCaptureMethod {
PulseAudio,
PipeWire,
ALSA,
}
impl AudioCaptureMethod {
pub fn as_str(&self) -> &str {
match self {
AudioCaptureMethod::PulseAudio => "pulseaudio",
AudioCaptureMethod::PipeWire => "pipewire",
AudioCaptureMethod::ALSA => "alsa",
}
}
}
pub struct AudioEncodingOptions {
pub base: EncodingOptionsBase,
pub capture_method: AudioCaptureMethod,
}
impl AudioEncodingOptions {
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
Self {
base: EncodingOptionsBase {
codec: matches.get_one::<String>("audio-codec").unwrap().clone(),
encoder: matches
.get_one::<String>("audio-encoder")
.unwrap_or(&"".to_string())
.clone(),
rate_control: match matches.get_one::<String>("audio-rate-control").unwrap().as_str() {
"cbr" => RateControl::CBR(RateControlCBR {
target_bitrate: matches.get_one::<String>("audio-bitrate").unwrap().parse::<i32>().unwrap(),
}),
"vbr" => RateControl::VBR(RateControlVBR {
target_bitrate: matches.get_one::<String>("audio-bitrate").unwrap().parse::<i32>().unwrap(),
max_bitrate: matches.get_one::<String>("audio-bitrate-max").unwrap().parse::<i32>().unwrap(),
}),
_ => panic!("Invalid rate control method for audio"),
},
},
capture_method: match matches.get_one::<String>("audio-capture-method").unwrap().as_str() {
"pulseaudio" => AudioCaptureMethod::PulseAudio,
"pipewire" => AudioCaptureMethod::PipeWire,
"alsa" => AudioCaptureMethod::ALSA,
// Default to PulseAudio
_ => AudioCaptureMethod::PulseAudio,
},
}
}
pub fn debug_print(&self) {
println!("Audio Encoding Options:");
self.base.debug_print();
println!("> Capture Method: {}", self.capture_method.as_str());
}
}
impl Deref for AudioEncodingOptions {
type Target = EncodingOptionsBase;
fn deref(&self) -> &Self::Target {
&self.base
}
}
pub struct EncodingArgs {
/// Video encoder options
pub video: VideoEncodingOptions,
/// Audio encoder options
pub audio: AudioEncodingOptions,
}
impl EncodingArgs {
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
Self {
video: VideoEncodingOptions::from_matches(matches),
audio: AudioEncodingOptions::from_matches(matches),
}
}
pub fn debug_print(&self) {
println!("Encoding Arguments:");
self.video.debug_print();
self.audio.debug_print();
}
}

View File

@@ -0,0 +1,598 @@
use gst::prelude::*;
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum VideoCodec {
H264,
H265,
AV1,
UNKNOWN,
}
impl VideoCodec {
pub fn to_str(&self) -> &'static str {
match self {
VideoCodec::H264 => "H.264",
VideoCodec::H265 => "H.265",
VideoCodec::AV1 => "AV1",
VideoCodec::UNKNOWN => "Unknown",
}
}
// unlike to_str, converts to gstreamer friendly codec name
pub fn to_gst_str(&self) -> &'static str {
match self {
VideoCodec::H264 => "h264",
VideoCodec::H265 => "h265",
VideoCodec::AV1 => "av1",
VideoCodec::UNKNOWN => "unknown",
}
}
// returns mime-type string
pub fn to_mime_str(&self) -> &'static str {
match self {
VideoCodec::H264 => "video/H264",
VideoCodec::H265 => "video/H265",
VideoCodec::AV1 => "video/AV1",
VideoCodec::UNKNOWN => "unknown",
}
}
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"h264" => VideoCodec::H264,
"h.264" => VideoCodec::H264,
"avc" => VideoCodec::H264,
"h265" => VideoCodec::H265,
"h.265" => VideoCodec::H265,
"hevc" => VideoCodec::H265,
"hev1" => VideoCodec::H265,
"av1" => VideoCodec::AV1,
_ => VideoCodec::UNKNOWN,
}
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum EncoderAPI {
QSV,
VAAPI,
NVENC,
AMF,
SOFTWARE,
UNKNOWN,
}
impl EncoderAPI {
pub fn to_str(&self) -> &'static str {
match self {
EncoderAPI::QSV => "Intel QuickSync Video",
EncoderAPI::VAAPI => "Video Acceleration API",
EncoderAPI::NVENC => "NVIDIA NVENC",
EncoderAPI::AMF => "AMD Media Framework",
EncoderAPI::SOFTWARE => "Software",
EncoderAPI::UNKNOWN => "Unknown",
}
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum EncoderType {
SOFTWARE,
HARDWARE,
UNKNOWN,
}
impl EncoderType {
pub fn to_str(&self) -> &'static str {
match self {
EncoderType::SOFTWARE => "Software",
EncoderType::HARDWARE => "Hardware",
EncoderType::UNKNOWN => "Unknown",
}
}
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"software" => EncoderType::SOFTWARE,
"hardware" => EncoderType::HARDWARE,
_ => EncoderType::UNKNOWN,
}
}
}
#[derive(Debug, Clone)]
pub struct VideoEncoderInfo {
pub name: String,
pub codec: VideoCodec,
pub encoder_type: EncoderType,
pub encoder_api: EncoderAPI,
pub parameters: Vec<(String, String)>,
}
impl VideoEncoderInfo {
pub fn new(
name: String,
codec: VideoCodec,
encoder_type: EncoderType,
encoder_api: EncoderAPI,
) -> Self {
Self {
name,
codec,
encoder_type,
encoder_api,
parameters: Vec::new(),
}
}
pub fn get_parameters_string(&self) -> String {
self.parameters
.iter()
.map(|(key, value)| format!("{}={}", key, value))
.collect::<Vec<String>>()
.join(" ")
}
pub fn set_parameter(&mut self, key: &str, value: &str) {
self.parameters.push((key.to_string(), value.to_string()));
}
pub fn apply_parameters(&self, element: &gst::Element, verbose: &bool) {
for (key, value) in &self.parameters {
if element.has_property(key) {
// If verbose, log property sets
if *verbose {
println!("Setting property {} to {}", key, value);
}
element.set_property_from_str(key, value);
}
}
}
}
/// Converts VA-API encoder name to low-power variant.
/// # Arguments
/// * `encoder` - The name of the VA-API encoder.
/// # Returns
/// * `&str` - The name of the low-power variant of the encoder.
fn get_low_power_encoder(encoder: &String) -> String {
if encoder.starts_with("va") && !encoder.ends_with("enc") && !encoder.ends_with("lpenc") {
// Replace "enc" substring at end with "lpenc"
let mut encoder = encoder.to_string();
encoder.truncate(encoder.len() - 3);
encoder.push_str("lpenc");
encoder
} else {
encoder.to_string()
}
}
/// Returns best guess for encoder API based on the encoder name.
/// # Arguments
/// * `encoder` - The name of the encoder.
/// # Returns
/// * `EncoderAPI` - The best guess for the encoder API.
fn get_encoder_api(encoder: &String, encoder_type: &EncoderType) -> EncoderAPI {
if *encoder_type == EncoderType::HARDWARE {
if encoder.starts_with("qsv") {
EncoderAPI::QSV
} else if encoder.starts_with("va") {
EncoderAPI::VAAPI
} else if encoder.starts_with("nv") {
EncoderAPI::NVENC
} else if encoder.starts_with("amf") {
EncoderAPI::AMF
} else {
EncoderAPI::UNKNOWN
}
} else if *encoder_type == EncoderType::SOFTWARE {
EncoderAPI::SOFTWARE
} else {
EncoderAPI::UNKNOWN
}
}
/// Returns true if system supports given encoder.
/// # Returns
/// * `bool` - True if encoder is supported, false otherwise.
fn is_encoder_supported(encoder: &String) -> bool {
gst::ElementFactory::find(encoder.as_str()).is_some()
}
/// Helper to set CQP value of known encoder
/// # Arguments
/// * `encoder` - Information about the encoder.
/// * `quality` - Constant quantization parameter (CQP) quality, recommended values are between 20-30.
/// # Returns
/// * `EncoderInfo` - Encoder with maybe updated parameters.
pub fn encoder_cqp_params(encoder: &VideoEncoderInfo, quality: u32) -> VideoEncoderInfo {
let mut encoder_optz = encoder.clone();
// Look for known keys by factory creation
let encoder = gst::ElementFactory::make(encoder_optz.name.as_str())
.build()
.unwrap();
// Get properties of the encoder
for prop in encoder.list_properties() {
let prop_name = prop.name();
// Look for known keys
if prop_name.to_lowercase().contains("qp")
&& (prop_name.to_lowercase().contains("i") || prop_name.to_lowercase().contains("min"))
{
encoder_optz.set_parameter(prop_name, &quality.to_string());
} else if prop_name.to_lowercase().contains("qp")
&& (prop_name.to_lowercase().contains("p") || prop_name.to_lowercase().contains("max"))
{
encoder_optz.set_parameter(prop_name, &(quality + 2).to_string());
}
}
encoder_optz
}
/// Helper to set VBR values of known encoder
/// # Arguments
/// * `encoder` - Information about the encoder.
/// * `bitrate` - Target bitrate in bits per second.
/// * `max_bitrate` - Maximum bitrate in bits per second.
/// # Returns
/// * `EncoderInfo` - Encoder with maybe updated parameters.
pub fn encoder_vbr_params(encoder: &VideoEncoderInfo, bitrate: u32, max_bitrate: u32) -> VideoEncoderInfo {
let mut encoder_optz = encoder.clone();
// Look for known keys by factory creation
let encoder = gst::ElementFactory::make(encoder_optz.name.as_str())
.build()
.unwrap();
// Get properties of the encoder
for prop in encoder.list_properties() {
let prop_name = prop.name();
// Look for known keys
if prop_name.to_lowercase().contains("bitrate")
&& !prop_name.to_lowercase().contains("max")
{
encoder_optz.set_parameter(prop_name, &bitrate.to_string());
} else if prop_name.to_lowercase().contains("bitrate")
&& prop_name.to_lowercase().contains("max")
{
// If SVT-AV1, don't set max bitrate
if encoder_optz.name == "svtav1enc" {
continue;
}
encoder_optz.set_parameter(prop_name, &max_bitrate.to_string());
}
}
encoder_optz
}
/// Helper to set CBR value of known encoder
/// # Arguments
/// * `encoder` - Information about the encoder.
/// * `bitrate` - Target bitrate in bits per second.
/// # Returns
/// * `EncoderInfo` - Encoder with maybe updated parameters.
pub fn encoder_cbr_params(encoder: &VideoEncoderInfo, bitrate: u32) -> VideoEncoderInfo {
let mut encoder_optz = encoder.clone();
// Look for known keys by factory creation
let encoder = gst::ElementFactory::make(encoder_optz.name.as_str())
.build()
.unwrap();
// Get properties of the encoder
for prop in encoder.list_properties() {
let prop_name = prop.name();
// Look for known keys
if prop_name.to_lowercase().contains("bitrate")
&& !prop_name.to_lowercase().contains("max")
{
encoder_optz.set_parameter(prop_name, &bitrate.to_string());
}
}
encoder_optz
}
/// Helper to set GOP size of known encoder
/// # Arguments
/// * `encoder` - Information about the encoder.
/// * `gop_size` - Group of pictures (GOP) size.
/// # Returns
/// * `EncoderInfo` - Encoder with maybe updated parameters.
pub fn encoder_gop_params(encoder: &VideoEncoderInfo, gop_size: u32) -> VideoEncoderInfo {
let mut encoder_optz = encoder.clone();
// Look for known keys by factory creation
let encoder = gst::ElementFactory::make(encoder_optz.name.as_str())
.build()
.unwrap();
// Get properties of the encoder
for prop in encoder.list_properties() {
let prop_name = prop.name();
// Look for known keys
if prop_name.to_lowercase().contains("gop")
|| prop_name.to_lowercase().contains("int-max")
|| prop_name.to_lowercase().contains("max-dist")
|| prop_name.to_lowercase().contains("intra-period-length")
{
encoder_optz.set_parameter(prop_name, &gop_size.to_string());
}
}
encoder_optz
}
/// Sets parameters of known encoders for low latency operation.
/// # Arguments
/// * `encoder` - Information about the encoder.
/// # Returns
/// * `EncoderInfo` - Encoder with maybe updated parameters.
pub fn encoder_low_latency_params(encoder: &VideoEncoderInfo) -> VideoEncoderInfo {
let mut encoder_optz = encoder.clone();
encoder_optz = encoder_gop_params(&encoder_optz, 30);
match encoder_optz.encoder_api {
EncoderAPI::QSV => {
encoder_optz.set_parameter("low-latency", "true");
encoder_optz.set_parameter("target-usage", "7");
}
EncoderAPI::VAAPI => {
encoder_optz.set_parameter("target-usage", "7");
}
EncoderAPI::NVENC => {
match encoder_optz.codec {
// nvcudah264enc supports newer presets and tunes
VideoCodec::H264 => {
encoder_optz.set_parameter("multi-pass", "disabled");
encoder_optz.set_parameter("preset", "p1");
encoder_optz.set_parameter("tune", "ultra-low-latency");
}
// same goes for nvcudah265enc
VideoCodec::H265 => {
encoder_optz.set_parameter("multi-pass", "disabled");
encoder_optz.set_parameter("preset", "p1");
encoder_optz.set_parameter("tune", "ultra-low-latency");
}
// nvav1enc only supports older presets
VideoCodec::AV1 => {
encoder_optz.set_parameter("preset", "low-latency-hp");
}
_ => {}
}
}
EncoderAPI::AMF => {
encoder_optz.set_parameter("preset", "speed");
match encoder_optz.codec {
// Only H.264 supports "ultra-low-latency" usage
VideoCodec::H264 => {
encoder_optz.set_parameter("usage", "ultra-low-latency");
}
// Same goes for H.265
VideoCodec::H265 => {
encoder_optz.set_parameter("usage", "ultra-low-latency");
}
VideoCodec::AV1 => {
encoder_optz.set_parameter("usage", "low-latency");
}
_ => {}
}
}
EncoderAPI::SOFTWARE => {
// Check encoder name for software encoders
match encoder_optz.name.as_str() {
"openh264enc" => {
encoder_optz.set_parameter("complexity", "low");
encoder_optz.set_parameter("usage-type", "screen");
}
"x264enc" => {
encoder_optz.set_parameter("rc-lookahead", "0");
encoder_optz.set_parameter("speed-preset", "ultrafast");
encoder_optz.set_parameter("tune", "zerolatency");
}
"svtav1enc" => {
encoder_optz.set_parameter("preset", "12");
// Add ":pred-struct=1" only in CBR mode
let params_string = format!(
"lookahead=0{}",
if encoder_optz.get_parameters_string().contains("cbr") {
":pred-struct=1"
} else {
""
}
);
encoder_optz.set_parameter("parameters-string", params_string.as_str());
}
"av1enc" => {
encoder_optz.set_parameter("usage-profile", "realtime");
encoder_optz.set_parameter("cpu-used", "10");
encoder_optz.set_parameter("lag-in-frames", "0");
}
_ => {}
}
}
_ => {}
}
encoder_optz
}
/// Returns all compatible encoders for the system.
/// # Returns
/// * `Vec<EncoderInfo>` - List of compatible encoders.
pub fn get_compatible_encoders() -> Vec<VideoEncoderInfo> {
let mut encoders: Vec<VideoEncoderInfo> = Vec::new();
let registry = gst::Registry::get();
let plugins = registry.plugins();
for plugin in plugins {
let features = registry.features_by_plugin(plugin.plugin_name().as_str());
for feature in features {
let encoder = feature.name().to_string();
let factory = gst::ElementFactory::find(encoder.as_str());
if factory.is_some() {
let factory = factory.unwrap();
// Get klass metadata
let klass = factory.metadata("klass");
if klass.is_some() {
// Make sure klass contains "Encoder/Video/..."
let klass = klass.unwrap().to_string();
if !klass.to_lowercase().contains("encoder/video") {
continue;
}
// If contains "/hardware" in klass, it's a hardware encoder
let encoder_type = if klass.to_lowercase().contains("/hardware") {
EncoderType::HARDWARE
} else {
EncoderType::SOFTWARE
};
let api = get_encoder_api(&encoder, &encoder_type);
if is_encoder_supported(&encoder) {
// Match codec by looking for "264" or "av1" in encoder name
let codec = if encoder.contains("264") {
VideoCodec::H264
} else if encoder.contains("265") {
VideoCodec::H265
} else if encoder.contains("av1") {
VideoCodec::AV1
} else {
continue;
};
let encoder_info = VideoEncoderInfo::new(encoder, codec, encoder_type, api);
encoders.push(encoder_info);
} else if api == EncoderAPI::VAAPI {
// Try low-power variant of VA-API encoder
let low_power_encoder = get_low_power_encoder(&encoder);
if is_encoder_supported(&low_power_encoder) {
let codec = if low_power_encoder.contains("264") {
VideoCodec::H264
} else if low_power_encoder.contains("265") {
VideoCodec::H265
} else if low_power_encoder.contains("av1") {
VideoCodec::AV1
} else {
continue;
};
let encoder_info =
VideoEncoderInfo::new(low_power_encoder, codec, encoder_type, api);
encoders.push(encoder_info);
}
}
}
}
}
}
encoders
}
/// Helper to return encoder from vector by name (case-insensitive).
/// # Arguments
/// * `encoders` - A vector containing information about each encoder.
/// * `name` - A string slice that holds the encoder name.
/// # Returns
/// * `Option<EncoderInfo>` - A reference to an EncoderInfo struct if found.
pub fn get_encoder_by_name(
encoders: &Vec<VideoEncoderInfo>,
name: &str,
) -> Option<VideoEncoderInfo> {
let name = name.to_lowercase();
encoders
.iter()
.find(|encoder| encoder.name.to_lowercase() == name)
.cloned()
}
/// Helper to get encoders from vector by video codec.
/// # Arguments
/// * `encoders` - A vector containing information about each encoder.
/// * `codec` - The codec of the encoder.
/// # Returns
/// * `Vec<EncoderInfo>` - A vector containing EncoderInfo structs if found.
pub fn get_encoders_by_videocodec(
encoders: &Vec<VideoEncoderInfo>,
codec: &VideoCodec,
) -> Vec<VideoEncoderInfo> {
encoders
.iter()
.filter(|encoder| encoder.codec == *codec)
.cloned()
.collect()
}
/// Helper to get encoders from vector by encoder type.
/// # Arguments
/// * `encoders` - A vector containing information about each encoder.
/// * `encoder_type` - The type of the encoder.
/// # Returns
/// * `Vec<EncoderInfo>` - A vector containing EncoderInfo structs if found.
pub fn get_encoders_by_type(
encoders: &Vec<VideoEncoderInfo>,
encoder_type: &EncoderType,
) -> Vec<VideoEncoderInfo> {
encoders
.iter()
.filter(|encoder| encoder.encoder_type == *encoder_type)
.cloned()
.collect()
}
/// Returns best-case compatible encoder given desired codec and encoder type.
/// # Arguments
/// * `encoders` - List of encoders to pick from.
/// * `codec` - Desired codec.
/// * `encoder_type` - Desired encoder type.
/// # Returns
/// * `Option<EncoderInfo>` - Best-case compatible encoder.
pub fn get_best_compatible_encoder(
encoders: &Vec<VideoEncoderInfo>,
codec: VideoCodec,
encoder_type: EncoderType,
) -> Option<VideoEncoderInfo> {
let mut best_encoder: Option<VideoEncoderInfo> = None;
let mut best_score: i32 = 0;
// Filter by codec and type first
let encoders = get_encoders_by_videocodec(encoders, &codec);
let encoders = get_encoders_by_type(&encoders, &encoder_type);
for encoder in encoders {
// Local score
let mut score = 0;
// API score
score += match encoder.encoder_api {
EncoderAPI::NVENC => 3,
EncoderAPI::QSV => 3,
EncoderAPI::AMF => 3,
EncoderAPI::VAAPI => 2,
EncoderAPI::SOFTWARE => 1,
EncoderAPI::UNKNOWN => 0,
};
// If software, score also based on name to get most compatible software encoder for low latency
if encoder.encoder_type == EncoderType::SOFTWARE {
score += match encoder.name.as_str() {
"openh264enc" => 2,
"x264enc" => 1,
"svtav1enc" => 2,
"av1enc" => 1,
_ => 0,
};
}
// Update best encoder based on score
if score > best_score {
best_encoder = Some(encoder.clone());
best_score = score;
}
}
best_encoder
}

233
packages/server/src/gpu.rs Normal file
View File

@@ -0,0 +1,233 @@
use regex::Regex;
use std::fs;
use std::path::Path;
use std::process::Command;
use std::str;
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum GPUVendor {
UNKNOWN,
INTEL,
NVIDIA,
AMD,
}
#[derive(Debug, Clone)]
pub struct GPUInfo {
vendor: GPUVendor,
card_path: String,
render_path: String,
device_name: String,
}
impl GPUInfo {
pub fn vendor(&self) -> &GPUVendor {
&self.vendor
}
pub fn vendor_string(&self) -> &str {
match self.vendor {
GPUVendor::INTEL => "Intel",
GPUVendor::NVIDIA => "NVIDIA",
GPUVendor::AMD => "AMD",
GPUVendor::UNKNOWN => "Unknown",
}
}
pub fn card_path(&self) -> &str {
&self.card_path
}
pub fn render_path(&self) -> &str {
&self.render_path
}
pub fn device_name(&self) -> &str {
&self.device_name
}
}
fn get_gpu_vendor(vendor_id: &str) -> GPUVendor {
match vendor_id {
"8086" => GPUVendor::INTEL, // Intel
"10de" => GPUVendor::NVIDIA, // NVIDIA
"1002" => GPUVendor::AMD, // AMD/ATI
_ => GPUVendor::UNKNOWN,
}
}
/// Retrieves a list of GPUs available on the system.
/// # Returns
/// * `Vec<GPUInfo>` - A vector containing information about each GPU.
pub fn get_gpus() -> Vec<GPUInfo> {
let mut gpus = Vec::new();
// Run lspci to get PCI devices related to GPUs
let lspci_output = Command::new("lspci")
.arg("-mmnn") // Get machine-readable output with IDs
.output()
.expect("Failed to execute lspci");
let output = str::from_utf8(&lspci_output.stdout).unwrap();
// Filter lines that mention VGA or 3D controller
for line in output.lines() {
if line.to_lowercase().contains("vga compatible controller")
|| line.to_lowercase().contains("3d controller")
|| line.to_lowercase().contains("display controller")
{
if let Some((pci_addr, vendor_id, device_name)) = parse_pci_device(line) {
// Run udevadm to get the device path
if let Some((card_path, render_path)) = get_dri_device_path(&pci_addr) {
let vendor = get_gpu_vendor(&vendor_id);
gpus.push(GPUInfo {
vendor,
card_path,
render_path,
device_name,
});
}
}
}
}
gpus
}
/// Parses a line from the lspci output to extract PCI device information.
/// # Arguments
/// * `line` - A string slice that holds a line from the 'lspci -mmnn' output.
/// # Returns
/// * `Option<(String, String, String)>` - A tuple containing the PCI address, vendor ID, and device name if parsing is successful.
fn parse_pci_device(line: &str) -> Option<(String, String, String)> {
// Define a regex pattern to match PCI device lines
let re = Regex::new(r#"(?P<pci_addr>[0-9a-fA-F]{1,2}:[0-9a-fA-F]{2}\.[0-9]) "[^"]+" "(?P<vendor_name>[^"]+)" "(?P<device_name>[^"]+)"#).unwrap();
// Collect all matched groups
let parts: Vec<(String, String, String)> = re
.captures_iter(line)
.filter_map(|cap| {
Some((
cap["pci_addr"].to_string(),
cap["vendor_name"].to_string(),
cap["device_name"].to_string(),
))
})
.collect();
if let Some((pci_addr, vendor_name, device_name)) = parts.first() {
// Create a mutable copy of the device name to modify
let mut device_name = device_name.clone();
// If more than 1 square-bracketed item is found, remove last one, otherwise remove all
if let Some(start) = device_name.rfind('[') {
if let Some(_) = device_name.rfind(']') {
device_name = device_name[..start].trim().to_string();
}
}
// Extract vendor ID from vendor name (e.g., "Intel Corporation [8086]" or "Advanced Micro Devices, Inc. [AMD/ATI] [1002]")
let vendor_id = vendor_name
.split_whitespace()
.last()
.unwrap_or_default()
.trim_matches(|c: char| !c.is_ascii_hexdigit())
.to_string();
return Some((pci_addr.clone(), vendor_id, device_name));
}
None
}
/// Retrieves the DRI device paths for a given PCI address.
/// Doubles as a way to verify the device is a GPU.
/// # Arguments
/// * `pci_addr` - A string slice that holds the PCI address.
/// # Returns
/// * `Option<(String, String)>` - A tuple containing the card path and render path if found.
fn get_dri_device_path(pci_addr: &str) -> Option<(String, String)> {
// Construct the base directory in /sys/bus/pci/devices to start the search
let base_dir = Path::new("/sys/bus/pci/devices");
// Define the target PCI address with "0000:" prefix
let target_addr = format!("0000:{}", pci_addr);
// Search for a matching directory that contains the target PCI address
for entry in fs::read_dir(base_dir).ok()?.flatten() {
let path = entry.path();
// Check if the path matches the target PCI address
if path.to_string_lossy().contains(&target_addr) {
// Look for any files under the 'drm' subdirectory, like 'card0' or 'renderD128'
let drm_path = path.join("drm");
if drm_path.exists() {
let mut card_path = String::new();
let mut render_path = String::new();
for drm_entry in fs::read_dir(drm_path).ok()?.flatten() {
let file_name = drm_entry.file_name();
if let Some(name) = file_name.to_str() {
if name.starts_with("card") {
card_path = format!("/dev/dri/{}", name);
}
if name.starts_with("renderD") {
render_path = format!("/dev/dri/{}", name);
}
// If both paths are found, break the loop
if !card_path.is_empty() && !render_path.is_empty() {
break;
}
}
}
return Some((card_path, render_path));
}
}
}
None
}
/// Helper method to get GPUs from vector by vendor name (case-insensitive).
/// # Arguments
/// * `gpus` - A vector containing information about each GPU.
/// * `vendor` - A string slice that holds the vendor name.
/// # Returns
/// * `Vec<GPUInfo>` - A vector containing GPUInfo structs if found.
pub fn get_gpus_by_vendor(gpus: &Vec<GPUInfo>, vendor: &str) -> Vec<GPUInfo> {
let vendor = vendor.to_lowercase();
gpus.iter()
.filter(|gpu| gpu.vendor_string().to_lowercase().contains(&vendor))
.cloned()
.collect()
}
/// Helper method to get GPUs from vector by device name substring (case-insensitive).
/// # Arguments
/// * `gpus` - A vector containing information about each GPU.
/// * `device_name` - A string slice that holds the device name substring.
/// # Returns
/// * `Vec<GPUInfo>` - A vector containing GPUInfo structs if found.
pub fn get_gpus_by_device_name(gpus: &Vec<GPUInfo>, device_name: &str) -> Vec<GPUInfo> {
let device_name = device_name.to_lowercase();
gpus.iter()
.filter(|gpu| gpu.device_name.to_lowercase().contains(&device_name))
.cloned()
.collect()
}
/// Helper method to get a GPU from vector of GPUInfo by card path (either /dev/dri/cardX or /dev/dri/renderDX, case-insensitive).
/// # Arguments
/// * `gpus` - A vector containing information about each GPU.
/// * `card_path` - A string slice that holds the card path.
/// # Returns
/// * `Option<GPUInfo>` - A reference to a GPUInfo struct if found.
pub fn get_gpu_by_card_path(gpus: &Vec<GPUInfo>, card_path: &str) -> Option<GPUInfo> {
for gpu in gpus {
if gpu.card_path().to_lowercase() == card_path.to_lowercase()
|| gpu.render_path().to_lowercase() == card_path.to_lowercase()
{
return Some(gpu.clone());
}
}
None
}

View File

@@ -0,0 +1,60 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct TimestampEntry {
pub stage: String,
pub time: String, // ISO 8601 timestamp
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LatencyTracker {
pub sequence_id: String,
pub timestamps: Vec<TimestampEntry>,
pub metadata: Option<HashMap<String, String>>, // Optional metadata
}
impl LatencyTracker {
// Creates a new LatencyTracker
pub fn new(sequence_id: String) -> Self {
Self {
sequence_id,
timestamps: Vec::new(),
metadata: None,
}
}
// Returns the sequence ID
pub fn sequence_id(&self) -> &str {
&self.sequence_id
}
// Adds a timestamp for a specific stage
pub fn add_timestamp(&mut self, stage: &str) {
// Ensure extremely precise UTC format (YYYY-MM-DDTHH:MM:SS.658548387Z)
let now = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Nanos, true);
self.timestamps.push(TimestampEntry {
stage: stage.to_string(),
time: now,
});
}
// Calculate total latency (first to last timestamp)
pub fn total_latency(&self) -> Option<i64> {
if self.timestamps.len() < 2 {
return None;
}
let parsed_times: Result<Vec<_>, _> = self
.timestamps
.iter()
.map(|entry| chrono::DateTime::parse_from_rfc3339(&entry.time))
.collect();
if let Ok(parsed_times) = parsed_times {
let min_time = parsed_times.iter().min().unwrap();
let max_time = parsed_times.iter().max().unwrap();
Some((*max_time - *min_time).num_milliseconds())
} else {
None
}
}
}

526
packages/server/src/main.rs Normal file
View File

@@ -0,0 +1,526 @@
mod args;
mod enc_helper;
mod gpu;
mod room;
mod websocket;
mod latency;
mod messages;
use crate::args::encoding_args;
use gst::prelude::*;
use gst_app::AppSink;
use std::error::Error;
use std::str::FromStr;
use std::sync::Arc;
use futures_util::StreamExt;
use gst_app::app_sink::AppSinkStream;
use tokio::sync::{Mutex};
use crate::websocket::{NestriWebSocket};
// Handles gathering GPU information and selecting the most suitable GPU
fn handle_gpus(args: &args::Args) -> Option<gpu::GPUInfo> {
println!("Gathering GPU information..");
let gpus = gpu::get_gpus();
if gpus.is_empty() {
println!("No GPUs found");
return None;
}
for gpu in &gpus {
println!(
"> [GPU] Vendor: '{}', Card Path: '{}', Render Path: '{}', Device Name: '{}'",
gpu.vendor_string(),
gpu.card_path(),
gpu.render_path(),
gpu.device_name()
);
}
// Based on available arguments, pick a GPU
let gpu;
if !args.device.gpu_card_path.is_empty() {
gpu = gpu::get_gpu_by_card_path(&gpus, &args.device.gpu_card_path);
} else {
// Run all filters that are not empty
let mut filtered_gpus = gpus.clone();
if !args.device.gpu_vendor.is_empty() {
filtered_gpus = gpu::get_gpus_by_vendor(&filtered_gpus, &args.device.gpu_vendor);
}
if !args.device.gpu_name.is_empty() {
filtered_gpus = gpu::get_gpus_by_device_name(&filtered_gpus, &args.device.gpu_name);
}
if args.device.gpu_index != 0 {
// get single GPU by index
gpu = filtered_gpus.get(args.device.gpu_index as usize).cloned();
} else {
// get first GPU
gpu = filtered_gpus.get(0).cloned();
}
}
if gpu.is_none() {
println!("No GPU found with the specified parameters: vendor='{}', name='{}', index='{}', card_path='{}'",
args.device.gpu_vendor, args.device.gpu_name, args.device.gpu_index, args.device.gpu_card_path);
return None;
}
let gpu = gpu.unwrap();
println!("Selected GPU: '{}'", gpu.device_name());
Some(gpu)
}
// Handles picking video encoder
fn handle_encoder_video(args: &args::Args) -> Option<enc_helper::VideoEncoderInfo> {
println!("Getting compatible video encoders..");
let video_encoders = enc_helper::get_compatible_encoders();
if video_encoders.is_empty() {
println!("No compatible video encoders found");
return None;
}
for encoder in &video_encoders {
println!(
"> [Video Encoder] Name: '{}', Codec: '{}', API: '{}', Type: '{}'",
encoder.name,
encoder.codec.to_str(),
encoder.encoder_api.to_str(),
encoder.encoder_type.to_str()
);
}
// Pick most suitable video encoder based on given arguments
let video_encoder;
if !args.encoding.video.encoder.is_empty() {
video_encoder =
enc_helper::get_encoder_by_name(&video_encoders, &args.encoding.video.encoder);
} else {
video_encoder = enc_helper::get_best_compatible_encoder(
&video_encoders,
enc_helper::VideoCodec::from_str(&args.encoding.video.codec),
enc_helper::EncoderType::from_str(&args.encoding.video.encoder_type),
);
}
if video_encoder.is_none() {
println!("No video encoder found with the specified parameters: name='{}', vcodec='{}', type='{}'",
args.encoding.video.encoder, args.encoding.video.codec, args.encoding.video.encoder_type);
return None;
}
let video_encoder = video_encoder.unwrap();
println!("Selected video encoder: '{}'", video_encoder.name);
Some(video_encoder)
}
// Handles picking preferred settings for video encoder
fn handle_encoder_video_settings(
args: &args::Args,
video_encoder: &enc_helper::VideoEncoderInfo,
) -> enc_helper::VideoEncoderInfo {
let mut optimized_encoder = enc_helper::encoder_low_latency_params(&video_encoder);
// Handle rate-control method
match &args.encoding.video.rate_control {
encoding_args::RateControl::CQP(cqp) => {
optimized_encoder = enc_helper::encoder_cqp_params(&optimized_encoder, cqp.quality);
}
encoding_args::RateControl::VBR(vbr) => {
optimized_encoder = enc_helper::encoder_vbr_params(
&optimized_encoder,
vbr.target_bitrate as u32,
vbr.max_bitrate as u32,
);
}
encoding_args::RateControl::CBR(cbr) => {
optimized_encoder =
enc_helper::encoder_cbr_params(&optimized_encoder, cbr.target_bitrate as u32);
}
}
println!(
"Selected video encoder settings: '{}'",
optimized_encoder.get_parameters_string()
);
optimized_encoder
}
// Handles picking audio encoder
// TODO: Expand enc_helper with audio types, for now just opus
fn handle_encoder_audio(args: &args::Args) -> String {
let audio_encoder = if args.encoding.audio.encoder.is_empty() {
"opusenc".to_string()
} else {
args.encoding.audio.encoder.clone()
};
println!("Selected audio encoder: '{}'", audio_encoder);
audio_encoder
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
// Parse command line arguments
let args = args::Args::new();
if args.app.verbose {
args.debug_print();
}
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install ring crypto provider");
// Begin connection attempt to the relay WebSocket endpoint
// replace any http/https with ws/wss
let replaced_relay_url
= args.app.relay_url.replace("http://", "ws://").replace("https://", "wss://");
let ws_url = format!(
"{}/api/ws/{}",
replaced_relay_url,
args.app.room,
);
// Setup our websocket
let nestri_ws = Arc::new(NestriWebSocket::new(ws_url).await?);
log::set_max_level(log::LevelFilter::Info);
log::set_boxed_logger(Box::new(nestri_ws.clone())).unwrap();
let _ = gst::init();
// Handle GPU selection
let gpu = handle_gpus(&args);
if gpu.is_none() {
log::error!("Failed to find a suitable GPU. Exiting..");
return Err("Failed to find a suitable GPU. Exiting..".into());
}
let gpu = gpu.unwrap();
// Handle video encoder selection
let video_encoder_info = handle_encoder_video(&args);
if video_encoder_info.is_none() {
log::error!("Failed to find a suitable video encoder. Exiting..");
return Err("Failed to find a suitable video encoder. Exiting..".into());
}
let mut video_encoder_info = video_encoder_info.unwrap();
// Handle video encoder settings
video_encoder_info = handle_encoder_video_settings(&args, &video_encoder_info);
// Handle audio encoder selection
let audio_encoder = handle_encoder_audio(&args);
/*** ROOM SETUP ***/
let room = Arc::new(Mutex::new(
room::Room::new(nestri_ws.clone()).await?,
));
/*** PIPELINE CREATION ***/
/* Audio */
// Audio Source Element
let audio_source = match args.encoding.audio.capture_method {
encoding_args::AudioCaptureMethod::PulseAudio => {
gst::ElementFactory::make("pulsesrc").build()?
}
encoding_args::AudioCaptureMethod::PipeWire => {
gst::ElementFactory::make("pipewiresrc").build()?
}
_ => gst::ElementFactory::make("alsasrc").build()?,
};
// Audio Converter Element
let audio_converter = gst::ElementFactory::make("audioconvert").build()?;
// Audio Rate Element
let audio_rate = gst::ElementFactory::make("audiorate").build()?;
// Required to fix gstreamer opus issue, where quality sounds off (due to wrong sample rate)
let audio_capsfilter = gst::ElementFactory::make("capsfilter").build()?;
let audio_caps = gst::Caps::from_str("audio/x-raw,rate=48000,channels=2").unwrap();
audio_capsfilter.set_property("caps", &audio_caps);
// Audio Encoder Element
let audio_encoder = gst::ElementFactory::make(audio_encoder.as_str()).build()?;
audio_encoder.set_property(
"bitrate",
&match &args.encoding.audio.rate_control {
encoding_args::RateControl::CBR(cbr) => cbr.target_bitrate * 1000i32,
encoding_args::RateControl::VBR(vbr) => vbr.target_bitrate * 1000i32,
_ => 128i32,
},
);
// Audio RTP Payloader Element
let audio_rtp_payloader = gst::ElementFactory::make("rtpopuspay").build()?;
/* Video */
// Video Source Element
let video_source = gst::ElementFactory::make("waylanddisplaysrc").build()?;
video_source.set_property("render-node", &gpu.render_path());
// Caps Filter Element (resolution, fps)
let caps_filter = gst::ElementFactory::make("capsfilter").build()?;
let caps = gst::Caps::from_str(&format!(
"video/x-raw,width={},height={},framerate={}/1,format=RGBx",
args.app.resolution.0, args.app.resolution.1, args.app.framerate
))?;
caps_filter.set_property("caps", &caps);
// Video Tee Element
let video_tee = gst::ElementFactory::make("tee").build()?;
// Video Converter Element
let video_converter = gst::ElementFactory::make("videoconvert").build()?;
// Video Encoder Element
let video_encoder = gst::ElementFactory::make(video_encoder_info.name.as_str()).build()?;
video_encoder_info.apply_parameters(&video_encoder, &args.app.verbose);
// Required for AV1 - av1parse
let av1_parse = gst::ElementFactory::make("av1parse").build()?;
// Video RTP Payloader Element
let video_rtp_payloader = gst::ElementFactory::make(
format!("rtp{}pay", video_encoder_info.codec.to_gst_str()).as_str(),
)
.build()?;
/* Output */
// Audio AppSink Element
let audio_appsink = gst::ElementFactory::make("appsink").build()?;
audio_appsink.set_property("emit-signals", &true);
let audio_appsink = audio_appsink.downcast_ref::<AppSink>().unwrap();
// Video AppSink Element
let video_appsink = gst::ElementFactory::make("appsink").build()?;
video_appsink.set_property("emit-signals", &true);
let video_appsink = video_appsink.downcast_ref::<AppSink>().unwrap();
/* Debug */
// Debug Feed Element
let debug_latency = gst::ElementFactory::make("timeoverlay").build()?;
debug_latency.set_property_from_str("halignment", &"right");
debug_latency.set_property_from_str("valignment", &"bottom");
// Debug Sink Element
let debug_sink = gst::ElementFactory::make("ximagesink").build()?;
// Debug video converter
let debug_video_converter = gst::ElementFactory::make("videoconvert").build()?;
// Queues with max 2ms latency
let debug_queue = gst::ElementFactory::make("queue2").build()?;
debug_queue.set_property("max-size-time", &1000000u64);
let main_video_queue = gst::ElementFactory::make("queue2").build()?;
main_video_queue.set_property("max-size-time", &1000000u64);
let main_audio_queue = gst::ElementFactory::make("queue2").build()?;
main_audio_queue.set_property("max-size-time", &1000000u64);
// Create the pipeline
let pipeline = gst::Pipeline::new();
// Add elements to the pipeline
pipeline.add_many(&[
&video_appsink.upcast_ref(),
&video_rtp_payloader,
&video_encoder,
&video_converter,
&video_tee,
&caps_filter,
&video_source,
&audio_appsink.upcast_ref(),
&audio_rtp_payloader,
&audio_encoder,
&audio_capsfilter,
&audio_rate,
&audio_converter,
&audio_source,
&main_video_queue,
&main_audio_queue,
])?;
// Add debug elements if debug is enabled
if args.app.debug_feed {
pipeline.add_many(&[&debug_sink, &debug_queue, &debug_video_converter])?;
}
// Add debug latency element if debug latency is enabled
if args.app.debug_latency {
pipeline.add(&debug_latency)?;
}
// Add AV1 parse element if AV1 is selected
if video_encoder_info.codec == enc_helper::VideoCodec::AV1 {
pipeline.add(&av1_parse)?;
}
// Link main audio branch
gst::Element::link_many(&[
&audio_source,
&audio_converter,
&audio_rate,
&audio_capsfilter,
&audio_encoder,
&audio_rtp_payloader,
&main_audio_queue,
&audio_appsink.upcast_ref(),
])?;
// If debug latency, add time overlay before tee
if args.app.debug_latency {
gst::Element::link_many(&[&video_source, &caps_filter, &debug_latency, &video_tee])?;
} else {
gst::Element::link_many(&[&video_source, &caps_filter, &video_tee])?;
}
// Link debug branch if debug is enabled
if args.app.debug_feed {
gst::Element::link_many(&[
&video_tee,
&debug_video_converter,
&debug_queue,
&debug_sink,
])?;
}
// Link main video branch, if AV1, add av1_parse
if video_encoder_info.codec == enc_helper::VideoCodec::AV1 {
gst::Element::link_many(&[
&video_tee,
&video_converter,
&video_encoder,
&av1_parse,
&video_rtp_payloader,
&main_video_queue,
&video_appsink.upcast_ref(),
])?;
} else {
gst::Element::link_many(&[
&video_tee,
&video_converter,
&video_encoder,
&video_rtp_payloader,
&main_video_queue,
&video_appsink.upcast_ref(),
])?;
}
// Optimize latency of pipeline
video_source.set_property("do-timestamp", &true);
audio_source.set_property("do-timestamp", &true);
pipeline.set_property("latency", &0u64);
// Wrap the pipeline in Arc<Mutex> to safely share it
let pipeline = Arc::new(Mutex::new(pipeline));
// Run both pipeline and websocket tasks concurrently
let result = tokio::try_join!(
run_room(
room.clone(),
"audio/opus",
video_encoder_info.codec.to_mime_str(),
pipeline.clone(),
Arc::new(Mutex::new(audio_appsink.stream())),
Arc::new(Mutex::new(video_appsink.stream()))
),
run_pipeline(pipeline.clone())
);
match result {
Ok(_) => log::info!("All tasks completed successfully"),
Err(e) => {
log::error!("Error occurred in one of the tasks: {}", e);
return Err("Error occurred in one of the tasks".into());
}
}
Ok(())
}
async fn run_room(
room: Arc<Mutex<room::Room>>,
audio_codec: &str,
video_codec: &str,
pipeline: Arc<Mutex<gst::Pipeline>>,
audio_stream: Arc<Mutex<AppSinkStream>>,
video_stream: Arc<Mutex<AppSinkStream>>,
) -> Result<(), Box<dyn Error>> {
// Run loop, with recovery on error
loop {
let mut room = room.lock().await;
tokio::select! {
_ = tokio::signal::ctrl_c() => {
log::info!("Room interrupted via Ctrl+C");
return Ok(());
}
result = room.run(
audio_codec,
video_codec,
pipeline.clone(),
audio_stream.clone(),
video_stream.clone(),
) => {
if let Err(e) = result {
log::error!("Room error: {}", e);
// Sleep for a while before retrying
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
} else {
return Ok(());
}
}
}
}
}
async fn run_pipeline(
pipeline: Arc<Mutex<gst::Pipeline>>,
) -> Result<(), Box<dyn Error>> {
// Take ownership of the bus without holding the lock
let bus = {
let pipeline = pipeline.lock().await;
pipeline.bus().ok_or("Pipeline has no bus")?
};
{
// Temporarily lock the pipeline to change state
let pipeline = pipeline.lock().await;
if let Err(e) = pipeline.set_state(gst::State::Playing) {
log::error!("Failed to start pipeline: {}", e);
return Err("Failed to start pipeline".into());
}
}
// Wait for EOS or error (don't lock the pipeline indefinitely)
tokio::select! {
_ = tokio::signal::ctrl_c() => {
log::info!("Pipeline interrupted via Ctrl+C");
}
result = listen_for_gst_messages(bus) => {
match result {
Ok(_) => log::info!("Pipeline finished with EOS"),
Err(err) => log::error!("Pipeline error: {}", err),
}
}
}
{
// Temporarily lock the pipeline to reset state
let pipeline = pipeline.lock().await;
pipeline.set_state(gst::State::Null)?;
}
Ok(())
}
async fn listen_for_gst_messages(bus: gst::Bus) -> Result<(), Box<dyn Error>> {
let bus_stream = bus.stream();
tokio::pin!(bus_stream);
while let Some(msg) = bus_stream.next().await {
match msg.view() {
gst::MessageView::Eos(_) => {
log::info!("Received EOS");
break;
}
gst::MessageView::Error(err) => {
let err_msg = format!(
"Error from {:?}: {:?}",
err.src().map(|s| s.path_string()),
err.error()
);
return Err(err_msg.into());
}
_ => (),
}
}
Ok(())
}

View File

@@ -0,0 +1,144 @@
use std::error::Error;
use std::io::{Read, Write};
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use num_derive::{FromPrimitive, ToPrimitive};
use num_traits::{FromPrimitive, ToPrimitive};
use serde::{Deserialize, Serialize};
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::latency::LatencyTracker;
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageBase {
pub payload_type: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageInput {
#[serde(flatten)]
pub base: MessageBase,
pub data: String,
pub latency: Option<LatencyTracker>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageLog {
#[serde(flatten)]
pub base: MessageBase,
pub level: String,
pub message: String,
pub time: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageMetrics {
#[serde(flatten)]
pub base: MessageBase,
pub usage_cpu: f64,
pub usage_memory: f64,
pub uptime: u64,
pub pipeline_latency: f64,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageICE {
#[serde(flatten)]
pub base: MessageBase,
pub candidate: RTCIceCandidateInit,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageSDP {
#[serde(flatten)]
pub base: MessageBase,
pub sdp: RTCSessionDescription,
}
#[repr(i32)]
#[derive(Debug, FromPrimitive, ToPrimitive, Copy, Clone, Serialize, Deserialize)]
#[serde(try_from = "i32", into = "i32")]
pub enum JoinerType {
JoinerNode = 0,
JoinerClient = 1,
}
impl TryFrom<i32> for JoinerType {
type Error = &'static str;
fn try_from(value: i32) -> Result<Self, Self::Error> {
JoinerType::from_i32(value).ok_or("Invalid value for JoinerType")
}
}
impl From<JoinerType> for i32 {
fn from(joiner_type: JoinerType) -> Self {
joiner_type.to_i32().unwrap()
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageJoin {
#[serde(flatten)]
pub base: MessageBase,
pub joiner_type: JoinerType,
}
#[repr(i32)]
#[derive(Debug, FromPrimitive, ToPrimitive, Copy, Clone, Serialize, Deserialize)]
#[serde(try_from = "i32", into = "i32")]
pub enum AnswerType {
AnswerOffline = 0,
AnswerInUse = 1,
AnswerOK = 2,
}
impl TryFrom<i32> for AnswerType {
type Error = &'static str;
fn try_from(value: i32) -> Result<Self, Self::Error> {
AnswerType::from_i32(value).ok_or("Invalid value for AnswerType")
}
}
impl From<AnswerType> for i32 {
fn from(answer_type: AnswerType) -> Self {
answer_type.to_i32().unwrap()
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageAnswer {
#[serde(flatten)]
pub base: MessageBase,
pub answer_type: AnswerType,
}
pub fn encode_message<T: Serialize>(message: &T) -> Result<Vec<u8>, Box<dyn Error>> {
// Serialize the message to JSON
let json = serde_json::to_string(message)?;
// Compress the JSON using gzip
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(json.as_bytes())?;
let compressed_data = encoder.finish()?;
Ok(compressed_data)
}
pub fn decode_message(data: &[u8]) -> Result<MessageBase, Box<dyn Error + Send + Sync>> {
let mut decoder = GzDecoder::new(data);
let mut decompressed_data = String::new();
decoder.read_to_string(&mut decompressed_data)?;
let base_message: MessageBase = serde_json::from_str(&decompressed_data)?;
Ok(base_message)
}
pub fn decode_message_as<T: for<'de> Deserialize<'de>>(
data: Vec<u8>,
) -> Result<T, Box<dyn Error + Send + Sync>> {
let mut decoder = GzDecoder::new(data.as_slice());
let mut decompressed_data = String::new();
decoder.read_to_string(&mut decompressed_data)?;
let message: T = serde_json::from_str(&decompressed_data)?;
Ok(message)
}

540
packages/server/src/room.rs Normal file
View File

@@ -0,0 +1,540 @@
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use std::collections::{HashSet};
use std::error::Error;
use std::sync::Arc;
use futures_util::StreamExt;
use gst::prelude::ElementExtManual;
use gst_app::app_sink::AppSinkStream;
use tokio::sync::{oneshot, Mutex};
use tokio::sync::{mpsc};
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::ice_transport::ice_gathering_state::RTCIceGatheringState;
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
use webrtc::track::track_local::TrackLocalWriter;
use crate::messages::*;
use crate::websocket::NestriWebSocket;
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
enum InputMessage {
#[serde(rename = "mousemove")]
MouseMove { x: i32, y: i32 },
#[serde(rename = "mousemoveabs")]
MouseMoveAbs { x: i32, y: i32 },
#[serde(rename = "wheel")]
Wheel { x: f64, y: f64 },
#[serde(rename = "mousedown")]
MouseDown { key: i32 },
// Add other variants as needed
#[serde(rename = "mouseup")]
MouseUp { key: i32 },
#[serde(rename = "keydown")]
KeyDown { key: i32 },
#[serde(rename = "keyup")]
KeyUp { key: i32 },
}
pub struct Room {
nestri_ws: Arc<NestriWebSocket>,
webrtc_api: webrtc::api::API,
webrtc_config: RTCConfiguration,
}
impl Room {
pub async fn new(
nestri_ws: Arc<NestriWebSocket>,
) -> Result<Room, Box<dyn Error>> {
// Create media engine and register default codecs
let mut media_engine = MediaEngine::default();
media_engine.register_default_codecs()?;
// Registry
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media_engine)?;
// Create the API object with the MediaEngine
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
// Prepare the configuration
let config = RTCConfiguration {
ice_servers: vec![RTCIceServer {
urls: vec!["stun:stun.l.google.com:19302".to_owned()],
..Default::default()
}],
..Default::default()
};
Ok(Self {
nestri_ws,
webrtc_api: api,
webrtc_config: config,
})
}
pub async fn run(
&mut self,
audio_codec: &str,
video_codec: &str,
pipeline: Arc<Mutex<gst::Pipeline>>,
audio_sink: Arc<Mutex<AppSinkStream>>,
video_sink: Arc<Mutex<AppSinkStream>>,
) -> Result<(), Box<dyn Error>> {
let (tx, rx) = oneshot::channel();
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
self.nestri_ws
.register_callback("answer", {
let tx = tx.clone();
move |data| {
if let Ok(answer) = decode_message_as::<MessageAnswer>(data) {
log::info!("Received answer: {:?}", answer);
match answer.answer_type {
AnswerType::AnswerOffline => {
log::warn!("Room is offline, we shouldn't be receiving this");
}
AnswerType::AnswerInUse => {
log::error!("Room is in use by another node!");
}
AnswerType::AnswerOK => {
// Notify that we got an OK answer
if let Some(tx) = tx.lock().unwrap().take() {
if let Err(_) = tx.send(()) {
log::error!("Failed to send OK answer signal");
}
}
}
}
} else {
log::error!("Failed to decode answer");
}
}
})
.await;
// Send a request to join the room
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
},
joiner_type: JoinerType::JoinerNode,
};
if let Ok(encoded) = encode_message(&join_msg) {
self.nestri_ws.send_message(encoded).await?;
} else {
log::error!("Failed to encode join message");
return Err("Failed to encode join message".into());
}
// Wait for the signal indicating that we have received an OK answer
match rx.await {
Ok(()) => {
log::info!("Received OK answer, proceeding...");
}
Err(_) => {
log::error!("Oneshot channel closed unexpectedly");
return Err("Unexpected error while waiting for OK answer".into());
}
}
// Create a new RTCPeerConnection
let config = self.webrtc_config.clone();
let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?);
// Create audio track
let audio_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: audio_codec.to_owned(),
..Default::default()
},
"audio".to_owned(),
"audio-nestri-server".to_owned(),
));
// Create video track
let video_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: video_codec.to_owned(),
..Default::default()
},
"video".to_owned(),
"video-nestri-server".to_owned(),
));
// Cancellation token to stop spawned tasks after peer connection is closed
let cancel_token = tokio_util::sync::CancellationToken::new();
// Add audio track to peer connection
let audio_sender = peer_connection.add_track(audio_track.clone()).await?;
let audio_sender_token = cancel_token.child_token();
tokio::spawn(async move {
loop {
let mut rtcp_buf = vec![0u8; 1500];
tokio::select! {
_ = audio_sender_token.cancelled() => {
break;
}
_ = audio_sender.read(&mut rtcp_buf) => {}
}
}
});
// Add video track to peer connection
let video_sender = peer_connection.add_track(video_track.clone()).await?;
let video_sender_token = cancel_token.child_token();
tokio::spawn(async move {
loop {
let mut rtcp_buf = vec![0u8; 1500];
tokio::select! {
_ = video_sender_token.cancelled() => {
break;
}
_ = video_sender.read(&mut rtcp_buf) => {}
}
}
});
// Create a datachannel with label 'input'
let data_channel_opts = Some(RTCDataChannelInit {
ordered: Some(false),
max_retransmits: Some(0),
..Default::default()
});
let data_channel
= peer_connection.create_data_channel("input", data_channel_opts).await?;
// PeerConnection state change tracker
let (pc_sndr, mut pc_recv) = mpsc::channel(1);
// Peer connection state change handler
peer_connection.on_peer_connection_state_change(Box::new(
move |s: RTCPeerConnectionState| {
let pc_sndr = pc_sndr.clone();
Box::pin(async move {
log::info!("PeerConnection State has changed: {s}");
if s == RTCPeerConnectionState::Failed
|| s == RTCPeerConnectionState::Disconnected
|| s == RTCPeerConnectionState::Closed
{
// Notify pc_state that the peer connection has closed
if let Err(e) = pc_sndr.send(s).await {
log::error!("Failed to send PeerConnection state: {}", e);
}
}
})
},
));
peer_connection.on_ice_gathering_state_change(Box::new(move |s| {
Box::pin(async move {
log::info!("ICE Gathering State has changed: {s}");
})
}));
peer_connection.on_ice_connection_state_change(Box::new(move |s| {
Box::pin(async move {
log::info!("ICE Connection State has changed: {s}");
})
}));
// Trickle ICE over WebSocket
let ws = self.nestri_ws.clone();
peer_connection.on_ice_candidate(Box::new(move |c| {
let nestri_ws = ws.clone();
Box::pin(async move {
if let Some(candidate) = c {
let candidate_json = candidate.to_json().unwrap();
let ice_msg = MessageICE {
base: MessageBase {
payload_type: "ice".to_string(),
},
candidate: candidate_json,
};
if let Ok(encoded) = encode_message(&ice_msg) {
let _ = nestri_ws.send_message(encoded);
}
}
})
}));
// Temporary ICE candidate buffer until remote description is set
let ice_holder: Arc<Mutex<Vec<RTCIceCandidateInit>>> = Arc::new(Mutex::new(Vec::new()));
// Register set_response_callback for ICE candidate
let pc = peer_connection.clone();
let ice_clone = ice_holder.clone();
self.nestri_ws.register_callback("ice", move |data| {
match decode_message_as::<MessageICE>(data) {
Ok(message) => {
log::info!("Received ICE message");
let candidate = RTCIceCandidateInit::from(message.candidate);
let pc = pc.clone();
let ice_clone = ice_clone.clone();
tokio::spawn(async move {
// If remote description is not set, buffer ICE candidates
if pc.remote_description().await.is_none() {
let mut ice_holder = ice_clone.lock().await;
ice_holder.push(candidate);
} else {
if let Err(e) = pc.add_ice_candidate(candidate).await {
log::error!("Failed to add ICE candidate: {}", e);
} else {
// Add any held ICE candidates
let mut ice_holder = ice_clone.lock().await;
for candidate in ice_holder.drain(..) {
if let Err(e) = pc.add_ice_candidate(candidate).await {
log::error!("Failed to add ICE candidate: {}", e);
}
}
}
}
});
}
Err(e) => eprintln!("Failed to decode callback message: {:?}", e),
}
}).await;
// A shared state to track currently pressed keys
let pressed_keys = Arc::new(Mutex::new(HashSet::new()));
let pressed_buttons = Arc::new(Mutex::new(HashSet::new()));
// Data channel message handler
data_channel.on_message(Box::new(move |msg: DataChannelMessage| {
let pipeline = pipeline.clone();
let pressed_keys = pressed_keys.clone();
let pressed_buttons = pressed_buttons.clone();
Box::pin({
async move {
// We don't care about string messages for now
if !msg.is_string {
// Decode the message as an MessageInput (binary encoded gzip)
match decode_message_as::<MessageInput>(msg.data.to_vec()) {
Ok(message_input) => {
// Handle the input message
if let Ok(input_msg) = from_str::<InputMessage>(&message_input.data) {
if let Some(event) =
handle_input_message(input_msg, &pressed_keys, &pressed_buttons).await
{
let _ = pipeline.lock().await.send_event(event);
}
}
}
Err(e) => {
log::error!("Failed to decode input message: {:?}", e);
}
}
}
}
})
}));
log::info!("Creating offer...");
// Create an offer to send to the browser
let offer = peer_connection.create_offer(None).await?;
log::info!("Setting local description...");
// Sets the LocalDescription, and starts our UDP listeners
peer_connection.set_local_description(offer).await?;
log::info!("Local description set...");
if let Some(local_description) = peer_connection.local_description().await {
// Wait until we have gathered all ICE candidates
while peer_connection.ice_gathering_state() != RTCIceGatheringState::Complete {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
// Register set_response_callback for SDP answer
let pc = peer_connection.clone();
self.nestri_ws.register_callback("sdp", move |data| {
match decode_message_as::<MessageSDP>(data) {
Ok(message) => {
log::info!("Received SDP message");
let sdp = message.sdp;
let pc = pc.clone();
tokio::spawn(async move {
if let Err(e) = pc.set_remote_description(sdp).await {
log::error!("Failed to set remote description: {}", e);
}
});
}
Err(e) => eprintln!("Failed to decode callback message: {:?}", e),
}
}).await;
log::info!("Sending local description to remote...");
// Encode and send the local description via WebSocket
let sdp_msg = MessageSDP {
base: MessageBase {
payload_type: "sdp".to_string(),
},
sdp: local_description,
};
let encoded = encode_message(&sdp_msg)?;
self.nestri_ws.send_message(encoded).await?;
} else {
log::error!("generate local_description failed!");
cancel_token.cancel();
return Err("generate local_description failed!".into());
};
// Send video and audio data
let audio_track = audio_track.clone();
tokio::spawn(async move {
let mut audio_sink = audio_sink.lock().await;
while let Some(sample) = audio_sink.next().await {
if let Some(buffer) = sample.buffer() {
if let Ok(map) = buffer.map_readable() {
if let Err(e) = audio_track.write(map.as_slice()).await {
if webrtc::Error::ErrClosedPipe == e {
break;
} else {
log::error!("Failed to write audio track: {}", e);
}
}
}
}
}
});
let video_track = video_track.clone();
tokio::spawn(async move {
let mut video_sink = video_sink.lock().await;
while let Some(sample) = video_sink.next().await {
if let Some(buffer) = sample.buffer() {
if let Ok(map) = buffer.map_readable() {
if let Err(e) = video_track.write(map.as_slice()).await {
if webrtc::Error::ErrClosedPipe == e {
break;
} else {
log::error!("Failed to write video track: {}", e);
}
}
}
}
}
});
// Block until closed or error
tokio::select! {
_ = pc_recv.recv() => {
log::info!("Peer connection closed with state: {:?}", peer_connection.connection_state());
}
}
cancel_token.cancel();
// Make double-sure to close the peer connection
if let Err(e) = peer_connection.close().await {
log::error!("Failed to close peer connection: {}", e);
}
Ok(())
}
}
async fn handle_input_message(
input_msg: InputMessage,
pressed_keys: &Arc<Mutex<HashSet<i32>>>,
pressed_buttons: &Arc<Mutex<HashSet<i32>>>,
) -> Option<gst::Event> {
match input_msg {
InputMessage::MouseMove { x, y } => {
let structure = gst::Structure::builder("MouseMoveRelative")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseMoveAbs { x, y } => {
let structure = gst::Structure::builder("MouseMoveAbsolute")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyDown { key } => {
let mut keys = pressed_keys.lock().await;
// If the key is already pressed, return to prevent key lockup
if keys.contains(&key) {
return None;
}
keys.insert(key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyUp { key } => {
let mut keys = pressed_keys.lock().await;
// Remove the key from the pressed state when released
keys.remove(&key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::Wheel { x, y } => {
let structure = gst::Structure::builder("MouseAxis")
.field("x", x as f64)
.field("y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseDown { key } => {
let mut buttons = pressed_buttons.lock().await;
// If the button is already pressed, return to prevent button lockup
if buttons.contains(&key) {
return None;
}
buttons.insert(key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseUp { key } => {
let mut buttons = pressed_buttons.lock().await;
// Remove the button from the pressed state when released
buttons.remove(&key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
}
}

View File

@@ -0,0 +1,174 @@
use futures_util::sink::SinkExt;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::StreamExt;
use log::{Level, Log, Metadata, Record};
use std::collections::HashMap;
use std::error::Error;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::time::sleep;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use crate::messages::{decode_message, encode_message, MessageBase, MessageLog};
type Callback = Box<dyn Fn(Vec<u8>) + Send + Sync>;
type WSRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
type WSWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
#[derive(Clone)]
pub struct NestriWebSocket {
ws_url: String,
reader: Arc<Mutex<WSRead>>,
writer: Arc<Mutex<WSWrite>>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
}
impl NestriWebSocket {
pub async fn new(
ws_url: String,
) -> Result<NestriWebSocket, Box<dyn Error>> {
// Attempt to connect to the WebSocket
let ws_stream = NestriWebSocket::do_connect(&ws_url).await.unwrap();
// If the connection is successful, split the stream
let (write, read) = ws_stream.split();
let mut ws = NestriWebSocket {
ws_url,
reader: Arc::new(Mutex::new(read)),
writer: Arc::new(Mutex::new(write)),
callbacks: Arc::new(RwLock::new(HashMap::new())),
};
// Spawn the read loop
ws.spawn_read_loop();
Ok(ws)
}
async fn do_connect(
ws_url: &str,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Box<dyn Error + Send + Sync>> {
loop {
match connect_async(ws_url).await {
Ok((ws_stream, _)) => {
return Ok(ws_stream);
}
Err(e) => {
eprintln!("Failed to connect to WebSocket, retrying: {:?}", e);
sleep(Duration::from_secs(1)).await; // Wait before retrying
}
}
}
}
// Handles message -> callback calls and reconnects on error/disconnect
fn spawn_read_loop(&mut self) {
let reader = self.reader.clone();
let callbacks = self.callbacks.clone();
let mut self_clone = self.clone();
tokio::spawn(async move {
loop {
let message = reader.lock().await.next().await;
match message {
Some(Ok(message)) => {
let data = message.into_data();
let base_message = match decode_message(&data) {
Ok(base_message) => base_message,
Err(e) => {
eprintln!("Failed to decode message: {:?}", e);
continue;
}
};
let callbacks_lock = callbacks.read().await;
if let Some(callback) = callbacks_lock.get(&base_message.payload_type) {
let data = data.clone();
callback(data);
}
}
Some(Err(e)) => {
eprintln!("Error receiving message: {:?}", e);
self_clone.reconnect().await.unwrap();
}
None => {
eprintln!("WebSocket connection closed, reconnecting...");
self_clone.reconnect().await.unwrap();
}
}
}
});
}
async fn reconnect(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
// Keep trying to reconnect until successful
loop {
match NestriWebSocket::do_connect(&self.ws_url).await {
Ok(ws_stream) => {
let (write, read) = ws_stream.split();
*self.reader.lock().await = read;
*self.writer.lock().await = write;
return Ok(());
}
Err(e) => {
eprintln!("Failed to reconnect to WebSocket: {:?}", e);
sleep(Duration::from_secs(2)).await; // Wait before retrying
}
}
}
}
pub async fn send_message(&self, message: Vec<u8>) -> Result<(), Box<dyn Error>> {
let mut writer_lock = self.writer.lock().await;
writer_lock
.send(Message::Binary(message))
.await
.map_err(|e| format!("Error sending message: {:?}", e).into())
}
pub async fn register_callback<F>(&self, response_type: &str, callback: F)
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
let mut callbacks_lock = self.callbacks.write().await;
callbacks_lock.insert(response_type.to_string(), Box::new(callback));
}
}
impl Log for NestriWebSocket {
fn enabled(&self, metadata: &Metadata) -> bool {
// Filter logs by level
metadata.level() <= Level::Info
}
fn log(&self, record: &Record) {
if self.enabled(record.metadata()) {
let level = record.level().to_string();
let message = record.args().to_string();
let time = chrono::Local::now().to_rfc3339();
// Print to console as well
println!("{}: {}", level, message);
// Encode and send the log message
let log_message = MessageLog {
base: MessageBase {
payload_type: "log".to_string(),
},
level,
message,
time,
};
if let Ok(encoded_message) = encode_message(&log_message) {
let _ = self.send_message(encoded_message);
}
}
}
fn flush(&self) {
// No-op for this logger
}
}