mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 08:45:38 +02:00
✨ feat: Add streaming support (#125)
This adds: - [x] Keyboard and mouse handling on the frontend - [x] Video and audio streaming from the backend to the frontend - [x] Input server that works with Websockets Update - 17/11 - [ ] Master docker container to run this - [ ] Steam runtime - [ ] Entrypoint.sh --------- Co-authored-by: Kristian Ollikainen <14197772+DatCaptainHorse@users.noreply.github.com> Co-authored-by: Kristian Ollikainen <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
205
packages/server/src/args.rs
Normal file
205
packages/server/src/args.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
use clap::{Arg, Command};
|
||||
|
||||
pub mod app_args;
|
||||
pub mod device_args;
|
||||
pub mod encoding_args;
|
||||
|
||||
pub struct Args {
|
||||
pub app: app_args::AppArgs,
|
||||
pub device: device_args::DeviceArgs,
|
||||
pub encoding: encoding_args::EncodingArgs,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
pub fn new() -> Self {
|
||||
let matches = Command::new("nestri-server")
|
||||
.arg(
|
||||
Arg::new("verbose")
|
||||
.short('v')
|
||||
.long("verbose")
|
||||
.env("VERBOSE")
|
||||
.help("Enable verbose output")
|
||||
.default_value("false"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("debug-feed")
|
||||
.short('d')
|
||||
.long("debug-feed")
|
||||
.env("DEBUG_FEED")
|
||||
.help("Debug by showing a window on host")
|
||||
.default_value("false"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("debug-latency")
|
||||
.short('l')
|
||||
.long("debug-latency")
|
||||
.env("DEBUG_LATENCY")
|
||||
.help("Debug latency by showing time on feed")
|
||||
.default_value("false"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("relay-url")
|
||||
.short('u')
|
||||
.long("relay-url")
|
||||
.env("RELAY_URL")
|
||||
.help("Nestri relay URL")
|
||||
)
|
||||
.arg(
|
||||
Arg::new("resolution")
|
||||
.short('r')
|
||||
.long("resolution")
|
||||
.env("RESOLUTION")
|
||||
.help("Display/stream resolution in 'WxH' format")
|
||||
.default_value("1280x720"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("framerate")
|
||||
.short('f')
|
||||
.long("framerate")
|
||||
.env("FRAMERATE")
|
||||
.help("Display/stream framerate")
|
||||
.default_value("60"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("room")
|
||||
.long("room")
|
||||
.env("NESTRI_ROOM")
|
||||
.help("Nestri room name/identifier")
|
||||
)
|
||||
.arg(
|
||||
Arg::new("gpu-vendor")
|
||||
.short('g')
|
||||
.long("gpu-vendor")
|
||||
.env("GPU_VENDOR")
|
||||
.help("GPU to find by vendor (e.g. 'nvidia')")
|
||||
.required(false),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("gpu-name")
|
||||
.short('n')
|
||||
.long("gpu-name")
|
||||
.env("GPU_NAME")
|
||||
.help("GPU to find by name (e.g. 'rtx 3060')")
|
||||
.required(false),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("gpu-index")
|
||||
.short('i')
|
||||
.long("gpu-index")
|
||||
.env("GPU_INDEX")
|
||||
.help("GPU index, if multiple similar GPUs are present")
|
||||
.default_value("0"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("gpu-card-path")
|
||||
.long("gpu-card-path")
|
||||
.env("GPU_CARD_PATH")
|
||||
.help("Force a specific GPU by card/render path (e.g. '/dev/dri/card0')")
|
||||
.required(false)
|
||||
.conflicts_with_all(["gpu-vendor", "gpu-name", "gpu-index"]),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("video-codec")
|
||||
.short('c')
|
||||
.long("video-codec")
|
||||
.env("VIDEO_CODEC")
|
||||
.help("Preferred video codec ('h264', 'h265', 'av1')")
|
||||
.default_value("h264"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("video-encoder")
|
||||
.long("video-encoder")
|
||||
.env("VIDEO_ENCODER")
|
||||
.help("Override video encoder (e.g. 'vah264enc')")
|
||||
)
|
||||
.arg(
|
||||
Arg::new("video-rate-control")
|
||||
.long("video-rate-control")
|
||||
.env("VIDEO_RATE_CONTROL")
|
||||
.help("Rate control method ('cqp', 'vbr', 'cbr')")
|
||||
.default_value("vbr"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("video-cqp")
|
||||
.long("video-cqp")
|
||||
.env("VIDEO_CQP")
|
||||
.help("Constant Quantization Parameter (CQP) quality")
|
||||
.default_value("26"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("video-bitrate")
|
||||
.long("video-bitrate")
|
||||
.env("VIDEO_BITRATE")
|
||||
.help("Target bitrate in kbps")
|
||||
.default_value("6000"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("video-bitrate-max")
|
||||
.long("video-bitrate-max")
|
||||
.env("VIDEO_BITRATE_MAX")
|
||||
.help("Maximum bitrate in kbps")
|
||||
.default_value("8000"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("video-encoder-type")
|
||||
.long("video-encoder-type")
|
||||
.env("VIDEO_ENCODER_TYPE")
|
||||
.help("Encoder type ('hardware', 'software')")
|
||||
.default_value("hardware"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("audio-capture-method")
|
||||
.long("audio-capture-method")
|
||||
.env("AUDIO_CAPTURE_METHOD")
|
||||
.help("Audio capture method ('pipewire', 'pulseaudio', 'alsa')")
|
||||
.default_value("pulseaudio"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("audio-codec")
|
||||
.long("audio-codec")
|
||||
.env("AUDIO_CODEC")
|
||||
.help("Preferred audio codec ('opus', 'aac')")
|
||||
.default_value("opus"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("audio-encoder")
|
||||
.long("audio-encoder")
|
||||
.env("AUDIO_ENCODER")
|
||||
.help("Override audio encoder (e.g. 'opusenc')")
|
||||
)
|
||||
.arg(
|
||||
Arg::new("audio-rate-control")
|
||||
.long("audio-rate-control")
|
||||
.env("AUDIO_RATE_CONTROL")
|
||||
.help("Rate control method ('cqp', 'vbr', 'cbr')")
|
||||
.default_value("vbr"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("audio-bitrate")
|
||||
.long("audio-bitrate")
|
||||
.env("AUDIO_BITRATE")
|
||||
.help("Target bitrate in kbps")
|
||||
.default_value("128"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("audio-bitrate-max")
|
||||
.long("audio-bitrate-max")
|
||||
.env("AUDIO_BITRATE_MAX")
|
||||
.help("Maximum bitrate in kbps")
|
||||
.default_value("192"),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
Self {
|
||||
app: app_args::AppArgs::from_matches(&matches),
|
||||
device: device_args::DeviceArgs::from_matches(&matches),
|
||||
encoding: encoding_args::EncodingArgs::from_matches(&matches),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_print(&self) {
|
||||
self.app.debug_print();
|
||||
self.device.debug_print();
|
||||
self.encoding.debug_print();
|
||||
}
|
||||
}
|
||||
59
packages/server/src/args/app_args.rs
Normal file
59
packages/server/src/args/app_args.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
pub struct AppArgs {
|
||||
/// Verbose output mode
|
||||
pub verbose: bool,
|
||||
/// Debug the pipeline by showing a window on host
|
||||
pub debug_feed: bool,
|
||||
/// Debug the latency by showing time in stream
|
||||
pub debug_latency: bool,
|
||||
|
||||
/// Virtual display resolution
|
||||
pub resolution: (u32, u32),
|
||||
/// Virtual display framerate
|
||||
pub framerate: u32,
|
||||
|
||||
/// Nestri relay url
|
||||
pub relay_url: String,
|
||||
/// Nestri room name/identifier
|
||||
pub room: String,
|
||||
}
|
||||
impl AppArgs {
|
||||
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
|
||||
Self {
|
||||
verbose: matches.get_one::<String>("verbose").unwrap() == "true"
|
||||
|| matches.get_one::<String>("verbose").unwrap() == "1",
|
||||
debug_feed: matches.get_one::<String>("debug-feed").unwrap() == "true"
|
||||
|| matches.get_one::<String>("debug-feed").unwrap() == "1",
|
||||
debug_latency: matches.get_one::<String>("debug-latency").unwrap() == "true"
|
||||
|| matches.get_one::<String>("debug-latency").unwrap() == "1",
|
||||
resolution: {
|
||||
let res = matches.get_one::<String>("resolution").unwrap().clone();
|
||||
let parts: Vec<&str> = res.split('x').collect();
|
||||
(
|
||||
parts[0].parse::<u32>().unwrap(),
|
||||
parts[1].parse::<u32>().unwrap(),
|
||||
)
|
||||
},
|
||||
framerate: matches
|
||||
.get_one::<String>("framerate")
|
||||
.unwrap()
|
||||
.parse::<u32>()
|
||||
.unwrap(),
|
||||
relay_url: matches.get_one::<String>("relay-url").unwrap().clone(),
|
||||
// Generate random room name if not provided
|
||||
room: matches.get_one::<String>("room")
|
||||
.unwrap_or(&rand::random::<u32>().to_string())
|
||||
.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_print(&self) {
|
||||
println!("AppArgs:");
|
||||
println!("> verbose: {}", self.verbose);
|
||||
println!("> debug_feed: {}", self.debug_feed);
|
||||
println!("> debug_latency: {}", self.debug_latency);
|
||||
println!("> resolution: {}x{}", self.resolution.0, self.resolution.1);
|
||||
println!("> framerate: {}", self.framerate);
|
||||
println!("> relay_url: {}", self.relay_url);
|
||||
println!("> room: {}", self.room);
|
||||
}
|
||||
}
|
||||
41
packages/server/src/args/device_args.rs
Normal file
41
packages/server/src/args/device_args.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
pub struct DeviceArgs {
|
||||
/// GPU vendor (e.g. "intel")
|
||||
pub gpu_vendor: String,
|
||||
/// GPU name (e.g. "a770")
|
||||
pub gpu_name: String,
|
||||
/// GPU index, if multiple same GPUs are present
|
||||
pub gpu_index: u32,
|
||||
/// GPU card/render path, sets card explicitly from such path
|
||||
pub gpu_card_path: String,
|
||||
}
|
||||
impl DeviceArgs {
|
||||
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
|
||||
Self {
|
||||
gpu_vendor: matches
|
||||
.get_one::<String>("gpu-vendor")
|
||||
.unwrap_or(&"".to_string())
|
||||
.clone(),
|
||||
gpu_name: matches
|
||||
.get_one::<String>("gpu-name")
|
||||
.unwrap_or(&"".to_string())
|
||||
.clone(),
|
||||
gpu_index: matches
|
||||
.get_one::<String>("gpu-index")
|
||||
.unwrap()
|
||||
.parse::<u32>()
|
||||
.unwrap(),
|
||||
gpu_card_path: matches
|
||||
.get_one::<String>("gpu-card-path")
|
||||
.unwrap_or(&"".to_string())
|
||||
.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_print(&self) {
|
||||
println!("DeviceArgs:");
|
||||
println!("> gpu_vendor: {}", self.gpu_vendor);
|
||||
println!("> gpu_name: {}", self.gpu_name);
|
||||
println!("> gpu_index: {}", self.gpu_index);
|
||||
println!("> gpu_card_path: {}", self.gpu_card_path);
|
||||
}
|
||||
}
|
||||
190
packages/server/src/args/encoding_args.rs
Normal file
190
packages/server/src/args/encoding_args.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use std::ops::Deref;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct RateControlCQP {
|
||||
/// Constant Quantization Parameter (CQP) quality level
|
||||
pub quality: u32,
|
||||
}
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct RateControlVBR {
|
||||
/// Target bitrate in kbps
|
||||
pub target_bitrate: i32,
|
||||
/// Maximum bitrate in kbps
|
||||
pub max_bitrate: i32,
|
||||
}
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct RateControlCBR {
|
||||
/// Target bitrate in kbps
|
||||
pub target_bitrate: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum RateControl {
|
||||
/// Constant Quantization Parameter
|
||||
CQP(RateControlCQP),
|
||||
/// Variable Bitrate
|
||||
VBR(RateControlVBR),
|
||||
/// Constant Bitrate
|
||||
CBR(RateControlCBR),
|
||||
}
|
||||
|
||||
pub struct EncodingOptionsBase {
|
||||
/// Codec (e.g. "h264", "opus" etc.)
|
||||
pub codec: String,
|
||||
/// Overridable encoder (e.g. "vah264lpenc", "opusenc" etc.)
|
||||
pub encoder: String,
|
||||
/// Rate control method (e.g. "cqp", "vbr", "cbr")
|
||||
pub rate_control: RateControl,
|
||||
}
|
||||
impl EncodingOptionsBase {
|
||||
pub fn debug_print(&self) {
|
||||
println!("> Codec: {}", self.codec);
|
||||
println!("> Encoder: {}", self.encoder);
|
||||
match &self.rate_control {
|
||||
RateControl::CQP(cqp) => {
|
||||
println!("> Rate Control: CQP");
|
||||
println!("-> Quality: {}", cqp.quality);
|
||||
}
|
||||
RateControl::VBR(vbr) => {
|
||||
println!("> Rate Control: VBR");
|
||||
println!("-> Target Bitrate: {}", vbr.target_bitrate);
|
||||
println!("-> Max Bitrate: {}", vbr.max_bitrate);
|
||||
}
|
||||
RateControl::CBR(cbr) => {
|
||||
println!("> Rate Control: CBR");
|
||||
println!("-> Target Bitrate: {}", cbr.target_bitrate);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VideoEncodingOptions {
|
||||
pub base: EncodingOptionsBase,
|
||||
/// Encoder type (e.g. "hardware", "software")
|
||||
pub encoder_type: String,
|
||||
}
|
||||
impl VideoEncodingOptions {
|
||||
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
|
||||
Self {
|
||||
base: EncodingOptionsBase {
|
||||
codec: matches.get_one::<String>("video-codec").unwrap().clone(),
|
||||
encoder: matches
|
||||
.get_one::<String>("video-encoder")
|
||||
.unwrap_or(&"".to_string())
|
||||
.clone(),
|
||||
rate_control: match matches.get_one::<String>("video-rate-control").unwrap().as_str() {
|
||||
"cqp" => RateControl::CQP(RateControlCQP {
|
||||
quality: matches.get_one::<String>("video-cqp").unwrap().parse::<u32>().unwrap(),
|
||||
}),
|
||||
"cbr" => RateControl::CBR(RateControlCBR {
|
||||
target_bitrate: matches.get_one::<String>("video-bitrate").unwrap().parse::<i32>().unwrap(),
|
||||
}),
|
||||
"vbr" => RateControl::VBR(RateControlVBR {
|
||||
target_bitrate: matches.get_one::<String>("video-bitrate").unwrap().parse::<i32>().unwrap(),
|
||||
max_bitrate: matches.get_one::<String>("video-bitrate-max").unwrap().parse::<i32>().unwrap(),
|
||||
}),
|
||||
_ => panic!("Invalid rate control method for video"),
|
||||
},
|
||||
},
|
||||
encoder_type: matches.get_one::<String>("video-encoder-type").unwrap_or(&"hardware".to_string()).clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_print(&self) {
|
||||
println!("Video Encoding Options:");
|
||||
self.base.debug_print();
|
||||
println!("> Encoder Type: {}", self.encoder_type);
|
||||
}
|
||||
}
|
||||
impl Deref for VideoEncodingOptions {
|
||||
type Target = EncodingOptionsBase;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.base
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum AudioCaptureMethod {
|
||||
PulseAudio,
|
||||
PipeWire,
|
||||
ALSA,
|
||||
}
|
||||
impl AudioCaptureMethod {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
AudioCaptureMethod::PulseAudio => "pulseaudio",
|
||||
AudioCaptureMethod::PipeWire => "pipewire",
|
||||
AudioCaptureMethod::ALSA => "alsa",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AudioEncodingOptions {
|
||||
pub base: EncodingOptionsBase,
|
||||
pub capture_method: AudioCaptureMethod,
|
||||
}
|
||||
impl AudioEncodingOptions {
|
||||
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
|
||||
Self {
|
||||
base: EncodingOptionsBase {
|
||||
codec: matches.get_one::<String>("audio-codec").unwrap().clone(),
|
||||
encoder: matches
|
||||
.get_one::<String>("audio-encoder")
|
||||
.unwrap_or(&"".to_string())
|
||||
.clone(),
|
||||
rate_control: match matches.get_one::<String>("audio-rate-control").unwrap().as_str() {
|
||||
"cbr" => RateControl::CBR(RateControlCBR {
|
||||
target_bitrate: matches.get_one::<String>("audio-bitrate").unwrap().parse::<i32>().unwrap(),
|
||||
}),
|
||||
"vbr" => RateControl::VBR(RateControlVBR {
|
||||
target_bitrate: matches.get_one::<String>("audio-bitrate").unwrap().parse::<i32>().unwrap(),
|
||||
max_bitrate: matches.get_one::<String>("audio-bitrate-max").unwrap().parse::<i32>().unwrap(),
|
||||
}),
|
||||
_ => panic!("Invalid rate control method for audio"),
|
||||
},
|
||||
},
|
||||
capture_method: match matches.get_one::<String>("audio-capture-method").unwrap().as_str() {
|
||||
"pulseaudio" => AudioCaptureMethod::PulseAudio,
|
||||
"pipewire" => AudioCaptureMethod::PipeWire,
|
||||
"alsa" => AudioCaptureMethod::ALSA,
|
||||
// Default to PulseAudio
|
||||
_ => AudioCaptureMethod::PulseAudio,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_print(&self) {
|
||||
println!("Audio Encoding Options:");
|
||||
self.base.debug_print();
|
||||
println!("> Capture Method: {}", self.capture_method.as_str());
|
||||
}
|
||||
}
|
||||
impl Deref for AudioEncodingOptions {
|
||||
type Target = EncodingOptionsBase;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.base
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EncodingArgs {
|
||||
/// Video encoder options
|
||||
pub video: VideoEncodingOptions,
|
||||
/// Audio encoder options
|
||||
pub audio: AudioEncodingOptions,
|
||||
}
|
||||
impl EncodingArgs {
|
||||
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
|
||||
Self {
|
||||
video: VideoEncodingOptions::from_matches(matches),
|
||||
audio: AudioEncodingOptions::from_matches(matches),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_print(&self) {
|
||||
println!("Encoding Arguments:");
|
||||
self.video.debug_print();
|
||||
self.audio.debug_print();
|
||||
}
|
||||
}
|
||||
598
packages/server/src/enc_helper.rs
Normal file
598
packages/server/src/enc_helper.rs
Normal file
@@ -0,0 +1,598 @@
|
||||
use gst::prelude::*;
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum VideoCodec {
|
||||
H264,
|
||||
H265,
|
||||
AV1,
|
||||
UNKNOWN,
|
||||
}
|
||||
impl VideoCodec {
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
match self {
|
||||
VideoCodec::H264 => "H.264",
|
||||
VideoCodec::H265 => "H.265",
|
||||
VideoCodec::AV1 => "AV1",
|
||||
VideoCodec::UNKNOWN => "Unknown",
|
||||
}
|
||||
}
|
||||
|
||||
// unlike to_str, converts to gstreamer friendly codec name
|
||||
pub fn to_gst_str(&self) -> &'static str {
|
||||
match self {
|
||||
VideoCodec::H264 => "h264",
|
||||
VideoCodec::H265 => "h265",
|
||||
VideoCodec::AV1 => "av1",
|
||||
VideoCodec::UNKNOWN => "unknown",
|
||||
}
|
||||
}
|
||||
|
||||
// returns mime-type string
|
||||
pub fn to_mime_str(&self) -> &'static str {
|
||||
match self {
|
||||
VideoCodec::H264 => "video/H264",
|
||||
VideoCodec::H265 => "video/H265",
|
||||
VideoCodec::AV1 => "video/AV1",
|
||||
VideoCodec::UNKNOWN => "unknown",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"h264" => VideoCodec::H264,
|
||||
"h.264" => VideoCodec::H264,
|
||||
"avc" => VideoCodec::H264,
|
||||
"h265" => VideoCodec::H265,
|
||||
"h.265" => VideoCodec::H265,
|
||||
"hevc" => VideoCodec::H265,
|
||||
"hev1" => VideoCodec::H265,
|
||||
"av1" => VideoCodec::AV1,
|
||||
_ => VideoCodec::UNKNOWN,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum EncoderAPI {
|
||||
QSV,
|
||||
VAAPI,
|
||||
NVENC,
|
||||
AMF,
|
||||
SOFTWARE,
|
||||
UNKNOWN,
|
||||
}
|
||||
impl EncoderAPI {
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
match self {
|
||||
EncoderAPI::QSV => "Intel QuickSync Video",
|
||||
EncoderAPI::VAAPI => "Video Acceleration API",
|
||||
EncoderAPI::NVENC => "NVIDIA NVENC",
|
||||
EncoderAPI::AMF => "AMD Media Framework",
|
||||
EncoderAPI::SOFTWARE => "Software",
|
||||
EncoderAPI::UNKNOWN => "Unknown",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum EncoderType {
|
||||
SOFTWARE,
|
||||
HARDWARE,
|
||||
UNKNOWN,
|
||||
}
|
||||
impl EncoderType {
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
match self {
|
||||
EncoderType::SOFTWARE => "Software",
|
||||
EncoderType::HARDWARE => "Hardware",
|
||||
EncoderType::UNKNOWN => "Unknown",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"software" => EncoderType::SOFTWARE,
|
||||
"hardware" => EncoderType::HARDWARE,
|
||||
_ => EncoderType::UNKNOWN,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VideoEncoderInfo {
|
||||
pub name: String,
|
||||
pub codec: VideoCodec,
|
||||
pub encoder_type: EncoderType,
|
||||
pub encoder_api: EncoderAPI,
|
||||
pub parameters: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl VideoEncoderInfo {
|
||||
pub fn new(
|
||||
name: String,
|
||||
codec: VideoCodec,
|
||||
encoder_type: EncoderType,
|
||||
encoder_api: EncoderAPI,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
codec,
|
||||
encoder_type,
|
||||
encoder_api,
|
||||
parameters: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_parameters_string(&self) -> String {
|
||||
self.parameters
|
||||
.iter()
|
||||
.map(|(key, value)| format!("{}={}", key, value))
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
pub fn set_parameter(&mut self, key: &str, value: &str) {
|
||||
self.parameters.push((key.to_string(), value.to_string()));
|
||||
}
|
||||
|
||||
pub fn apply_parameters(&self, element: &gst::Element, verbose: &bool) {
|
||||
for (key, value) in &self.parameters {
|
||||
if element.has_property(key) {
|
||||
// If verbose, log property sets
|
||||
if *verbose {
|
||||
println!("Setting property {} to {}", key, value);
|
||||
}
|
||||
element.set_property_from_str(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts VA-API encoder name to low-power variant.
|
||||
/// # Arguments
|
||||
/// * `encoder` - The name of the VA-API encoder.
|
||||
/// # Returns
|
||||
/// * `&str` - The name of the low-power variant of the encoder.
|
||||
fn get_low_power_encoder(encoder: &String) -> String {
|
||||
if encoder.starts_with("va") && !encoder.ends_with("enc") && !encoder.ends_with("lpenc") {
|
||||
// Replace "enc" substring at end with "lpenc"
|
||||
let mut encoder = encoder.to_string();
|
||||
encoder.truncate(encoder.len() - 3);
|
||||
encoder.push_str("lpenc");
|
||||
encoder
|
||||
} else {
|
||||
encoder.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns best guess for encoder API based on the encoder name.
|
||||
/// # Arguments
|
||||
/// * `encoder` - The name of the encoder.
|
||||
/// # Returns
|
||||
/// * `EncoderAPI` - The best guess for the encoder API.
|
||||
fn get_encoder_api(encoder: &String, encoder_type: &EncoderType) -> EncoderAPI {
|
||||
if *encoder_type == EncoderType::HARDWARE {
|
||||
if encoder.starts_with("qsv") {
|
||||
EncoderAPI::QSV
|
||||
} else if encoder.starts_with("va") {
|
||||
EncoderAPI::VAAPI
|
||||
} else if encoder.starts_with("nv") {
|
||||
EncoderAPI::NVENC
|
||||
} else if encoder.starts_with("amf") {
|
||||
EncoderAPI::AMF
|
||||
} else {
|
||||
EncoderAPI::UNKNOWN
|
||||
}
|
||||
} else if *encoder_type == EncoderType::SOFTWARE {
|
||||
EncoderAPI::SOFTWARE
|
||||
} else {
|
||||
EncoderAPI::UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if system supports given encoder.
|
||||
/// # Returns
|
||||
/// * `bool` - True if encoder is supported, false otherwise.
|
||||
fn is_encoder_supported(encoder: &String) -> bool {
|
||||
gst::ElementFactory::find(encoder.as_str()).is_some()
|
||||
}
|
||||
|
||||
/// Helper to set CQP value of known encoder
|
||||
/// # Arguments
|
||||
/// * `encoder` - Information about the encoder.
|
||||
/// * `quality` - Constant quantization parameter (CQP) quality, recommended values are between 20-30.
|
||||
/// # Returns
|
||||
/// * `EncoderInfo` - Encoder with maybe updated parameters.
|
||||
pub fn encoder_cqp_params(encoder: &VideoEncoderInfo, quality: u32) -> VideoEncoderInfo {
|
||||
let mut encoder_optz = encoder.clone();
|
||||
|
||||
// Look for known keys by factory creation
|
||||
let encoder = gst::ElementFactory::make(encoder_optz.name.as_str())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Get properties of the encoder
|
||||
for prop in encoder.list_properties() {
|
||||
let prop_name = prop.name();
|
||||
|
||||
// Look for known keys
|
||||
if prop_name.to_lowercase().contains("qp")
|
||||
&& (prop_name.to_lowercase().contains("i") || prop_name.to_lowercase().contains("min"))
|
||||
{
|
||||
encoder_optz.set_parameter(prop_name, &quality.to_string());
|
||||
} else if prop_name.to_lowercase().contains("qp")
|
||||
&& (prop_name.to_lowercase().contains("p") || prop_name.to_lowercase().contains("max"))
|
||||
{
|
||||
encoder_optz.set_parameter(prop_name, &(quality + 2).to_string());
|
||||
}
|
||||
}
|
||||
|
||||
encoder_optz
|
||||
}
|
||||
|
||||
/// Helper to set VBR values of known encoder
|
||||
/// # Arguments
|
||||
/// * `encoder` - Information about the encoder.
|
||||
/// * `bitrate` - Target bitrate in bits per second.
|
||||
/// * `max_bitrate` - Maximum bitrate in bits per second.
|
||||
/// # Returns
|
||||
/// * `EncoderInfo` - Encoder with maybe updated parameters.
|
||||
pub fn encoder_vbr_params(encoder: &VideoEncoderInfo, bitrate: u32, max_bitrate: u32) -> VideoEncoderInfo {
|
||||
let mut encoder_optz = encoder.clone();
|
||||
|
||||
// Look for known keys by factory creation
|
||||
let encoder = gst::ElementFactory::make(encoder_optz.name.as_str())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Get properties of the encoder
|
||||
for prop in encoder.list_properties() {
|
||||
let prop_name = prop.name();
|
||||
|
||||
// Look for known keys
|
||||
if prop_name.to_lowercase().contains("bitrate")
|
||||
&& !prop_name.to_lowercase().contains("max")
|
||||
{
|
||||
encoder_optz.set_parameter(prop_name, &bitrate.to_string());
|
||||
} else if prop_name.to_lowercase().contains("bitrate")
|
||||
&& prop_name.to_lowercase().contains("max")
|
||||
{
|
||||
// If SVT-AV1, don't set max bitrate
|
||||
if encoder_optz.name == "svtav1enc" {
|
||||
continue;
|
||||
}
|
||||
encoder_optz.set_parameter(prop_name, &max_bitrate.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
encoder_optz
|
||||
}
|
||||
|
||||
/// Helper to set CBR value of known encoder
|
||||
/// # Arguments
|
||||
/// * `encoder` - Information about the encoder.
|
||||
/// * `bitrate` - Target bitrate in bits per second.
|
||||
/// # Returns
|
||||
/// * `EncoderInfo` - Encoder with maybe updated parameters.
|
||||
pub fn encoder_cbr_params(encoder: &VideoEncoderInfo, bitrate: u32) -> VideoEncoderInfo {
|
||||
let mut encoder_optz = encoder.clone();
|
||||
|
||||
// Look for known keys by factory creation
|
||||
let encoder = gst::ElementFactory::make(encoder_optz.name.as_str())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Get properties of the encoder
|
||||
for prop in encoder.list_properties() {
|
||||
let prop_name = prop.name();
|
||||
|
||||
// Look for known keys
|
||||
if prop_name.to_lowercase().contains("bitrate")
|
||||
&& !prop_name.to_lowercase().contains("max")
|
||||
{
|
||||
encoder_optz.set_parameter(prop_name, &bitrate.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
encoder_optz
|
||||
}
|
||||
|
||||
/// Helper to set GOP size of known encoder
|
||||
/// # Arguments
|
||||
/// * `encoder` - Information about the encoder.
|
||||
/// * `gop_size` - Group of pictures (GOP) size.
|
||||
/// # Returns
|
||||
/// * `EncoderInfo` - Encoder with maybe updated parameters.
|
||||
pub fn encoder_gop_params(encoder: &VideoEncoderInfo, gop_size: u32) -> VideoEncoderInfo {
|
||||
let mut encoder_optz = encoder.clone();
|
||||
|
||||
// Look for known keys by factory creation
|
||||
let encoder = gst::ElementFactory::make(encoder_optz.name.as_str())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Get properties of the encoder
|
||||
for prop in encoder.list_properties() {
|
||||
let prop_name = prop.name();
|
||||
|
||||
// Look for known keys
|
||||
if prop_name.to_lowercase().contains("gop")
|
||||
|| prop_name.to_lowercase().contains("int-max")
|
||||
|| prop_name.to_lowercase().contains("max-dist")
|
||||
|| prop_name.to_lowercase().contains("intra-period-length")
|
||||
{
|
||||
encoder_optz.set_parameter(prop_name, &gop_size.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
encoder_optz
|
||||
}
|
||||
|
||||
/// Sets parameters of known encoders for low latency operation.
|
||||
/// # Arguments
|
||||
/// * `encoder` - Information about the encoder.
|
||||
/// # Returns
|
||||
/// * `EncoderInfo` - Encoder with maybe updated parameters.
|
||||
pub fn encoder_low_latency_params(encoder: &VideoEncoderInfo) -> VideoEncoderInfo {
|
||||
let mut encoder_optz = encoder.clone();
|
||||
encoder_optz = encoder_gop_params(&encoder_optz, 30);
|
||||
match encoder_optz.encoder_api {
|
||||
EncoderAPI::QSV => {
|
||||
encoder_optz.set_parameter("low-latency", "true");
|
||||
encoder_optz.set_parameter("target-usage", "7");
|
||||
}
|
||||
EncoderAPI::VAAPI => {
|
||||
encoder_optz.set_parameter("target-usage", "7");
|
||||
}
|
||||
EncoderAPI::NVENC => {
|
||||
match encoder_optz.codec {
|
||||
// nvcudah264enc supports newer presets and tunes
|
||||
VideoCodec::H264 => {
|
||||
encoder_optz.set_parameter("multi-pass", "disabled");
|
||||
encoder_optz.set_parameter("preset", "p1");
|
||||
encoder_optz.set_parameter("tune", "ultra-low-latency");
|
||||
}
|
||||
// same goes for nvcudah265enc
|
||||
VideoCodec::H265 => {
|
||||
encoder_optz.set_parameter("multi-pass", "disabled");
|
||||
encoder_optz.set_parameter("preset", "p1");
|
||||
encoder_optz.set_parameter("tune", "ultra-low-latency");
|
||||
}
|
||||
// nvav1enc only supports older presets
|
||||
VideoCodec::AV1 => {
|
||||
encoder_optz.set_parameter("preset", "low-latency-hp");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
EncoderAPI::AMF => {
|
||||
encoder_optz.set_parameter("preset", "speed");
|
||||
match encoder_optz.codec {
|
||||
// Only H.264 supports "ultra-low-latency" usage
|
||||
VideoCodec::H264 => {
|
||||
encoder_optz.set_parameter("usage", "ultra-low-latency");
|
||||
}
|
||||
// Same goes for H.265
|
||||
VideoCodec::H265 => {
|
||||
encoder_optz.set_parameter("usage", "ultra-low-latency");
|
||||
}
|
||||
VideoCodec::AV1 => {
|
||||
encoder_optz.set_parameter("usage", "low-latency");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
EncoderAPI::SOFTWARE => {
|
||||
// Check encoder name for software encoders
|
||||
match encoder_optz.name.as_str() {
|
||||
"openh264enc" => {
|
||||
encoder_optz.set_parameter("complexity", "low");
|
||||
encoder_optz.set_parameter("usage-type", "screen");
|
||||
}
|
||||
"x264enc" => {
|
||||
encoder_optz.set_parameter("rc-lookahead", "0");
|
||||
encoder_optz.set_parameter("speed-preset", "ultrafast");
|
||||
encoder_optz.set_parameter("tune", "zerolatency");
|
||||
}
|
||||
"svtav1enc" => {
|
||||
encoder_optz.set_parameter("preset", "12");
|
||||
// Add ":pred-struct=1" only in CBR mode
|
||||
let params_string = format!(
|
||||
"lookahead=0{}",
|
||||
if encoder_optz.get_parameters_string().contains("cbr") {
|
||||
":pred-struct=1"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
);
|
||||
encoder_optz.set_parameter("parameters-string", params_string.as_str());
|
||||
}
|
||||
"av1enc" => {
|
||||
encoder_optz.set_parameter("usage-profile", "realtime");
|
||||
encoder_optz.set_parameter("cpu-used", "10");
|
||||
encoder_optz.set_parameter("lag-in-frames", "0");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
encoder_optz
|
||||
}
|
||||
|
||||
/// Returns all compatible encoders for the system.
|
||||
/// # Returns
|
||||
/// * `Vec<EncoderInfo>` - List of compatible encoders.
|
||||
pub fn get_compatible_encoders() -> Vec<VideoEncoderInfo> {
|
||||
let mut encoders: Vec<VideoEncoderInfo> = Vec::new();
|
||||
|
||||
let registry = gst::Registry::get();
|
||||
let plugins = registry.plugins();
|
||||
for plugin in plugins {
|
||||
let features = registry.features_by_plugin(plugin.plugin_name().as_str());
|
||||
for feature in features {
|
||||
let encoder = feature.name().to_string();
|
||||
let factory = gst::ElementFactory::find(encoder.as_str());
|
||||
if factory.is_some() {
|
||||
let factory = factory.unwrap();
|
||||
// Get klass metadata
|
||||
let klass = factory.metadata("klass");
|
||||
if klass.is_some() {
|
||||
// Make sure klass contains "Encoder/Video/..."
|
||||
let klass = klass.unwrap().to_string();
|
||||
if !klass.to_lowercase().contains("encoder/video") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If contains "/hardware" in klass, it's a hardware encoder
|
||||
let encoder_type = if klass.to_lowercase().contains("/hardware") {
|
||||
EncoderType::HARDWARE
|
||||
} else {
|
||||
EncoderType::SOFTWARE
|
||||
};
|
||||
|
||||
let api = get_encoder_api(&encoder, &encoder_type);
|
||||
if is_encoder_supported(&encoder) {
|
||||
// Match codec by looking for "264" or "av1" in encoder name
|
||||
let codec = if encoder.contains("264") {
|
||||
VideoCodec::H264
|
||||
} else if encoder.contains("265") {
|
||||
VideoCodec::H265
|
||||
} else if encoder.contains("av1") {
|
||||
VideoCodec::AV1
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
let encoder_info = VideoEncoderInfo::new(encoder, codec, encoder_type, api);
|
||||
encoders.push(encoder_info);
|
||||
} else if api == EncoderAPI::VAAPI {
|
||||
// Try low-power variant of VA-API encoder
|
||||
let low_power_encoder = get_low_power_encoder(&encoder);
|
||||
if is_encoder_supported(&low_power_encoder) {
|
||||
let codec = if low_power_encoder.contains("264") {
|
||||
VideoCodec::H264
|
||||
} else if low_power_encoder.contains("265") {
|
||||
VideoCodec::H265
|
||||
} else if low_power_encoder.contains("av1") {
|
||||
VideoCodec::AV1
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
let encoder_info =
|
||||
VideoEncoderInfo::new(low_power_encoder, codec, encoder_type, api);
|
||||
encoders.push(encoder_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
encoders
|
||||
}
|
||||
|
||||
/// Helper to return encoder from vector by name (case-insensitive).
|
||||
/// # Arguments
|
||||
/// * `encoders` - A vector containing information about each encoder.
|
||||
/// * `name` - A string slice that holds the encoder name.
|
||||
/// # Returns
|
||||
/// * `Option<EncoderInfo>` - A reference to an EncoderInfo struct if found.
|
||||
pub fn get_encoder_by_name(
|
||||
encoders: &Vec<VideoEncoderInfo>,
|
||||
name: &str,
|
||||
) -> Option<VideoEncoderInfo> {
|
||||
let name = name.to_lowercase();
|
||||
encoders
|
||||
.iter()
|
||||
.find(|encoder| encoder.name.to_lowercase() == name)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Helper to get encoders from vector by video codec.
|
||||
/// # Arguments
|
||||
/// * `encoders` - A vector containing information about each encoder.
|
||||
/// * `codec` - The codec of the encoder.
|
||||
/// # Returns
|
||||
/// * `Vec<EncoderInfo>` - A vector containing EncoderInfo structs if found.
|
||||
pub fn get_encoders_by_videocodec(
|
||||
encoders: &Vec<VideoEncoderInfo>,
|
||||
codec: &VideoCodec,
|
||||
) -> Vec<VideoEncoderInfo> {
|
||||
encoders
|
||||
.iter()
|
||||
.filter(|encoder| encoder.codec == *codec)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Helper to get encoders from vector by encoder type.
|
||||
/// # Arguments
|
||||
/// * `encoders` - A vector containing information about each encoder.
|
||||
/// * `encoder_type` - The type of the encoder.
|
||||
/// # Returns
|
||||
/// * `Vec<EncoderInfo>` - A vector containing EncoderInfo structs if found.
|
||||
pub fn get_encoders_by_type(
|
||||
encoders: &Vec<VideoEncoderInfo>,
|
||||
encoder_type: &EncoderType,
|
||||
) -> Vec<VideoEncoderInfo> {
|
||||
encoders
|
||||
.iter()
|
||||
.filter(|encoder| encoder.encoder_type == *encoder_type)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns best-case compatible encoder given desired codec and encoder type.
|
||||
/// # Arguments
|
||||
/// * `encoders` - List of encoders to pick from.
|
||||
/// * `codec` - Desired codec.
|
||||
/// * `encoder_type` - Desired encoder type.
|
||||
/// # Returns
|
||||
/// * `Option<EncoderInfo>` - Best-case compatible encoder.
|
||||
pub fn get_best_compatible_encoder(
|
||||
encoders: &Vec<VideoEncoderInfo>,
|
||||
codec: VideoCodec,
|
||||
encoder_type: EncoderType,
|
||||
) -> Option<VideoEncoderInfo> {
|
||||
let mut best_encoder: Option<VideoEncoderInfo> = None;
|
||||
let mut best_score: i32 = 0;
|
||||
|
||||
// Filter by codec and type first
|
||||
let encoders = get_encoders_by_videocodec(encoders, &codec);
|
||||
let encoders = get_encoders_by_type(&encoders, &encoder_type);
|
||||
|
||||
for encoder in encoders {
|
||||
// Local score
|
||||
let mut score = 0;
|
||||
|
||||
// API score
|
||||
score += match encoder.encoder_api {
|
||||
EncoderAPI::NVENC => 3,
|
||||
EncoderAPI::QSV => 3,
|
||||
EncoderAPI::AMF => 3,
|
||||
EncoderAPI::VAAPI => 2,
|
||||
EncoderAPI::SOFTWARE => 1,
|
||||
EncoderAPI::UNKNOWN => 0,
|
||||
};
|
||||
|
||||
// If software, score also based on name to get most compatible software encoder for low latency
|
||||
if encoder.encoder_type == EncoderType::SOFTWARE {
|
||||
score += match encoder.name.as_str() {
|
||||
"openh264enc" => 2,
|
||||
"x264enc" => 1,
|
||||
"svtav1enc" => 2,
|
||||
"av1enc" => 1,
|
||||
_ => 0,
|
||||
};
|
||||
}
|
||||
|
||||
// Update best encoder based on score
|
||||
if score > best_score {
|
||||
best_encoder = Some(encoder.clone());
|
||||
best_score = score;
|
||||
}
|
||||
}
|
||||
|
||||
best_encoder
|
||||
}
|
||||
233
packages/server/src/gpu.rs
Normal file
233
packages/server/src/gpu.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use regex::Regex;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
use std::str;
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum GPUVendor {
|
||||
UNKNOWN,
|
||||
INTEL,
|
||||
NVIDIA,
|
||||
AMD,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GPUInfo {
|
||||
vendor: GPUVendor,
|
||||
card_path: String,
|
||||
render_path: String,
|
||||
device_name: String,
|
||||
}
|
||||
|
||||
impl GPUInfo {
|
||||
pub fn vendor(&self) -> &GPUVendor {
|
||||
&self.vendor
|
||||
}
|
||||
|
||||
pub fn vendor_string(&self) -> &str {
|
||||
match self.vendor {
|
||||
GPUVendor::INTEL => "Intel",
|
||||
GPUVendor::NVIDIA => "NVIDIA",
|
||||
GPUVendor::AMD => "AMD",
|
||||
GPUVendor::UNKNOWN => "Unknown",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn card_path(&self) -> &str {
|
||||
&self.card_path
|
||||
}
|
||||
|
||||
pub fn render_path(&self) -> &str {
|
||||
&self.render_path
|
||||
}
|
||||
|
||||
pub fn device_name(&self) -> &str {
|
||||
&self.device_name
|
||||
}
|
||||
}
|
||||
|
||||
fn get_gpu_vendor(vendor_id: &str) -> GPUVendor {
|
||||
match vendor_id {
|
||||
"8086" => GPUVendor::INTEL, // Intel
|
||||
"10de" => GPUVendor::NVIDIA, // NVIDIA
|
||||
"1002" => GPUVendor::AMD, // AMD/ATI
|
||||
_ => GPUVendor::UNKNOWN,
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a list of GPUs available on the system.
|
||||
/// # Returns
|
||||
/// * `Vec<GPUInfo>` - A vector containing information about each GPU.
|
||||
pub fn get_gpus() -> Vec<GPUInfo> {
|
||||
let mut gpus = Vec::new();
|
||||
|
||||
// Run lspci to get PCI devices related to GPUs
|
||||
let lspci_output = Command::new("lspci")
|
||||
.arg("-mmnn") // Get machine-readable output with IDs
|
||||
.output()
|
||||
.expect("Failed to execute lspci");
|
||||
|
||||
let output = str::from_utf8(&lspci_output.stdout).unwrap();
|
||||
|
||||
// Filter lines that mention VGA or 3D controller
|
||||
for line in output.lines() {
|
||||
if line.to_lowercase().contains("vga compatible controller")
|
||||
|| line.to_lowercase().contains("3d controller")
|
||||
|| line.to_lowercase().contains("display controller")
|
||||
{
|
||||
if let Some((pci_addr, vendor_id, device_name)) = parse_pci_device(line) {
|
||||
// Run udevadm to get the device path
|
||||
if let Some((card_path, render_path)) = get_dri_device_path(&pci_addr) {
|
||||
let vendor = get_gpu_vendor(&vendor_id);
|
||||
gpus.push(GPUInfo {
|
||||
vendor,
|
||||
card_path,
|
||||
render_path,
|
||||
device_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
gpus
|
||||
}
|
||||
|
||||
/// Parses a line from the lspci output to extract PCI device information.
|
||||
/// # Arguments
|
||||
/// * `line` - A string slice that holds a line from the 'lspci -mmnn' output.
|
||||
/// # Returns
|
||||
/// * `Option<(String, String, String)>` - A tuple containing the PCI address, vendor ID, and device name if parsing is successful.
|
||||
fn parse_pci_device(line: &str) -> Option<(String, String, String)> {
|
||||
// Define a regex pattern to match PCI device lines
|
||||
let re = Regex::new(r#"(?P<pci_addr>[0-9a-fA-F]{1,2}:[0-9a-fA-F]{2}\.[0-9]) "[^"]+" "(?P<vendor_name>[^"]+)" "(?P<device_name>[^"]+)"#).unwrap();
|
||||
|
||||
// Collect all matched groups
|
||||
let parts: Vec<(String, String, String)> = re
|
||||
.captures_iter(line)
|
||||
.filter_map(|cap| {
|
||||
Some((
|
||||
cap["pci_addr"].to_string(),
|
||||
cap["vendor_name"].to_string(),
|
||||
cap["device_name"].to_string(),
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some((pci_addr, vendor_name, device_name)) = parts.first() {
|
||||
// Create a mutable copy of the device name to modify
|
||||
let mut device_name = device_name.clone();
|
||||
|
||||
// If more than 1 square-bracketed item is found, remove last one, otherwise remove all
|
||||
if let Some(start) = device_name.rfind('[') {
|
||||
if let Some(_) = device_name.rfind(']') {
|
||||
device_name = device_name[..start].trim().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Extract vendor ID from vendor name (e.g., "Intel Corporation [8086]" or "Advanced Micro Devices, Inc. [AMD/ATI] [1002]")
|
||||
let vendor_id = vendor_name
|
||||
.split_whitespace()
|
||||
.last()
|
||||
.unwrap_or_default()
|
||||
.trim_matches(|c: char| !c.is_ascii_hexdigit())
|
||||
.to_string();
|
||||
|
||||
return Some((pci_addr.clone(), vendor_id, device_name));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Retrieves the DRI device paths for a given PCI address.
|
||||
/// Doubles as a way to verify the device is a GPU.
|
||||
/// # Arguments
|
||||
/// * `pci_addr` - A string slice that holds the PCI address.
|
||||
/// # Returns
|
||||
/// * `Option<(String, String)>` - A tuple containing the card path and render path if found.
|
||||
fn get_dri_device_path(pci_addr: &str) -> Option<(String, String)> {
|
||||
// Construct the base directory in /sys/bus/pci/devices to start the search
|
||||
let base_dir = Path::new("/sys/bus/pci/devices");
|
||||
|
||||
// Define the target PCI address with "0000:" prefix
|
||||
let target_addr = format!("0000:{}", pci_addr);
|
||||
|
||||
// Search for a matching directory that contains the target PCI address
|
||||
for entry in fs::read_dir(base_dir).ok()?.flatten() {
|
||||
let path = entry.path();
|
||||
|
||||
// Check if the path matches the target PCI address
|
||||
if path.to_string_lossy().contains(&target_addr) {
|
||||
// Look for any files under the 'drm' subdirectory, like 'card0' or 'renderD128'
|
||||
let drm_path = path.join("drm");
|
||||
if drm_path.exists() {
|
||||
let mut card_path = String::new();
|
||||
let mut render_path = String::new();
|
||||
for drm_entry in fs::read_dir(drm_path).ok()?.flatten() {
|
||||
let file_name = drm_entry.file_name();
|
||||
if let Some(name) = file_name.to_str() {
|
||||
if name.starts_with("card") {
|
||||
card_path = format!("/dev/dri/{}", name);
|
||||
}
|
||||
if name.starts_with("renderD") {
|
||||
render_path = format!("/dev/dri/{}", name);
|
||||
}
|
||||
// If both paths are found, break the loop
|
||||
if !card_path.is_empty() && !render_path.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Some((card_path, render_path));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Helper method to get GPUs from vector by vendor name (case-insensitive).
|
||||
/// # Arguments
|
||||
/// * `gpus` - A vector containing information about each GPU.
|
||||
/// * `vendor` - A string slice that holds the vendor name.
|
||||
/// # Returns
|
||||
/// * `Vec<GPUInfo>` - A vector containing GPUInfo structs if found.
|
||||
pub fn get_gpus_by_vendor(gpus: &Vec<GPUInfo>, vendor: &str) -> Vec<GPUInfo> {
|
||||
let vendor = vendor.to_lowercase();
|
||||
gpus.iter()
|
||||
.filter(|gpu| gpu.vendor_string().to_lowercase().contains(&vendor))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Helper method to get GPUs from vector by device name substring (case-insensitive).
|
||||
/// # Arguments
|
||||
/// * `gpus` - A vector containing information about each GPU.
|
||||
/// * `device_name` - A string slice that holds the device name substring.
|
||||
/// # Returns
|
||||
/// * `Vec<GPUInfo>` - A vector containing GPUInfo structs if found.
|
||||
pub fn get_gpus_by_device_name(gpus: &Vec<GPUInfo>, device_name: &str) -> Vec<GPUInfo> {
|
||||
let device_name = device_name.to_lowercase();
|
||||
gpus.iter()
|
||||
.filter(|gpu| gpu.device_name.to_lowercase().contains(&device_name))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Helper method to get a GPU from vector of GPUInfo by card path (either /dev/dri/cardX or /dev/dri/renderDX, case-insensitive).
|
||||
/// # Arguments
|
||||
/// * `gpus` - A vector containing information about each GPU.
|
||||
/// * `card_path` - A string slice that holds the card path.
|
||||
/// # Returns
|
||||
/// * `Option<GPUInfo>` - A reference to a GPUInfo struct if found.
|
||||
pub fn get_gpu_by_card_path(gpus: &Vec<GPUInfo>, card_path: &str) -> Option<GPUInfo> {
|
||||
for gpu in gpus {
|
||||
if gpu.card_path().to_lowercase() == card_path.to_lowercase()
|
||||
|| gpu.render_path().to_lowercase() == card_path.to_lowercase()
|
||||
{
|
||||
return Some(gpu.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
60
packages/server/src/latency.rs
Normal file
60
packages/server/src/latency.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct TimestampEntry {
|
||||
pub stage: String,
|
||||
pub time: String, // ISO 8601 timestamp
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LatencyTracker {
|
||||
pub sequence_id: String,
|
||||
pub timestamps: Vec<TimestampEntry>,
|
||||
pub metadata: Option<HashMap<String, String>>, // Optional metadata
|
||||
}
|
||||
|
||||
impl LatencyTracker {
|
||||
// Creates a new LatencyTracker
|
||||
pub fn new(sequence_id: String) -> Self {
|
||||
Self {
|
||||
sequence_id,
|
||||
timestamps: Vec::new(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the sequence ID
|
||||
pub fn sequence_id(&self) -> &str {
|
||||
&self.sequence_id
|
||||
}
|
||||
|
||||
// Adds a timestamp for a specific stage
|
||||
pub fn add_timestamp(&mut self, stage: &str) {
|
||||
// Ensure extremely precise UTC format (YYYY-MM-DDTHH:MM:SS.658548387Z)
|
||||
let now = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Nanos, true);
|
||||
self.timestamps.push(TimestampEntry {
|
||||
stage: stage.to_string(),
|
||||
time: now,
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate total latency (first to last timestamp)
|
||||
pub fn total_latency(&self) -> Option<i64> {
|
||||
if self.timestamps.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let parsed_times: Result<Vec<_>, _> = self
|
||||
.timestamps
|
||||
.iter()
|
||||
.map(|entry| chrono::DateTime::parse_from_rfc3339(&entry.time))
|
||||
.collect();
|
||||
if let Ok(parsed_times) = parsed_times {
|
||||
let min_time = parsed_times.iter().min().unwrap();
|
||||
let max_time = parsed_times.iter().max().unwrap();
|
||||
Some((*max_time - *min_time).num_milliseconds())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
526
packages/server/src/main.rs
Normal file
526
packages/server/src/main.rs
Normal file
@@ -0,0 +1,526 @@
|
||||
mod args;
|
||||
mod enc_helper;
|
||||
mod gpu;
|
||||
mod room;
|
||||
mod websocket;
|
||||
mod latency;
|
||||
mod messages;
|
||||
|
||||
use crate::args::encoding_args;
|
||||
use gst::prelude::*;
|
||||
use gst_app::AppSink;
|
||||
use std::error::Error;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use futures_util::StreamExt;
|
||||
use gst_app::app_sink::AppSinkStream;
|
||||
use tokio::sync::{Mutex};
|
||||
use crate::websocket::{NestriWebSocket};
|
||||
|
||||
// Handles gathering GPU information and selecting the most suitable GPU
|
||||
fn handle_gpus(args: &args::Args) -> Option<gpu::GPUInfo> {
|
||||
println!("Gathering GPU information..");
|
||||
let gpus = gpu::get_gpus();
|
||||
if gpus.is_empty() {
|
||||
println!("No GPUs found");
|
||||
return None;
|
||||
}
|
||||
for gpu in &gpus {
|
||||
println!(
|
||||
"> [GPU] Vendor: '{}', Card Path: '{}', Render Path: '{}', Device Name: '{}'",
|
||||
gpu.vendor_string(),
|
||||
gpu.card_path(),
|
||||
gpu.render_path(),
|
||||
gpu.device_name()
|
||||
);
|
||||
}
|
||||
|
||||
// Based on available arguments, pick a GPU
|
||||
let gpu;
|
||||
if !args.device.gpu_card_path.is_empty() {
|
||||
gpu = gpu::get_gpu_by_card_path(&gpus, &args.device.gpu_card_path);
|
||||
} else {
|
||||
// Run all filters that are not empty
|
||||
let mut filtered_gpus = gpus.clone();
|
||||
if !args.device.gpu_vendor.is_empty() {
|
||||
filtered_gpus = gpu::get_gpus_by_vendor(&filtered_gpus, &args.device.gpu_vendor);
|
||||
}
|
||||
if !args.device.gpu_name.is_empty() {
|
||||
filtered_gpus = gpu::get_gpus_by_device_name(&filtered_gpus, &args.device.gpu_name);
|
||||
}
|
||||
if args.device.gpu_index != 0 {
|
||||
// get single GPU by index
|
||||
gpu = filtered_gpus.get(args.device.gpu_index as usize).cloned();
|
||||
} else {
|
||||
// get first GPU
|
||||
gpu = filtered_gpus.get(0).cloned();
|
||||
}
|
||||
}
|
||||
if gpu.is_none() {
|
||||
println!("No GPU found with the specified parameters: vendor='{}', name='{}', index='{}', card_path='{}'",
|
||||
args.device.gpu_vendor, args.device.gpu_name, args.device.gpu_index, args.device.gpu_card_path);
|
||||
return None;
|
||||
}
|
||||
let gpu = gpu.unwrap();
|
||||
println!("Selected GPU: '{}'", gpu.device_name());
|
||||
Some(gpu)
|
||||
}
|
||||
|
||||
// Handles picking video encoder
|
||||
fn handle_encoder_video(args: &args::Args) -> Option<enc_helper::VideoEncoderInfo> {
|
||||
println!("Getting compatible video encoders..");
|
||||
let video_encoders = enc_helper::get_compatible_encoders();
|
||||
if video_encoders.is_empty() {
|
||||
println!("No compatible video encoders found");
|
||||
return None;
|
||||
}
|
||||
for encoder in &video_encoders {
|
||||
println!(
|
||||
"> [Video Encoder] Name: '{}', Codec: '{}', API: '{}', Type: '{}'",
|
||||
encoder.name,
|
||||
encoder.codec.to_str(),
|
||||
encoder.encoder_api.to_str(),
|
||||
encoder.encoder_type.to_str()
|
||||
);
|
||||
}
|
||||
// Pick most suitable video encoder based on given arguments
|
||||
let video_encoder;
|
||||
if !args.encoding.video.encoder.is_empty() {
|
||||
video_encoder =
|
||||
enc_helper::get_encoder_by_name(&video_encoders, &args.encoding.video.encoder);
|
||||
} else {
|
||||
video_encoder = enc_helper::get_best_compatible_encoder(
|
||||
&video_encoders,
|
||||
enc_helper::VideoCodec::from_str(&args.encoding.video.codec),
|
||||
enc_helper::EncoderType::from_str(&args.encoding.video.encoder_type),
|
||||
);
|
||||
}
|
||||
if video_encoder.is_none() {
|
||||
println!("No video encoder found with the specified parameters: name='{}', vcodec='{}', type='{}'",
|
||||
args.encoding.video.encoder, args.encoding.video.codec, args.encoding.video.encoder_type);
|
||||
return None;
|
||||
}
|
||||
let video_encoder = video_encoder.unwrap();
|
||||
println!("Selected video encoder: '{}'", video_encoder.name);
|
||||
Some(video_encoder)
|
||||
}
|
||||
|
||||
// Handles picking preferred settings for video encoder
|
||||
fn handle_encoder_video_settings(
|
||||
args: &args::Args,
|
||||
video_encoder: &enc_helper::VideoEncoderInfo,
|
||||
) -> enc_helper::VideoEncoderInfo {
|
||||
let mut optimized_encoder = enc_helper::encoder_low_latency_params(&video_encoder);
|
||||
// Handle rate-control method
|
||||
match &args.encoding.video.rate_control {
|
||||
encoding_args::RateControl::CQP(cqp) => {
|
||||
optimized_encoder = enc_helper::encoder_cqp_params(&optimized_encoder, cqp.quality);
|
||||
}
|
||||
encoding_args::RateControl::VBR(vbr) => {
|
||||
optimized_encoder = enc_helper::encoder_vbr_params(
|
||||
&optimized_encoder,
|
||||
vbr.target_bitrate as u32,
|
||||
vbr.max_bitrate as u32,
|
||||
);
|
||||
}
|
||||
encoding_args::RateControl::CBR(cbr) => {
|
||||
optimized_encoder =
|
||||
enc_helper::encoder_cbr_params(&optimized_encoder, cbr.target_bitrate as u32);
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"Selected video encoder settings: '{}'",
|
||||
optimized_encoder.get_parameters_string()
|
||||
);
|
||||
optimized_encoder
|
||||
}
|
||||
|
||||
// Handles picking audio encoder
|
||||
// TODO: Expand enc_helper with audio types, for now just opus
|
||||
fn handle_encoder_audio(args: &args::Args) -> String {
|
||||
let audio_encoder = if args.encoding.audio.encoder.is_empty() {
|
||||
"opusenc".to_string()
|
||||
} else {
|
||||
args.encoding.audio.encoder.clone()
|
||||
};
|
||||
println!("Selected audio encoder: '{}'", audio_encoder);
|
||||
audio_encoder
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
// Parse command line arguments
|
||||
let args = args::Args::new();
|
||||
if args.app.verbose {
|
||||
args.debug_print();
|
||||
}
|
||||
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install ring crypto provider");
|
||||
|
||||
// Begin connection attempt to the relay WebSocket endpoint
|
||||
// replace any http/https with ws/wss
|
||||
let replaced_relay_url
|
||||
= args.app.relay_url.replace("http://", "ws://").replace("https://", "wss://");
|
||||
let ws_url = format!(
|
||||
"{}/api/ws/{}",
|
||||
replaced_relay_url,
|
||||
args.app.room,
|
||||
);
|
||||
|
||||
// Setup our websocket
|
||||
let nestri_ws = Arc::new(NestriWebSocket::new(ws_url).await?);
|
||||
log::set_max_level(log::LevelFilter::Info);
|
||||
log::set_boxed_logger(Box::new(nestri_ws.clone())).unwrap();
|
||||
|
||||
let _ = gst::init();
|
||||
|
||||
// Handle GPU selection
|
||||
let gpu = handle_gpus(&args);
|
||||
if gpu.is_none() {
|
||||
log::error!("Failed to find a suitable GPU. Exiting..");
|
||||
return Err("Failed to find a suitable GPU. Exiting..".into());
|
||||
}
|
||||
let gpu = gpu.unwrap();
|
||||
|
||||
// Handle video encoder selection
|
||||
let video_encoder_info = handle_encoder_video(&args);
|
||||
if video_encoder_info.is_none() {
|
||||
log::error!("Failed to find a suitable video encoder. Exiting..");
|
||||
return Err("Failed to find a suitable video encoder. Exiting..".into());
|
||||
}
|
||||
let mut video_encoder_info = video_encoder_info.unwrap();
|
||||
// Handle video encoder settings
|
||||
video_encoder_info = handle_encoder_video_settings(&args, &video_encoder_info);
|
||||
|
||||
// Handle audio encoder selection
|
||||
let audio_encoder = handle_encoder_audio(&args);
|
||||
|
||||
/*** ROOM SETUP ***/
|
||||
let room = Arc::new(Mutex::new(
|
||||
room::Room::new(nestri_ws.clone()).await?,
|
||||
));
|
||||
|
||||
/*** PIPELINE CREATION ***/
|
||||
/* Audio */
|
||||
// Audio Source Element
|
||||
let audio_source = match args.encoding.audio.capture_method {
|
||||
encoding_args::AudioCaptureMethod::PulseAudio => {
|
||||
gst::ElementFactory::make("pulsesrc").build()?
|
||||
}
|
||||
encoding_args::AudioCaptureMethod::PipeWire => {
|
||||
gst::ElementFactory::make("pipewiresrc").build()?
|
||||
}
|
||||
_ => gst::ElementFactory::make("alsasrc").build()?,
|
||||
};
|
||||
|
||||
// Audio Converter Element
|
||||
let audio_converter = gst::ElementFactory::make("audioconvert").build()?;
|
||||
|
||||
// Audio Rate Element
|
||||
let audio_rate = gst::ElementFactory::make("audiorate").build()?;
|
||||
|
||||
// Required to fix gstreamer opus issue, where quality sounds off (due to wrong sample rate)
|
||||
let audio_capsfilter = gst::ElementFactory::make("capsfilter").build()?;
|
||||
let audio_caps = gst::Caps::from_str("audio/x-raw,rate=48000,channels=2").unwrap();
|
||||
audio_capsfilter.set_property("caps", &audio_caps);
|
||||
|
||||
// Audio Encoder Element
|
||||
let audio_encoder = gst::ElementFactory::make(audio_encoder.as_str()).build()?;
|
||||
audio_encoder.set_property(
|
||||
"bitrate",
|
||||
&match &args.encoding.audio.rate_control {
|
||||
encoding_args::RateControl::CBR(cbr) => cbr.target_bitrate * 1000i32,
|
||||
encoding_args::RateControl::VBR(vbr) => vbr.target_bitrate * 1000i32,
|
||||
_ => 128i32,
|
||||
},
|
||||
);
|
||||
|
||||
// Audio RTP Payloader Element
|
||||
let audio_rtp_payloader = gst::ElementFactory::make("rtpopuspay").build()?;
|
||||
|
||||
/* Video */
|
||||
// Video Source Element
|
||||
let video_source = gst::ElementFactory::make("waylanddisplaysrc").build()?;
|
||||
video_source.set_property("render-node", &gpu.render_path());
|
||||
|
||||
// Caps Filter Element (resolution, fps)
|
||||
let caps_filter = gst::ElementFactory::make("capsfilter").build()?;
|
||||
let caps = gst::Caps::from_str(&format!(
|
||||
"video/x-raw,width={},height={},framerate={}/1,format=RGBx",
|
||||
args.app.resolution.0, args.app.resolution.1, args.app.framerate
|
||||
))?;
|
||||
caps_filter.set_property("caps", &caps);
|
||||
|
||||
// Video Tee Element
|
||||
let video_tee = gst::ElementFactory::make("tee").build()?;
|
||||
|
||||
// Video Converter Element
|
||||
let video_converter = gst::ElementFactory::make("videoconvert").build()?;
|
||||
|
||||
// Video Encoder Element
|
||||
let video_encoder = gst::ElementFactory::make(video_encoder_info.name.as_str()).build()?;
|
||||
video_encoder_info.apply_parameters(&video_encoder, &args.app.verbose);
|
||||
|
||||
// Required for AV1 - av1parse
|
||||
let av1_parse = gst::ElementFactory::make("av1parse").build()?;
|
||||
|
||||
// Video RTP Payloader Element
|
||||
let video_rtp_payloader = gst::ElementFactory::make(
|
||||
format!("rtp{}pay", video_encoder_info.codec.to_gst_str()).as_str(),
|
||||
)
|
||||
.build()?;
|
||||
|
||||
/* Output */
|
||||
// Audio AppSink Element
|
||||
let audio_appsink = gst::ElementFactory::make("appsink").build()?;
|
||||
audio_appsink.set_property("emit-signals", &true);
|
||||
let audio_appsink = audio_appsink.downcast_ref::<AppSink>().unwrap();
|
||||
|
||||
// Video AppSink Element
|
||||
let video_appsink = gst::ElementFactory::make("appsink").build()?;
|
||||
video_appsink.set_property("emit-signals", &true);
|
||||
let video_appsink = video_appsink.downcast_ref::<AppSink>().unwrap();
|
||||
|
||||
/* Debug */
|
||||
// Debug Feed Element
|
||||
let debug_latency = gst::ElementFactory::make("timeoverlay").build()?;
|
||||
debug_latency.set_property_from_str("halignment", &"right");
|
||||
debug_latency.set_property_from_str("valignment", &"bottom");
|
||||
|
||||
// Debug Sink Element
|
||||
let debug_sink = gst::ElementFactory::make("ximagesink").build()?;
|
||||
|
||||
// Debug video converter
|
||||
let debug_video_converter = gst::ElementFactory::make("videoconvert").build()?;
|
||||
|
||||
// Queues with max 2ms latency
|
||||
let debug_queue = gst::ElementFactory::make("queue2").build()?;
|
||||
debug_queue.set_property("max-size-time", &1000000u64);
|
||||
let main_video_queue = gst::ElementFactory::make("queue2").build()?;
|
||||
main_video_queue.set_property("max-size-time", &1000000u64);
|
||||
let main_audio_queue = gst::ElementFactory::make("queue2").build()?;
|
||||
main_audio_queue.set_property("max-size-time", &1000000u64);
|
||||
|
||||
// Create the pipeline
|
||||
let pipeline = gst::Pipeline::new();
|
||||
|
||||
// Add elements to the pipeline
|
||||
pipeline.add_many(&[
|
||||
&video_appsink.upcast_ref(),
|
||||
&video_rtp_payloader,
|
||||
&video_encoder,
|
||||
&video_converter,
|
||||
&video_tee,
|
||||
&caps_filter,
|
||||
&video_source,
|
||||
&audio_appsink.upcast_ref(),
|
||||
&audio_rtp_payloader,
|
||||
&audio_encoder,
|
||||
&audio_capsfilter,
|
||||
&audio_rate,
|
||||
&audio_converter,
|
||||
&audio_source,
|
||||
&main_video_queue,
|
||||
&main_audio_queue,
|
||||
])?;
|
||||
|
||||
// Add debug elements if debug is enabled
|
||||
if args.app.debug_feed {
|
||||
pipeline.add_many(&[&debug_sink, &debug_queue, &debug_video_converter])?;
|
||||
}
|
||||
|
||||
// Add debug latency element if debug latency is enabled
|
||||
if args.app.debug_latency {
|
||||
pipeline.add(&debug_latency)?;
|
||||
}
|
||||
|
||||
// Add AV1 parse element if AV1 is selected
|
||||
if video_encoder_info.codec == enc_helper::VideoCodec::AV1 {
|
||||
pipeline.add(&av1_parse)?;
|
||||
}
|
||||
|
||||
// Link main audio branch
|
||||
gst::Element::link_many(&[
|
||||
&audio_source,
|
||||
&audio_converter,
|
||||
&audio_rate,
|
||||
&audio_capsfilter,
|
||||
&audio_encoder,
|
||||
&audio_rtp_payloader,
|
||||
&main_audio_queue,
|
||||
&audio_appsink.upcast_ref(),
|
||||
])?;
|
||||
|
||||
// If debug latency, add time overlay before tee
|
||||
if args.app.debug_latency {
|
||||
gst::Element::link_many(&[&video_source, &caps_filter, &debug_latency, &video_tee])?;
|
||||
} else {
|
||||
gst::Element::link_many(&[&video_source, &caps_filter, &video_tee])?;
|
||||
}
|
||||
|
||||
// Link debug branch if debug is enabled
|
||||
if args.app.debug_feed {
|
||||
gst::Element::link_many(&[
|
||||
&video_tee,
|
||||
&debug_video_converter,
|
||||
&debug_queue,
|
||||
&debug_sink,
|
||||
])?;
|
||||
}
|
||||
|
||||
// Link main video branch, if AV1, add av1_parse
|
||||
if video_encoder_info.codec == enc_helper::VideoCodec::AV1 {
|
||||
gst::Element::link_many(&[
|
||||
&video_tee,
|
||||
&video_converter,
|
||||
&video_encoder,
|
||||
&av1_parse,
|
||||
&video_rtp_payloader,
|
||||
&main_video_queue,
|
||||
&video_appsink.upcast_ref(),
|
||||
])?;
|
||||
} else {
|
||||
gst::Element::link_many(&[
|
||||
&video_tee,
|
||||
&video_converter,
|
||||
&video_encoder,
|
||||
&video_rtp_payloader,
|
||||
&main_video_queue,
|
||||
&video_appsink.upcast_ref(),
|
||||
])?;
|
||||
}
|
||||
|
||||
// Optimize latency of pipeline
|
||||
video_source.set_property("do-timestamp", &true);
|
||||
audio_source.set_property("do-timestamp", &true);
|
||||
pipeline.set_property("latency", &0u64);
|
||||
|
||||
// Wrap the pipeline in Arc<Mutex> to safely share it
|
||||
let pipeline = Arc::new(Mutex::new(pipeline));
|
||||
|
||||
// Run both pipeline and websocket tasks concurrently
|
||||
let result = tokio::try_join!(
|
||||
run_room(
|
||||
room.clone(),
|
||||
"audio/opus",
|
||||
video_encoder_info.codec.to_mime_str(),
|
||||
pipeline.clone(),
|
||||
Arc::new(Mutex::new(audio_appsink.stream())),
|
||||
Arc::new(Mutex::new(video_appsink.stream()))
|
||||
),
|
||||
run_pipeline(pipeline.clone())
|
||||
);
|
||||
|
||||
match result {
|
||||
Ok(_) => log::info!("All tasks completed successfully"),
|
||||
Err(e) => {
|
||||
log::error!("Error occurred in one of the tasks: {}", e);
|
||||
return Err("Error occurred in one of the tasks".into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_room(
|
||||
room: Arc<Mutex<room::Room>>,
|
||||
audio_codec: &str,
|
||||
video_codec: &str,
|
||||
pipeline: Arc<Mutex<gst::Pipeline>>,
|
||||
audio_stream: Arc<Mutex<AppSinkStream>>,
|
||||
video_stream: Arc<Mutex<AppSinkStream>>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
// Run loop, with recovery on error
|
||||
loop {
|
||||
let mut room = room.lock().await;
|
||||
tokio::select! {
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
log::info!("Room interrupted via Ctrl+C");
|
||||
return Ok(());
|
||||
}
|
||||
result = room.run(
|
||||
audio_codec,
|
||||
video_codec,
|
||||
pipeline.clone(),
|
||||
audio_stream.clone(),
|
||||
video_stream.clone(),
|
||||
) => {
|
||||
if let Err(e) = result {
|
||||
log::error!("Room error: {}", e);
|
||||
// Sleep for a while before retrying
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_pipeline(
|
||||
pipeline: Arc<Mutex<gst::Pipeline>>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
// Take ownership of the bus without holding the lock
|
||||
let bus = {
|
||||
let pipeline = pipeline.lock().await;
|
||||
pipeline.bus().ok_or("Pipeline has no bus")?
|
||||
};
|
||||
|
||||
{
|
||||
// Temporarily lock the pipeline to change state
|
||||
let pipeline = pipeline.lock().await;
|
||||
if let Err(e) = pipeline.set_state(gst::State::Playing) {
|
||||
log::error!("Failed to start pipeline: {}", e);
|
||||
return Err("Failed to start pipeline".into());
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for EOS or error (don't lock the pipeline indefinitely)
|
||||
tokio::select! {
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
log::info!("Pipeline interrupted via Ctrl+C");
|
||||
}
|
||||
result = listen_for_gst_messages(bus) => {
|
||||
match result {
|
||||
Ok(_) => log::info!("Pipeline finished with EOS"),
|
||||
Err(err) => log::error!("Pipeline error: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Temporarily lock the pipeline to reset state
|
||||
let pipeline = pipeline.lock().await;
|
||||
pipeline.set_state(gst::State::Null)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen_for_gst_messages(bus: gst::Bus) -> Result<(), Box<dyn Error>> {
|
||||
let bus_stream = bus.stream();
|
||||
|
||||
tokio::pin!(bus_stream);
|
||||
|
||||
while let Some(msg) = bus_stream.next().await {
|
||||
match msg.view() {
|
||||
gst::MessageView::Eos(_) => {
|
||||
log::info!("Received EOS");
|
||||
break;
|
||||
}
|
||||
gst::MessageView::Error(err) => {
|
||||
let err_msg = format!(
|
||||
"Error from {:?}: {:?}",
|
||||
err.src().map(|s| s.path_string()),
|
||||
err.error()
|
||||
);
|
||||
return Err(err_msg.into());
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
144
packages/server/src/messages.rs
Normal file
144
packages/server/src/messages.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
use std::error::Error;
|
||||
use std::io::{Read, Write};
|
||||
use flate2::Compression;
|
||||
use flate2::read::GzDecoder;
|
||||
use flate2::write::GzEncoder;
|
||||
use num_derive::{FromPrimitive, ToPrimitive};
|
||||
use num_traits::{FromPrimitive, ToPrimitive};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||
use crate::latency::LatencyTracker;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageBase {
|
||||
pub payload_type: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageInput {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub data: String,
|
||||
pub latency: Option<LatencyTracker>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageLog {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub level: String,
|
||||
pub message: String,
|
||||
pub time: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageMetrics {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub usage_cpu: f64,
|
||||
pub usage_memory: f64,
|
||||
pub uptime: u64,
|
||||
pub pipeline_latency: f64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageICE {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub candidate: RTCIceCandidateInit,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageSDP {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub sdp: RTCSessionDescription,
|
||||
}
|
||||
|
||||
#[repr(i32)]
|
||||
#[derive(Debug, FromPrimitive, ToPrimitive, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(try_from = "i32", into = "i32")]
|
||||
pub enum JoinerType {
|
||||
JoinerNode = 0,
|
||||
JoinerClient = 1,
|
||||
}
|
||||
impl TryFrom<i32> for JoinerType {
|
||||
type Error = &'static str;
|
||||
|
||||
fn try_from(value: i32) -> Result<Self, Self::Error> {
|
||||
JoinerType::from_i32(value).ok_or("Invalid value for JoinerType")
|
||||
}
|
||||
}
|
||||
impl From<JoinerType> for i32 {
|
||||
fn from(joiner_type: JoinerType) -> Self {
|
||||
joiner_type.to_i32().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageJoin {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub joiner_type: JoinerType,
|
||||
}
|
||||
|
||||
#[repr(i32)]
|
||||
#[derive(Debug, FromPrimitive, ToPrimitive, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(try_from = "i32", into = "i32")]
|
||||
pub enum AnswerType {
|
||||
AnswerOffline = 0,
|
||||
AnswerInUse = 1,
|
||||
AnswerOK = 2,
|
||||
}
|
||||
impl TryFrom<i32> for AnswerType {
|
||||
type Error = &'static str;
|
||||
|
||||
fn try_from(value: i32) -> Result<Self, Self::Error> {
|
||||
AnswerType::from_i32(value).ok_or("Invalid value for AnswerType")
|
||||
}
|
||||
}
|
||||
impl From<AnswerType> for i32 {
|
||||
fn from(answer_type: AnswerType) -> Self {
|
||||
answer_type.to_i32().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageAnswer {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub answer_type: AnswerType,
|
||||
}
|
||||
|
||||
pub fn encode_message<T: Serialize>(message: &T) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
// Serialize the message to JSON
|
||||
let json = serde_json::to_string(message)?;
|
||||
|
||||
// Compress the JSON using gzip
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(json.as_bytes())?;
|
||||
let compressed_data = encoder.finish()?;
|
||||
|
||||
Ok(compressed_data)
|
||||
}
|
||||
|
||||
pub fn decode_message(data: &[u8]) -> Result<MessageBase, Box<dyn Error + Send + Sync>> {
|
||||
let mut decoder = GzDecoder::new(data);
|
||||
let mut decompressed_data = String::new();
|
||||
decoder.read_to_string(&mut decompressed_data)?;
|
||||
|
||||
let base_message: MessageBase = serde_json::from_str(&decompressed_data)?;
|
||||
Ok(base_message)
|
||||
}
|
||||
|
||||
pub fn decode_message_as<T: for<'de> Deserialize<'de>>(
|
||||
data: Vec<u8>,
|
||||
) -> Result<T, Box<dyn Error + Send + Sync>> {
|
||||
let mut decoder = GzDecoder::new(data.as_slice());
|
||||
let mut decompressed_data = String::new();
|
||||
decoder.read_to_string(&mut decompressed_data)?;
|
||||
|
||||
let message: T = serde_json::from_str(&decompressed_data)?;
|
||||
Ok(message)
|
||||
}
|
||||
540
packages/server/src/room.rs
Normal file
540
packages/server/src/room.rs
Normal file
@@ -0,0 +1,540 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use std::collections::{HashSet};
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use futures_util::StreamExt;
|
||||
use gst::prelude::ElementExtManual;
|
||||
use gst_app::app_sink::AppSinkStream;
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
use tokio::sync::{mpsc};
|
||||
use webrtc::api::interceptor_registry::register_default_interceptors;
|
||||
use webrtc::api::media_engine::MediaEngine;
|
||||
use webrtc::api::APIBuilder;
|
||||
use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
|
||||
use webrtc::data_channel::data_channel_message::DataChannelMessage;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
use webrtc::ice_transport::ice_gathering_state::RTCIceGatheringState;
|
||||
use webrtc::ice_transport::ice_server::RTCIceServer;
|
||||
use webrtc::interceptor::registry::Registry;
|
||||
use webrtc::peer_connection::configuration::RTCConfiguration;
|
||||
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
|
||||
use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
|
||||
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
|
||||
use webrtc::track::track_local::TrackLocalWriter;
|
||||
use crate::messages::*;
|
||||
use crate::websocket::NestriWebSocket;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(tag = "type")]
|
||||
enum InputMessage {
|
||||
#[serde(rename = "mousemove")]
|
||||
MouseMove { x: i32, y: i32 },
|
||||
|
||||
#[serde(rename = "mousemoveabs")]
|
||||
MouseMoveAbs { x: i32, y: i32 },
|
||||
|
||||
#[serde(rename = "wheel")]
|
||||
Wheel { x: f64, y: f64 },
|
||||
|
||||
#[serde(rename = "mousedown")]
|
||||
MouseDown { key: i32 },
|
||||
// Add other variants as needed
|
||||
#[serde(rename = "mouseup")]
|
||||
MouseUp { key: i32 },
|
||||
|
||||
#[serde(rename = "keydown")]
|
||||
KeyDown { key: i32 },
|
||||
|
||||
#[serde(rename = "keyup")]
|
||||
KeyUp { key: i32 },
|
||||
}
|
||||
|
||||
pub struct Room {
|
||||
nestri_ws: Arc<NestriWebSocket>,
|
||||
webrtc_api: webrtc::api::API,
|
||||
webrtc_config: RTCConfiguration,
|
||||
}
|
||||
|
||||
impl Room {
|
||||
pub async fn new(
|
||||
nestri_ws: Arc<NestriWebSocket>,
|
||||
) -> Result<Room, Box<dyn Error>> {
|
||||
// Create media engine and register default codecs
|
||||
let mut media_engine = MediaEngine::default();
|
||||
media_engine.register_default_codecs()?;
|
||||
|
||||
// Registry
|
||||
let mut registry = Registry::new();
|
||||
registry = register_default_interceptors(registry, &mut media_engine)?;
|
||||
|
||||
// Create the API object with the MediaEngine
|
||||
let api = APIBuilder::new()
|
||||
.with_media_engine(media_engine)
|
||||
.with_interceptor_registry(registry)
|
||||
.build();
|
||||
|
||||
// Prepare the configuration
|
||||
let config = RTCConfiguration {
|
||||
ice_servers: vec![RTCIceServer {
|
||||
urls: vec!["stun:stun.l.google.com:19302".to_owned()],
|
||||
..Default::default()
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
nestri_ws,
|
||||
webrtc_api: api,
|
||||
webrtc_config: config,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
&mut self,
|
||||
audio_codec: &str,
|
||||
video_codec: &str,
|
||||
pipeline: Arc<Mutex<gst::Pipeline>>,
|
||||
audio_sink: Arc<Mutex<AppSinkStream>>,
|
||||
video_sink: Arc<Mutex<AppSinkStream>>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
|
||||
self.nestri_ws
|
||||
.register_callback("answer", {
|
||||
let tx = tx.clone();
|
||||
move |data| {
|
||||
if let Ok(answer) = decode_message_as::<MessageAnswer>(data) {
|
||||
log::info!("Received answer: {:?}", answer);
|
||||
match answer.answer_type {
|
||||
AnswerType::AnswerOffline => {
|
||||
log::warn!("Room is offline, we shouldn't be receiving this");
|
||||
}
|
||||
AnswerType::AnswerInUse => {
|
||||
log::error!("Room is in use by another node!");
|
||||
}
|
||||
AnswerType::AnswerOK => {
|
||||
// Notify that we got an OK answer
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
if let Err(_) = tx.send(()) {
|
||||
log::error!("Failed to send OK answer signal");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::error!("Failed to decode answer");
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
// Send a request to join the room
|
||||
let join_msg = MessageJoin {
|
||||
base: MessageBase {
|
||||
payload_type: "join".to_string(),
|
||||
},
|
||||
joiner_type: JoinerType::JoinerNode,
|
||||
};
|
||||
if let Ok(encoded) = encode_message(&join_msg) {
|
||||
self.nestri_ws.send_message(encoded).await?;
|
||||
} else {
|
||||
log::error!("Failed to encode join message");
|
||||
return Err("Failed to encode join message".into());
|
||||
}
|
||||
|
||||
// Wait for the signal indicating that we have received an OK answer
|
||||
match rx.await {
|
||||
Ok(()) => {
|
||||
log::info!("Received OK answer, proceeding...");
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Oneshot channel closed unexpectedly");
|
||||
return Err("Unexpected error while waiting for OK answer".into());
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new RTCPeerConnection
|
||||
let config = self.webrtc_config.clone();
|
||||
let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?);
|
||||
|
||||
// Create audio track
|
||||
let audio_track = Arc::new(TrackLocalStaticRTP::new(
|
||||
RTCRtpCodecCapability {
|
||||
mime_type: audio_codec.to_owned(),
|
||||
..Default::default()
|
||||
},
|
||||
"audio".to_owned(),
|
||||
"audio-nestri-server".to_owned(),
|
||||
));
|
||||
|
||||
// Create video track
|
||||
let video_track = Arc::new(TrackLocalStaticRTP::new(
|
||||
RTCRtpCodecCapability {
|
||||
mime_type: video_codec.to_owned(),
|
||||
..Default::default()
|
||||
},
|
||||
"video".to_owned(),
|
||||
"video-nestri-server".to_owned(),
|
||||
));
|
||||
|
||||
// Cancellation token to stop spawned tasks after peer connection is closed
|
||||
let cancel_token = tokio_util::sync::CancellationToken::new();
|
||||
|
||||
// Add audio track to peer connection
|
||||
let audio_sender = peer_connection.add_track(audio_track.clone()).await?;
|
||||
let audio_sender_token = cancel_token.child_token();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let mut rtcp_buf = vec![0u8; 1500];
|
||||
tokio::select! {
|
||||
_ = audio_sender_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
_ = audio_sender.read(&mut rtcp_buf) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add video track to peer connection
|
||||
let video_sender = peer_connection.add_track(video_track.clone()).await?;
|
||||
let video_sender_token = cancel_token.child_token();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let mut rtcp_buf = vec![0u8; 1500];
|
||||
tokio::select! {
|
||||
_ = video_sender_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
_ = video_sender.read(&mut rtcp_buf) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Create a datachannel with label 'input'
|
||||
let data_channel_opts = Some(RTCDataChannelInit {
|
||||
ordered: Some(false),
|
||||
max_retransmits: Some(0),
|
||||
..Default::default()
|
||||
});
|
||||
let data_channel
|
||||
= peer_connection.create_data_channel("input", data_channel_opts).await?;
|
||||
|
||||
// PeerConnection state change tracker
|
||||
let (pc_sndr, mut pc_recv) = mpsc::channel(1);
|
||||
|
||||
// Peer connection state change handler
|
||||
peer_connection.on_peer_connection_state_change(Box::new(
|
||||
move |s: RTCPeerConnectionState| {
|
||||
let pc_sndr = pc_sndr.clone();
|
||||
Box::pin(async move {
|
||||
log::info!("PeerConnection State has changed: {s}");
|
||||
|
||||
if s == RTCPeerConnectionState::Failed
|
||||
|| s == RTCPeerConnectionState::Disconnected
|
||||
|| s == RTCPeerConnectionState::Closed
|
||||
{
|
||||
// Notify pc_state that the peer connection has closed
|
||||
if let Err(e) = pc_sndr.send(s).await {
|
||||
log::error!("Failed to send PeerConnection state: {}", e);
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
));
|
||||
|
||||
peer_connection.on_ice_gathering_state_change(Box::new(move |s| {
|
||||
Box::pin(async move {
|
||||
log::info!("ICE Gathering State has changed: {s}");
|
||||
})
|
||||
}));
|
||||
|
||||
peer_connection.on_ice_connection_state_change(Box::new(move |s| {
|
||||
Box::pin(async move {
|
||||
log::info!("ICE Connection State has changed: {s}");
|
||||
})
|
||||
}));
|
||||
|
||||
// Trickle ICE over WebSocket
|
||||
let ws = self.nestri_ws.clone();
|
||||
peer_connection.on_ice_candidate(Box::new(move |c| {
|
||||
let nestri_ws = ws.clone();
|
||||
Box::pin(async move {
|
||||
if let Some(candidate) = c {
|
||||
let candidate_json = candidate.to_json().unwrap();
|
||||
let ice_msg = MessageICE {
|
||||
base: MessageBase {
|
||||
payload_type: "ice".to_string(),
|
||||
},
|
||||
candidate: candidate_json,
|
||||
};
|
||||
if let Ok(encoded) = encode_message(&ice_msg) {
|
||||
let _ = nestri_ws.send_message(encoded);
|
||||
}
|
||||
}
|
||||
})
|
||||
}));
|
||||
|
||||
// Temporary ICE candidate buffer until remote description is set
|
||||
let ice_holder: Arc<Mutex<Vec<RTCIceCandidateInit>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
// Register set_response_callback for ICE candidate
|
||||
let pc = peer_connection.clone();
|
||||
let ice_clone = ice_holder.clone();
|
||||
self.nestri_ws.register_callback("ice", move |data| {
|
||||
match decode_message_as::<MessageICE>(data) {
|
||||
Ok(message) => {
|
||||
log::info!("Received ICE message");
|
||||
let candidate = RTCIceCandidateInit::from(message.candidate);
|
||||
let pc = pc.clone();
|
||||
let ice_clone = ice_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
// If remote description is not set, buffer ICE candidates
|
||||
if pc.remote_description().await.is_none() {
|
||||
let mut ice_holder = ice_clone.lock().await;
|
||||
ice_holder.push(candidate);
|
||||
} else {
|
||||
if let Err(e) = pc.add_ice_candidate(candidate).await {
|
||||
log::error!("Failed to add ICE candidate: {}", e);
|
||||
} else {
|
||||
// Add any held ICE candidates
|
||||
let mut ice_holder = ice_clone.lock().await;
|
||||
for candidate in ice_holder.drain(..) {
|
||||
if let Err(e) = pc.add_ice_candidate(candidate).await {
|
||||
log::error!("Failed to add ICE candidate: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => eprintln!("Failed to decode callback message: {:?}", e),
|
||||
}
|
||||
}).await;
|
||||
|
||||
// A shared state to track currently pressed keys
|
||||
let pressed_keys = Arc::new(Mutex::new(HashSet::new()));
|
||||
let pressed_buttons = Arc::new(Mutex::new(HashSet::new()));
|
||||
|
||||
// Data channel message handler
|
||||
data_channel.on_message(Box::new(move |msg: DataChannelMessage| {
|
||||
let pipeline = pipeline.clone();
|
||||
let pressed_keys = pressed_keys.clone();
|
||||
let pressed_buttons = pressed_buttons.clone();
|
||||
Box::pin({
|
||||
async move {
|
||||
// We don't care about string messages for now
|
||||
if !msg.is_string {
|
||||
// Decode the message as an MessageInput (binary encoded gzip)
|
||||
match decode_message_as::<MessageInput>(msg.data.to_vec()) {
|
||||
Ok(message_input) => {
|
||||
// Handle the input message
|
||||
if let Ok(input_msg) = from_str::<InputMessage>(&message_input.data) {
|
||||
if let Some(event) =
|
||||
handle_input_message(input_msg, &pressed_keys, &pressed_buttons).await
|
||||
{
|
||||
let _ = pipeline.lock().await.send_event(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to decode input message: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}));
|
||||
|
||||
log::info!("Creating offer...");
|
||||
|
||||
// Create an offer to send to the browser
|
||||
let offer = peer_connection.create_offer(None).await?;
|
||||
|
||||
log::info!("Setting local description...");
|
||||
|
||||
// Sets the LocalDescription, and starts our UDP listeners
|
||||
peer_connection.set_local_description(offer).await?;
|
||||
|
||||
log::info!("Local description set...");
|
||||
|
||||
if let Some(local_description) = peer_connection.local_description().await {
|
||||
// Wait until we have gathered all ICE candidates
|
||||
while peer_connection.ice_gathering_state() != RTCIceGatheringState::Complete {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
// Register set_response_callback for SDP answer
|
||||
let pc = peer_connection.clone();
|
||||
self.nestri_ws.register_callback("sdp", move |data| {
|
||||
match decode_message_as::<MessageSDP>(data) {
|
||||
Ok(message) => {
|
||||
log::info!("Received SDP message");
|
||||
let sdp = message.sdp;
|
||||
let pc = pc.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = pc.set_remote_description(sdp).await {
|
||||
log::error!("Failed to set remote description: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => eprintln!("Failed to decode callback message: {:?}", e),
|
||||
}
|
||||
}).await;
|
||||
|
||||
log::info!("Sending local description to remote...");
|
||||
// Encode and send the local description via WebSocket
|
||||
let sdp_msg = MessageSDP {
|
||||
base: MessageBase {
|
||||
payload_type: "sdp".to_string(),
|
||||
},
|
||||
sdp: local_description,
|
||||
};
|
||||
let encoded = encode_message(&sdp_msg)?;
|
||||
self.nestri_ws.send_message(encoded).await?;
|
||||
} else {
|
||||
log::error!("generate local_description failed!");
|
||||
cancel_token.cancel();
|
||||
return Err("generate local_description failed!".into());
|
||||
};
|
||||
|
||||
// Send video and audio data
|
||||
let audio_track = audio_track.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut audio_sink = audio_sink.lock().await;
|
||||
while let Some(sample) = audio_sink.next().await {
|
||||
if let Some(buffer) = sample.buffer() {
|
||||
if let Ok(map) = buffer.map_readable() {
|
||||
if let Err(e) = audio_track.write(map.as_slice()).await {
|
||||
if webrtc::Error::ErrClosedPipe == e {
|
||||
break;
|
||||
} else {
|
||||
log::error!("Failed to write audio track: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let video_track = video_track.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut video_sink = video_sink.lock().await;
|
||||
while let Some(sample) = video_sink.next().await {
|
||||
if let Some(buffer) = sample.buffer() {
|
||||
if let Ok(map) = buffer.map_readable() {
|
||||
if let Err(e) = video_track.write(map.as_slice()).await {
|
||||
if webrtc::Error::ErrClosedPipe == e {
|
||||
break;
|
||||
} else {
|
||||
log::error!("Failed to write video track: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Block until closed or error
|
||||
tokio::select! {
|
||||
_ = pc_recv.recv() => {
|
||||
log::info!("Peer connection closed with state: {:?}", peer_connection.connection_state());
|
||||
}
|
||||
}
|
||||
|
||||
cancel_token.cancel();
|
||||
|
||||
// Make double-sure to close the peer connection
|
||||
if let Err(e) = peer_connection.close().await {
|
||||
log::error!("Failed to close peer connection: {}", e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_input_message(
|
||||
input_msg: InputMessage,
|
||||
pressed_keys: &Arc<Mutex<HashSet<i32>>>,
|
||||
pressed_buttons: &Arc<Mutex<HashSet<i32>>>,
|
||||
) -> Option<gst::Event> {
|
||||
match input_msg {
|
||||
InputMessage::MouseMove { x, y } => {
|
||||
let structure = gst::Structure::builder("MouseMoveRelative")
|
||||
.field("pointer_x", x as f64)
|
||||
.field("pointer_y", y as f64)
|
||||
.build();
|
||||
|
||||
Some(gst::event::CustomUpstream::new(structure))
|
||||
}
|
||||
InputMessage::MouseMoveAbs { x, y } => {
|
||||
let structure = gst::Structure::builder("MouseMoveAbsolute")
|
||||
.field("pointer_x", x as f64)
|
||||
.field("pointer_y", y as f64)
|
||||
.build();
|
||||
|
||||
Some(gst::event::CustomUpstream::new(structure))
|
||||
}
|
||||
InputMessage::KeyDown { key } => {
|
||||
let mut keys = pressed_keys.lock().await;
|
||||
// If the key is already pressed, return to prevent key lockup
|
||||
if keys.contains(&key) {
|
||||
return None;
|
||||
}
|
||||
keys.insert(key);
|
||||
|
||||
let structure = gst::Structure::builder("KeyboardKey")
|
||||
.field("key", key as u32)
|
||||
.field("pressed", true)
|
||||
.build();
|
||||
|
||||
Some(gst::event::CustomUpstream::new(structure))
|
||||
}
|
||||
InputMessage::KeyUp { key } => {
|
||||
let mut keys = pressed_keys.lock().await;
|
||||
// Remove the key from the pressed state when released
|
||||
keys.remove(&key);
|
||||
|
||||
let structure = gst::Structure::builder("KeyboardKey")
|
||||
.field("key", key as u32)
|
||||
.field("pressed", false)
|
||||
.build();
|
||||
|
||||
Some(gst::event::CustomUpstream::new(structure))
|
||||
}
|
||||
InputMessage::Wheel { x, y } => {
|
||||
let structure = gst::Structure::builder("MouseAxis")
|
||||
.field("x", x as f64)
|
||||
.field("y", y as f64)
|
||||
.build();
|
||||
|
||||
Some(gst::event::CustomUpstream::new(structure))
|
||||
}
|
||||
InputMessage::MouseDown { key } => {
|
||||
let mut buttons = pressed_buttons.lock().await;
|
||||
// If the button is already pressed, return to prevent button lockup
|
||||
if buttons.contains(&key) {
|
||||
return None;
|
||||
}
|
||||
buttons.insert(key);
|
||||
|
||||
let structure = gst::Structure::builder("MouseButton")
|
||||
.field("button", key as u32)
|
||||
.field("pressed", true)
|
||||
.build();
|
||||
|
||||
Some(gst::event::CustomUpstream::new(structure))
|
||||
}
|
||||
InputMessage::MouseUp { key } => {
|
||||
let mut buttons = pressed_buttons.lock().await;
|
||||
// Remove the button from the pressed state when released
|
||||
buttons.remove(&key);
|
||||
|
||||
let structure = gst::Structure::builder("MouseButton")
|
||||
.field("button", key as u32)
|
||||
.field("pressed", false)
|
||||
.build();
|
||||
|
||||
Some(gst::event::CustomUpstream::new(structure))
|
||||
}
|
||||
}
|
||||
}
|
||||
174
packages/server/src/websocket.rs
Normal file
174
packages/server/src/websocket.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use futures_util::sink::SinkExt;
|
||||
use futures_util::stream::{SplitSink, SplitStream};
|
||||
use futures_util::StreamExt;
|
||||
use log::{Level, Log, Metadata, Record};
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::sleep;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||
use crate::messages::{decode_message, encode_message, MessageBase, MessageLog};
|
||||
|
||||
type Callback = Box<dyn Fn(Vec<u8>) + Send + Sync>;
|
||||
type WSRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
|
||||
type WSWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NestriWebSocket {
|
||||
ws_url: String,
|
||||
reader: Arc<Mutex<WSRead>>,
|
||||
writer: Arc<Mutex<WSWrite>>,
|
||||
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
|
||||
}
|
||||
impl NestriWebSocket {
|
||||
pub async fn new(
|
||||
ws_url: String,
|
||||
) -> Result<NestriWebSocket, Box<dyn Error>> {
|
||||
// Attempt to connect to the WebSocket
|
||||
let ws_stream = NestriWebSocket::do_connect(&ws_url).await.unwrap();
|
||||
|
||||
// If the connection is successful, split the stream
|
||||
let (write, read) = ws_stream.split();
|
||||
let mut ws = NestriWebSocket {
|
||||
ws_url,
|
||||
reader: Arc::new(Mutex::new(read)),
|
||||
writer: Arc::new(Mutex::new(write)),
|
||||
callbacks: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
|
||||
// Spawn the read loop
|
||||
ws.spawn_read_loop();
|
||||
|
||||
Ok(ws)
|
||||
}
|
||||
|
||||
async fn do_connect(
|
||||
ws_url: &str,
|
||||
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Box<dyn Error + Send + Sync>> {
|
||||
loop {
|
||||
match connect_async(ws_url).await {
|
||||
Ok((ws_stream, _)) => {
|
||||
return Ok(ws_stream);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to connect to WebSocket, retrying: {:?}", e);
|
||||
sleep(Duration::from_secs(1)).await; // Wait before retrying
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handles message -> callback calls and reconnects on error/disconnect
|
||||
fn spawn_read_loop(&mut self) {
|
||||
let reader = self.reader.clone();
|
||||
let callbacks = self.callbacks.clone();
|
||||
|
||||
let mut self_clone = self.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let message = reader.lock().await.next().await;
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
let data = message.into_data();
|
||||
let base_message = match decode_message(&data) {
|
||||
Ok(base_message) => base_message,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to decode message: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let callbacks_lock = callbacks.read().await;
|
||||
if let Some(callback) = callbacks_lock.get(&base_message.payload_type) {
|
||||
let data = data.clone();
|
||||
callback(data);
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
eprintln!("Error receiving message: {:?}", e);
|
||||
self_clone.reconnect().await.unwrap();
|
||||
}
|
||||
None => {
|
||||
eprintln!("WebSocket connection closed, reconnecting...");
|
||||
self_clone.reconnect().await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn reconnect(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
// Keep trying to reconnect until successful
|
||||
loop {
|
||||
match NestriWebSocket::do_connect(&self.ws_url).await {
|
||||
Ok(ws_stream) => {
|
||||
let (write, read) = ws_stream.split();
|
||||
*self.reader.lock().await = read;
|
||||
*self.writer.lock().await = write;
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to reconnect to WebSocket: {:?}", e);
|
||||
sleep(Duration::from_secs(2)).await; // Wait before retrying
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_message(&self, message: Vec<u8>) -> Result<(), Box<dyn Error>> {
|
||||
let mut writer_lock = self.writer.lock().await;
|
||||
writer_lock
|
||||
.send(Message::Binary(message))
|
||||
.await
|
||||
.map_err(|e| format!("Error sending message: {:?}", e).into())
|
||||
}
|
||||
|
||||
pub async fn register_callback<F>(&self, response_type: &str, callback: F)
|
||||
where
|
||||
F: Fn(Vec<u8>) + Send + Sync + 'static,
|
||||
{
|
||||
let mut callbacks_lock = self.callbacks.write().await;
|
||||
callbacks_lock.insert(response_type.to_string(), Box::new(callback));
|
||||
}
|
||||
}
|
||||
impl Log for NestriWebSocket {
|
||||
fn enabled(&self, metadata: &Metadata) -> bool {
|
||||
// Filter logs by level
|
||||
metadata.level() <= Level::Info
|
||||
}
|
||||
|
||||
fn log(&self, record: &Record) {
|
||||
if self.enabled(record.metadata()) {
|
||||
let level = record.level().to_string();
|
||||
let message = record.args().to_string();
|
||||
let time = chrono::Local::now().to_rfc3339();
|
||||
|
||||
// Print to console as well
|
||||
println!("{}: {}", level, message);
|
||||
|
||||
// Encode and send the log message
|
||||
let log_message = MessageLog {
|
||||
base: MessageBase {
|
||||
payload_type: "log".to_string(),
|
||||
},
|
||||
level,
|
||||
message,
|
||||
time,
|
||||
};
|
||||
if let Ok(encoded_message) = encode_message(&log_message) {
|
||||
let _ = self.send_message(encoded_message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) {
|
||||
// No-op for this logger
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user