feat: Controller support, performance enchancements, multi-stage images, fixes (#304)

## Description
Oops.. another massive PR 🥲 

This PR contains multiple improvements and changes.

Firstly, thanks gst-wayland-display's PR
[here](https://github.com/games-on-whales/gst-wayland-display/pull/20).
NVIDIA path is now way more efficient than before.

Secondly, adding controller support was a massive hurdle, requiring me
to start another project
[vimputti](https://github.com/DatCaptainHorse/vimputti) - which allows
simple virtual controller inputs in isolated containers. Well, it's not
simple, it includes LD_PRELOAD shims and other craziness, but the
library API is simple to use..

Thirdly, split runner image into 3 separate stages, base + build +
runtime, should help keep things in check in future, also added GitHub
Actions CI builds for v2 to v4 builds (hopefully they pass..).

Fourth, replaced the runner's runtime Steam patching with better and
simpler bubblewrap patch, massive thanks to `games-on-whales` to
figuring it out better!

Fifth, relay for once needed some changes, the new changes are still
mostly WIP, but I'll deal with them next time I have energy.. I'm spent
now. Needed to include these changes as relay needed a minor change to
allow rumble events to flow back to client peer.

Sixth.. tons of package updates, minor code improvements and the usual. 

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* End-to-end gamepad/controller support (attach/detach, buttons, sticks,
triggers, rumble) with client/server integration and virtual controller
plumbing.
  * Optional Prometheus metrics endpoint and WebTransport support.
  * Background vimputti manager process added for controller handling.

* **Improvements**
  * Multi-variant container image builds and streamlined runtime images.
  * Zero-copy video pipeline and encoder improvements for lower latency.
  * Updated Steam compat mapping and dependency/toolchain refreshes.

* **Bug Fixes**
* More robust GPU detection, input/fullscreen lifecycle,
startup/entrypoint, and container runtime fixes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
Kristian Ollikainen
2025-10-20 11:20:05 +03:00
committed by GitHub
parent a3ee9aadd9
commit c62a22b552
62 changed files with 4203 additions and 2278 deletions

View File

@@ -58,6 +58,14 @@ impl Args {
.env("NESTRI_ROOM")
.help("Nestri room name/identifier"),
)
.arg(
Arg::new("vimputti-path")
.long("vimputti-path")
.env("VIMPUTTI_PATH")
.help("Path to vimputti socket")
.value_parser(NonEmptyStringValueParser::new())
.default_value("/tmp/vimputti-0"),
)
.arg(
Arg::new("gpu-vendor")
.short('g')
@@ -204,10 +212,10 @@ impl Args {
.default_value("192"),
)
.arg(
Arg::new("dma-buf")
.long("dma-buf")
.env("DMA_BUF")
.help("Use DMA-BUF for pipeline")
Arg::new("zero-copy")
.long("zero-copy")
.env("ZERO_COPY")
.help("Use zero-copy pipeline")
.value_parser(BoolishValueParser::new())
.default_value("false"),
)

View File

@@ -12,9 +12,12 @@ pub struct AppArgs {
/// Nestri room name/identifier
pub room: String,
/// Experimental DMA-BUF support
/// vimputti socket path
pub vimputti_path: Option<String>,
/// Experimental zero-copy pipeline support
/// TODO: Move to video encoding flags
pub dma_buf: bool,
pub zero_copy: bool,
}
impl AppArgs {
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
@@ -45,7 +48,13 @@ impl AppArgs {
.get_one::<String>("room")
.unwrap_or(&rand::random::<u32>().to_string())
.clone(),
dma_buf: matches.get_one::<bool>("dma-buf").unwrap_or(&false).clone(),
vimputti_path: matches
.get_one::<String>("vimputti-path")
.map(|s| s.clone()),
zero_copy: matches
.get_one::<bool>("zero-copy")
.unwrap_or(&false)
.clone(),
}
}
@@ -60,6 +69,10 @@ impl AppArgs {
tracing::info!("> framerate: {}", self.framerate);
tracing::info!("> relay_url: '{}'", self.relay_url);
tracing::info!("> room: '{}'", self.room);
tracing::info!("> dma_buf: {}", self.dma_buf);
tracing::info!(
"> vimputti_path: '{}'",
self.vimputti_path.as_ref().map_or("None", |s| s.as_str())
);
tracing::info!("> zero_copy: {}", self.zero_copy);
}
}

View File

@@ -11,18 +11,10 @@ pub struct DeviceArgs {
impl DeviceArgs {
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
Self {
gpu_vendor: matches
.get_one::<String>("gpu-vendor")
.cloned(),
gpu_name: matches
.get_one::<String>("gpu-name")
.cloned(),
gpu_index: matches
.get_one::<u32>("gpu-index")
.cloned(),
gpu_card_path: matches
.get_one::<String>("gpu-card-path")
.cloned(),
gpu_vendor: matches.get_one::<String>("gpu-vendor").cloned(),
gpu_name: matches.get_one::<String>("gpu-name").cloned(),
gpu_index: matches.get_one::<u32>("gpu-index").cloned(),
gpu_card_path: matches.get_one::<String>("gpu-card-path").cloned(),
}
}

View File

@@ -276,8 +276,8 @@ pub fn encoder_low_latency_params(
_rate_control: &RateControl,
framerate: u32,
) -> VideoEncoderInfo {
// 2 second GOP size, maybe lower to 1 second for fast recovery, if needed?
let mut encoder_optz = encoder_gop_params(encoder, framerate * 2);
// 1 second keyframe interval for fast recovery, is this too taxing?
let mut encoder_optz = encoder_gop_params(encoder, framerate);
match encoder_optz.encoder_api {
EncoderAPI::QSV => {
@@ -291,6 +291,7 @@ pub fn encoder_low_latency_params(
encoder_optz.set_parameter("multi-pass", "disabled");
encoder_optz.set_parameter("preset", "p1");
encoder_optz.set_parameter("tune", "ultra-low-latency");
encoder_optz.set_parameter("zerolatency", "true");
}
EncoderAPI::AMF => {
encoder_optz.set_parameter("preset", "speed");
@@ -400,11 +401,21 @@ pub fn get_compatible_encoders(gpus: &Vec<GPUInfo>) -> Vec<VideoEncoderInfo> {
}
None
} else if element.has_property("cuda-device-id") {
let device_id =
match element.property_value("cuda-device-id").get::<i32>() {
Ok(v) if v >= 0 => Some(v as usize),
_ => None,
};
let device_id = match element
.property_value("cuda-device-id")
.get::<i32>()
{
Ok(v) if v >= 0 => Some(v as usize),
_ => {
// If only one NVIDIA GPU, default to 0
// fixes "Type: 'Hardware', Device: 'CPU'" issue
if get_gpus_by_vendor(&gpus, GPUVendor::NVIDIA).len() == 1 {
Some(0)
} else {
None
}
}
};
// We'll just treat cuda-device-id as an index
device_id.and_then(|id| {
@@ -574,7 +585,7 @@ pub fn get_best_working_encoder(
encoders: &Vec<VideoEncoderInfo>,
codec: &Codec,
encoder_type: &EncoderType,
dma_buf: bool,
zero_copy: bool,
) -> Result<VideoEncoderInfo, Box<dyn Error>> {
let mut candidates = get_encoders_by_videocodec(
encoders,
@@ -590,7 +601,7 @@ pub fn get_best_working_encoder(
while !candidates.is_empty() {
let best = get_best_compatible_encoder(&candidates, codec, encoder_type)?;
tracing::info!("Testing encoder: {}", best.name,);
if test_encoder(&best, dma_buf).is_ok() {
if test_encoder(&best, zero_copy).is_ok() {
return Ok(best);
} else {
// Remove this encoder and try next best
@@ -602,7 +613,7 @@ pub fn get_best_working_encoder(
}
/// Test if a pipeline with the given encoder can be created and set to Playing
pub fn test_encoder(encoder: &VideoEncoderInfo, dma_buf: bool) -> Result<(), Box<dyn Error>> {
pub fn test_encoder(encoder: &VideoEncoderInfo, zero_copy: bool) -> Result<(), Box<dyn Error>> {
let src = gstreamer::ElementFactory::make("waylanddisplaysrc").build()?;
if let Some(gpu_info) = &encoder.gpu_info {
src.set_property_from_str("render-node", gpu_info.render_path());
@@ -610,12 +621,16 @@ pub fn test_encoder(encoder: &VideoEncoderInfo, dma_buf: bool) -> Result<(), Box
let caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let caps = gstreamer::Caps::from_str(&format!(
"{},width=1280,height=720,framerate=30/1{}",
if dma_buf {
"video/x-raw(memory:DMABuf)"
if zero_copy {
if encoder.encoder_api == EncoderAPI::NVENC {
"video/x-raw(memory:CUDAMemory)"
} else {
"video/x-raw(memory:DMABuf)"
}
} else {
"video/x-raw"
},
if dma_buf { "" } else { ",format=RGBx" }
if zero_copy { "" } else { ",format=RGBx" }
))?;
caps_filter.set_property("caps", &caps);
@@ -627,66 +642,47 @@ pub fn test_encoder(encoder: &VideoEncoderInfo, dma_buf: bool) -> Result<(), Box
// Create pipeline and link elements
let pipeline = gstreamer::Pipeline::new();
if dma_buf && encoder.encoder_api == EncoderAPI::NVENC {
// GL upload element
let glupload = gstreamer::ElementFactory::make("glupload").build()?;
// GL color convert element
let glconvert = gstreamer::ElementFactory::make("glcolorconvert").build()?;
// GL color convert caps
let gl_caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let gl_caps = gstreamer::Caps::from_str("video/x-raw(memory:GLMemory),format=NV12")?;
gl_caps_filter.set_property("caps", &gl_caps);
// CUDA upload element
let cudaupload = gstreamer::ElementFactory::make("cudaupload").build()?;
if zero_copy {
if encoder.encoder_api == EncoderAPI::NVENC {
// NVENC zero-copy path
pipeline.add_many(&[&src, &caps_filter, &enc, &sink])?;
gstreamer::Element::link_many(&[&src, &caps_filter, &enc, &sink])?;
} else {
// VA-API/QSV zero-copy path
let vapostproc = gstreamer::ElementFactory::make("vapostproc").build()?;
let va_caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let va_caps = gstreamer::Caps::from_str("video/x-raw(memory:VAMemory),format=NV12")?;
va_caps_filter.set_property("caps", &va_caps);
pipeline.add_many(&[
&src,
&caps_filter,
&glupload,
&glconvert,
&gl_caps_filter,
&cudaupload,
&enc,
&sink,
])?;
gstreamer::Element::link_many(&[
&src,
&caps_filter,
&glupload,
&glconvert,
&gl_caps_filter,
&cudaupload,
&enc,
&sink,
])?;
pipeline.add_many(&[
&src,
&caps_filter,
&vapostproc,
&va_caps_filter,
&enc,
&sink,
])?;
gstreamer::Element::link_many(&[
&src,
&caps_filter,
&vapostproc,
&va_caps_filter,
&enc,
&sink,
])?;
}
} else {
let vapostproc = gstreamer::ElementFactory::make("vapostproc").build()?;
// VA caps filter
let va_caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let va_caps = gstreamer::Caps::from_str("video/x-raw(memory:VAMemory),format=NV12")?;
va_caps_filter.set_property("caps", &va_caps);
pipeline.add_many(&[
&src,
&caps_filter,
&vapostproc,
&va_caps_filter,
&enc,
&sink,
])?;
gstreamer::Element::link_many(&[
&src,
&caps_filter,
&vapostproc,
&va_caps_filter,
&enc,
&sink,
])?;
// Non-zero-copy path for all encoders - needs videoconvert
let videoconvert = gstreamer::ElementFactory::make("videoconvert").build()?;
pipeline.add_many(&[&src, &caps_filter, &videoconvert, &enc, &sink])?;
gstreamer::Element::link_many(&[&src, &caps_filter, &videoconvert, &enc, &sink])?;
}
let bus = pipeline.bus().ok_or("Pipeline has no bus")?;
let _ = pipeline.set_state(gstreamer::State::Playing);
for msg in bus.iter_timed(gstreamer::ClockTime::from_seconds(2)) {
pipeline.set_state(gstreamer::State::Playing)?;
// Wait for either error or async-done (state change complete)
for msg in bus.iter_timed(gstreamer::ClockTime::from_seconds(10)) {
match msg.view() {
gstreamer::MessageView::Error(err) => {
let err_msg = format!("Pipeline error: {}", err.error());
@@ -694,14 +690,17 @@ pub fn test_encoder(encoder: &VideoEncoderInfo, dma_buf: bool) -> Result<(), Box
let _ = pipeline.set_state(gstreamer::State::Null);
return Err(err_msg.into());
}
gstreamer::MessageView::Eos(_) => {
tracing::info!("Pipeline EOS received");
gstreamer::MessageView::AsyncDone(_) => {
// Pipeline successfully reached PLAYING state
tracing::debug!("Pipeline reached PLAYING state successfully");
let _ = pipeline.set_state(gstreamer::State::Null);
return Err("Pipeline EOS received, encoder test failed".into());
return Ok(());
}
_ => {}
}
}
// If we got here, timeout occurred without reaching PLAYING or error
let _ = pipeline.set_state(gstreamer::State::Null);
Ok(())
Err("Encoder test timed out".into())
}

View File

@@ -112,11 +112,25 @@ pub fn get_gpus() -> Result<Vec<GPUInfo>, Box<dyn Error>> {
let minor = &caps[1];
// Read vendor and device ID
let vendor_str = fs::read_to_string(format!("/sys/class/drm/card{}/device/vendor", minor))?;
let vendor_str = fs::read_to_string(format!("/sys/class/drm/card{}/device/vendor", minor));
let vendor_str = match vendor_str {
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to read vendor for card{}: {}", minor, e);
continue;
}
};
let vendor_str = vendor_str.trim_start_matches("0x").trim_end_matches('\n');
let vendor = u16::from_str_radix(vendor_str, 16)?;
let device_str = fs::read_to_string(format!("/sys/class/drm/card{}/device/device", minor))?;
let device_str = fs::read_to_string(format!("/sys/class/drm/card{}/device/device", minor));
let device_str = match device_str {
Ok(d) => d,
Err(e) => {
tracing::warn!("Failed to read device for card{}: {}", minor, e);
continue;
}
};
let device_str = device_str.trim_start_matches("0x").trim_end_matches('\n');
// Look up in hwdata PCI database
@@ -129,7 +143,15 @@ pub fn get_gpus() -> Result<Vec<GPUInfo>, Box<dyn Error>> {
};
// Read PCI bus ID
let pci_bus_id = fs::read_to_string(format!("/sys/class/drm/card{}/device/uevent", minor))?;
let pci_bus_id = fs::read_to_string(format!("/sys/class/drm/card{}/device/uevent", minor));
let pci_bus_id = match pci_bus_id {
Ok(p) => p,
Err(e) => {
tracing::warn!("Failed to read PCI bus ID for card{}: {}", minor, e);
continue;
}
};
// Extract PCI_SLOT_NAME from uevent content
let pci_bus_id = pci_bus_id
.lines()
.find_map(|line| {
@@ -191,7 +213,6 @@ fn parse_pci_ids(pci_data: &str, vendor_id: &str, device_id: &str) -> Option<Str
fn get_dri_device_path(pci_addr: &str) -> Option<(String, String)> {
let entries = fs::read_dir("/sys/bus/pci/devices").ok()?;
for entry in entries.flatten() {
if !entry.path().to_string_lossy().contains(&pci_addr) {
continue;

View File

@@ -0,0 +1 @@
pub mod controller;

View File

@@ -0,0 +1,205 @@
use crate::proto::proto::proto_input::InputType::{
ControllerAttach, ControllerAxis, ControllerButton, ControllerDetach, ControllerRumble,
ControllerStick, ControllerTrigger,
};
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
fn controller_string_to_type(controller_type: &str) -> Result<vimputti::DeviceConfig> {
match controller_type.to_lowercase().as_str() {
"ps4" => Ok(vimputti::ControllerTemplates::ps4()),
"ps5" => Ok(vimputti::ControllerTemplates::ps5()),
"xbox360" => Ok(vimputti::ControllerTemplates::xbox360()),
"xboxone" => Ok(vimputti::ControllerTemplates::xbox_one()),
"switchpro" => Ok(vimputti::ControllerTemplates::switch_pro()),
_ => Err(anyhow::anyhow!(
"Unsupported controller type: {}",
controller_type
)),
}
}
pub struct ControllerInput {
config: vimputti::DeviceConfig,
device: vimputti::client::VirtualController,
}
impl ControllerInput {
pub async fn new(
controller_type: String,
client: &vimputti::client::VimputtiClient,
) -> Result<Self> {
let config = controller_string_to_type(&controller_type)?;
Ok(Self {
config: config.clone(),
device: client.create_device(config).await?,
})
}
pub fn device_mut(&mut self) -> &mut vimputti::client::VirtualController {
&mut self.device
}
pub fn device(&self) -> &vimputti::client::VirtualController {
&self.device
}
}
pub struct ControllerManager {
vimputti_client: Arc<vimputti::client::VimputtiClient>,
cmd_tx: mpsc::Sender<crate::proto::proto::ProtoInput>,
rumble_tx: mpsc::Sender<(u32, u16, u16, u16)>, // (slot, strong, weak, duration_ms)
}
impl ControllerManager {
pub fn new(
vimputti_client: Arc<vimputti::client::VimputtiClient>,
) -> Result<(Self, mpsc::Receiver<(u32, u16, u16, u16)>)> {
let (cmd_tx, cmd_rx) = mpsc::channel(100);
let (rumble_tx, rumble_rx) = mpsc::channel(100);
tokio::spawn(command_loop(
cmd_rx,
vimputti_client.clone(),
rumble_tx.clone(),
));
Ok((
Self {
vimputti_client,
cmd_tx,
rumble_tx,
},
rumble_rx,
))
}
pub async fn send_command(&self, input: crate::proto::proto::ProtoInput) -> Result<()> {
self.cmd_tx.send(input).await?;
Ok(())
}
}
async fn command_loop(
mut cmd_rx: mpsc::Receiver<crate::proto::proto::ProtoInput>,
vimputti_client: Arc<vimputti::client::VimputtiClient>,
rumble_tx: mpsc::Sender<(u32, u16, u16, u16)>,
) {
let mut controllers: HashMap<u32, ControllerInput> = HashMap::new();
while let Some(input) = cmd_rx.recv().await {
if let Some(input_type) = input.input_type {
match input_type {
ControllerAttach(data) => {
// Check if controller already exists in the slot, if so, ignore
if controllers.contains_key(&(data.slot as u32)) {
tracing::warn!(
"Controller slot {} already occupied, ignoring attach",
data.slot
);
} else {
if let Ok(mut controller) =
ControllerInput::new(data.id.clone(), &vimputti_client).await
{
let slot = data.slot as u32;
let rumble_tx = rumble_tx.clone();
controller
.device_mut()
.on_rumble(move |strong, weak, duration_ms| {
let _ = rumble_tx.try_send((slot, strong, weak, duration_ms));
})
.await
.map_err(|e| {
tracing::warn!(
"Failed to register rumble callback for slot {}: {}",
slot,
e
);
})
.ok();
controllers.insert(data.slot as u32, controller);
tracing::info!("Controller {} attached to slot {}", data.id, data.slot);
} else {
tracing::error!(
"Failed to create controller of type {} for slot {}",
data.id,
data.slot
);
}
}
}
ControllerDetach(data) => {
if controllers.remove(&(data.slot as u32)).is_some() {
tracing::info!("Controller detached from slot {}", data.slot);
} else {
tracing::warn!("No controller found in slot {} to detach", data.slot);
}
}
ControllerButton(data) => {
if let Some(controller) = controllers.get(&(data.slot as u32)) {
if let Some(button) = vimputti::Button::from_ev_code(data.button as u16) {
let device = controller.device();
device.button(button, data.pressed);
device.sync();
}
} else {
tracing::warn!("Controller slot {} not found for button event", data.slot);
}
}
ControllerStick(data) => {
if let Some(controller) = controllers.get(&(data.slot as u32)) {
let device = controller.device();
if data.stick == 0 {
// Left stick
device.axis(vimputti::Axis::LeftStickX, data.x);
device.sync();
device.axis(vimputti::Axis::LeftStickY, data.y);
} else if data.stick == 1 {
// Right stick
device.axis(vimputti::Axis::RightStickX, data.x);
device.sync();
device.axis(vimputti::Axis::RightStickY, data.y);
}
device.sync();
} else {
tracing::warn!("Controller slot {} not found for stick event", data.slot);
}
}
ControllerTrigger(data) => {
if let Some(controller) = controllers.get(&(data.slot as u32)) {
let device = controller.device();
if data.trigger == 0 {
// Left trigger
device.axis(vimputti::Axis::LowerLeftTrigger, data.value);
} else if data.trigger == 1 {
// Right trigger
device.axis(vimputti::Axis::LowerRightTrigger, data.value);
}
device.sync();
} else {
tracing::warn!("Controller slot {} not found for trigger event", data.slot);
}
}
ControllerAxis(data) => {
if let Some(controller) = controllers.get(&(data.slot as u32)) {
let device = controller.device();
if data.axis == 0 {
// dpad x
device.axis(vimputti::Axis::DPadX, data.value);
} else if data.axis == 1 {
// dpad y
device.axis(vimputti::Axis::DPadY, data.value);
}
device.sync();
}
}
// Rumble will be outgoing event..
ControllerRumble(_) => {
//no-op
}
_ => {
//no-op
}
}
}
}
}

View File

@@ -1,6 +1,7 @@
mod args;
mod enc_helper;
mod gpu;
mod input;
mod latency;
mod messages;
mod nestrisink;
@@ -10,6 +11,7 @@ mod proto;
use crate::args::encoding_args;
use crate::enc_helper::{EncoderAPI, EncoderType};
use crate::gpu::{GPUInfo, GPUVendor};
use crate::input::controller::ControllerManager;
use crate::nestrisink::NestriSignaller;
use crate::p2p::p2p::NestriP2P;
use gstreamer::prelude::*;
@@ -118,7 +120,7 @@ fn handle_encoder_video(
&video_encoders,
&args.encoding.video.codec,
&args.encoding.video.encoder_type,
args.app.dma_buf,
args.app.zero_copy,
)?;
}
tracing::info!("Selected video encoder: '{}'", video_encoder.name);
@@ -174,9 +176,6 @@ fn handle_encoder_audio(args: &args::Args) -> String {
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
// Parse command line arguments
let mut args = args::Args::new();
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::builder()
@@ -185,6 +184,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
)
.init();
// Parse command line arguments
let mut args = args::Args::new();
if args.app.verbose {
args.debug_print();
}
@@ -199,13 +201,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
gstreamer::init()?;
let _ = gstrswebrtc::plugin_register_static(); // Might be already registered, so we'll pass..
if args.app.dma_buf {
if args.app.zero_copy {
if args.encoding.video.encoder_type != EncoderType::HARDWARE {
tracing::warn!("DMA-BUF is only supported with hardware encoders, disabling DMA-BUF..");
args.app.dma_buf = false;
tracing::warn!(
"zero-copy is only supported with hardware encoders, disabling zero-copy.."
);
args.app.zero_copy = false;
} else {
tracing::warn!(
"DMA-BUF is experimental, it may or may not improve performance, or even work at all."
"zero-copy is experimental, it may or may not improve performance, or even work at all."
);
}
}
@@ -238,6 +242,28 @@ async fn main() -> Result<(), Box<dyn Error>> {
let nestri_p2p = Arc::new(NestriP2P::new().await?);
let p2p_conn = nestri_p2p.connect(relay_url).await?;
// Get vimputti manager connection if available
let vpath = match args.app.vimputti_path {
Some(ref path) => path.clone(),
None => "/tmp/vimputti-0".to_string(),
};
let vimputti_client = match vimputti::VimputtiClient::connect(vpath).await {
Ok(client) => {
tracing::info!("Connected to vimputti manager");
Some(Arc::new(client))
}
Err(e) => {
tracing::warn!("Failed to connect to vimputti manager: {}", e);
None
}
};
let (controller_manager, rumble_rx) = if let Some(vclient) = vimputti_client {
let (controller_manager, rumble_rx) = ControllerManager::new(vclient)?;
(Some(Arc::new(controller_manager)), Some(rumble_rx))
} else {
(None, None)
};
/*** PIPELINE CREATION ***/
// Create the pipeline
let pipeline = Arc::new(gstreamer::Pipeline::new());
@@ -266,7 +292,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Required to fix gstreamer opus issue, where quality sounds off (due to wrong sample rate)
let audio_capsfilter = gstreamer::ElementFactory::make("capsfilter").build()?;
let audio_caps = gstreamer::Caps::from_str("audio/x-raw,rate=48000,channels=2").unwrap();
let audio_caps = gstreamer::Caps::from_str("audio/x-raw,rate=48000,channels=2")?;
audio_capsfilter.set_property("caps", &audio_caps);
// Audio Encoder Element
@@ -302,22 +328,30 @@ async fn main() -> Result<(), Box<dyn Error>> {
let caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let caps = gstreamer::Caps::from_str(&format!(
"{},width={},height={},framerate={}/1{}",
if args.app.dma_buf {
"video/x-raw(memory:DMABuf)"
if args.app.zero_copy {
if video_encoder_info.encoder_api == EncoderAPI::NVENC {
"video/x-raw(memory:CUDAMemory)"
} else {
"video/x-raw(memory:DMABuf)"
}
} else {
"video/x-raw"
},
args.app.resolution.0,
args.app.resolution.1,
args.app.framerate,
if args.app.dma_buf { "" } else { ",format=RGBx" }
if args.app.zero_copy {
""
} else {
",format=RGBx"
}
))?;
caps_filter.set_property("caps", &caps);
// Get bit-depth and choose appropriate format (NV12 or P010_10LE)
// H.264 does not support above 8-bit. Also we require DMA-BUF.
let video_format = if args.encoding.video.bit_depth == 10
&& args.app.dma_buf
&& args.app.zero_copy
&& video_encoder_info.codec != enc_helper::VideoCodec::H264
{
"P010_10LE"
@@ -325,27 +359,6 @@ async fn main() -> Result<(), Box<dyn Error>> {
"NV12"
};
// GL and CUDA elements (NVIDIA only..)
let mut glupload = None;
let mut glconvert = None;
let mut gl_caps_filter = None;
let mut cudaupload = None;
if args.app.dma_buf && video_encoder_info.encoder_api == EncoderAPI::NVENC {
// GL upload element
glupload = Some(gstreamer::ElementFactory::make("glupload").build()?);
// GL color convert element
glconvert = Some(gstreamer::ElementFactory::make("glcolorconvert").build()?);
// GL color convert caps
let caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let gl_caps = gstreamer::Caps::from_str(
format!("video/x-raw(memory:GLMemory),format={video_format}").as_str(),
)?;
caps_filter.set_property("caps", &gl_caps);
gl_caps_filter = Some(caps_filter);
// CUDA upload element
cudaupload = Some(gstreamer::ElementFactory::make("cudaupload").build()?);
}
// vapostproc for VA compatible encoders
let mut vapostproc = None;
let mut va_caps_filter = None;
@@ -364,7 +377,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Video Converter Element
let mut video_converter = None;
if !args.app.dma_buf {
if !args.app.zero_copy {
video_converter = Some(gstreamer::ElementFactory::make("videoconvert").build()?);
}
@@ -397,24 +410,34 @@ async fn main() -> Result<(), Box<dyn Error>> {
/* Output */
// WebRTC sink Element
let signaller =
NestriSignaller::new(args.app.room, p2p_conn.clone(), video_source.clone()).await?;
let signaller = NestriSignaller::new(
args.app.room,
p2p_conn.clone(),
video_source.clone(),
controller_manager,
rumble_rx,
)
.await?;
let webrtcsink = BaseWebRTCSink::with_signaller(Signallable::from(signaller.clone()));
webrtcsink.set_property_from_str("stun-server", "stun://stun.l.google.com:19302");
webrtcsink.set_property_from_str("congestion-control", "disabled");
webrtcsink.set_property("do-retransmission", false);
/* Queues */
let video_queue = gstreamer::ElementFactory::make("queue2")
.property("max-size-buffers", 3u32)
.property("max-size-time", 0u64)
.property("max-size-bytes", 0u32)
let video_source_queue = gstreamer::ElementFactory::make("queue")
.property("max-size-buffers", 5u32)
.build()?;
let audio_queue = gstreamer::ElementFactory::make("queue2")
.property("max-size-buffers", 3u32)
.property("max-size-time", 0u64)
.property("max-size-bytes", 0u32)
let audio_source_queue = gstreamer::ElementFactory::make("queue")
.property("max-size-buffers", 5u32)
.build()?;
let video_queue = gstreamer::ElementFactory::make("queue")
.property("max-size-buffers", 5u32)
.build()?;
let audio_queue = gstreamer::ElementFactory::make("queue")
.property("max-size-buffers", 5u32)
.build()?;
/* Clock Sync */
@@ -433,6 +456,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
&caps_filter,
&video_queue,
&video_clocksync,
&video_source_queue,
&video_source,
&audio_encoder,
&audio_capsfilter,
@@ -440,6 +464,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
&audio_clocksync,
&audio_rate,
&audio_converter,
&audio_source_queue,
&audio_source,
])?;
@@ -455,24 +480,18 @@ async fn main() -> Result<(), Box<dyn Error>> {
pipeline.add(parser)?;
}
// If DMA-BUF..
if args.app.dma_buf {
// If zero-copy..
if args.app.zero_copy {
// VA-API / QSV pipeline
if let (Some(vapostproc), Some(va_caps_filter)) = (&vapostproc, &va_caps_filter) {
pipeline.add_many(&[vapostproc, va_caps_filter])?;
} else {
// NVENC pipeline
if let (Some(glupload), Some(glconvert), Some(gl_caps_filter), Some(cudaupload)) =
(&glupload, &glconvert, &gl_caps_filter, &cudaupload)
{
pipeline.add_many(&[glupload, glconvert, gl_caps_filter, cudaupload])?;
}
}
}
// Link main audio branch
gstreamer::Element::link_many(&[
&audio_source,
&audio_source_queue,
&audio_converter,
&audio_rate,
&audio_capsfilter,
@@ -488,12 +507,13 @@ async fn main() -> Result<(), Box<dyn Error>> {
gstreamer::Element::link_many(&[&audio_encoder, webrtcsink.upcast_ref()])?;
}
// With DMA-BUF..
if args.app.dma_buf {
// With zero-copy..
if args.app.zero_copy {
// VA-API / QSV pipeline
if let (Some(vapostproc), Some(va_caps_filter)) = (&vapostproc, &va_caps_filter) {
gstreamer::Element::link_many(&[
&video_source,
&video_source_queue,
&caps_filter,
&video_queue,
&video_clocksync,
@@ -501,27 +521,19 @@ async fn main() -> Result<(), Box<dyn Error>> {
&va_caps_filter,
&video_encoder,
])?;
} else {
} else if video_encoder_info.encoder_api == EncoderAPI::NVENC {
// NVENC pipeline
if let (Some(glupload), Some(glconvert), Some(gl_caps_filter), Some(cudaupload)) =
(&glupload, &glconvert, &gl_caps_filter, &cudaupload)
{
gstreamer::Element::link_many(&[
&video_source,
&caps_filter,
&video_queue,
&video_clocksync,
&glupload,
&glconvert,
&gl_caps_filter,
&cudaupload,
&video_encoder,
])?;
}
gstreamer::Element::link_many(&[
&video_source,
&video_source_queue,
&caps_filter,
&video_encoder,
])?;
}
} else {
gstreamer::Element::link_many(&[
&video_source,
&video_source_queue,
&caps_filter,
&video_queue,
&video_clocksync,
@@ -537,8 +549,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
gstreamer::Element::link_many(&[&video_encoder, webrtcsink.upcast_ref()])?;
}
// Set QOS
video_encoder.set_property("qos", true);
// Make sure QOS is disabled to avoid latency
video_encoder.set_property("qos", false);
// Optimize latency of pipeline
video_source

View File

@@ -1,3 +1,4 @@
use crate::input::controller::ControllerManager;
use crate::messages::{MessageBase, MessageICE, MessageRaw, MessageSDP};
use crate::p2p::p2p::NestriConnection;
use crate::p2p::p2p_protocol_stream::NestriStreamProtocol;
@@ -5,7 +6,7 @@ use crate::proto::proto::proto_input::InputType::{
KeyDown, KeyUp, MouseKeyDown, MouseKeyUp, MouseMove, MouseMoveAbs, MouseWheel,
};
use crate::proto::proto::{ProtoInput, ProtoMessageInput};
use atomic_refcell::AtomicRefCell;
use anyhow::Result;
use glib::subclass::prelude::*;
use gstreamer::glib;
use gstreamer::prelude::*;
@@ -14,6 +15,7 @@ use gstrswebrtc::signaller::{Signallable, SignallableImpl};
use parking_lot::RwLock as PLRwLock;
use prost::Message;
use std::sync::{Arc, LazyLock};
use tokio::sync::{Mutex, mpsc};
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
@@ -21,7 +23,9 @@ pub struct Signaller {
stream_room: PLRwLock<Option<String>>,
stream_protocol: PLRwLock<Option<Arc<NestriStreamProtocol>>>,
wayland_src: PLRwLock<Option<Arc<gstreamer::Element>>>,
data_channel: AtomicRefCell<Option<gstreamer_webrtc::WebRTCDataChannel>>,
data_channel: PLRwLock<Option<Arc<gstreamer_webrtc::WebRTCDataChannel>>>,
controller_manager: PLRwLock<Option<Arc<ControllerManager>>>,
rumble_rx: Mutex<Option<mpsc::Receiver<(u32, u16, u16, u16)>>>,
}
impl Default for Signaller {
fn default() -> Self {
@@ -29,15 +33,14 @@ impl Default for Signaller {
stream_room: PLRwLock::new(None),
stream_protocol: PLRwLock::new(None),
wayland_src: PLRwLock::new(None),
data_channel: AtomicRefCell::new(None),
data_channel: PLRwLock::new(None),
controller_manager: PLRwLock::new(None),
rumble_rx: Mutex::new(None),
}
}
}
impl Signaller {
pub async fn set_nestri_connection(
&self,
nestri_conn: NestriConnection,
) -> Result<(), Box<dyn std::error::Error>> {
pub async fn set_nestri_connection(&self, nestri_conn: NestriConnection) -> Result<()> {
let stream_protocol = NestriStreamProtocol::new(nestri_conn).await?;
*self.stream_protocol.write() = Some(Arc::new(stream_protocol));
Ok(())
@@ -59,14 +62,29 @@ impl Signaller {
self.wayland_src.read().clone()
}
pub fn set_controller_manager(&self, controller_manager: Arc<ControllerManager>) {
*self.controller_manager.write() = Some(controller_manager);
}
pub fn get_controller_manager(&self) -> Option<Arc<ControllerManager>> {
self.controller_manager.read().clone()
}
pub async fn set_rumble_rx(&self, rumble_rx: mpsc::Receiver<(u32, u16, u16, u16)>) {
*self.rumble_rx.lock().await = Some(rumble_rx);
}
// Change getter to take ownership:
pub async fn take_rumble_rx(&self) -> Option<mpsc::Receiver<(u32, u16, u16, u16)>> {
self.rumble_rx.lock().await.take()
}
pub fn set_data_channel(&self, data_channel: gstreamer_webrtc::WebRTCDataChannel) {
match self.data_channel.try_borrow_mut() {
Ok(mut dc) => *dc = Some(data_channel),
Err(_) => gstreamer::warning!(
gstreamer::CAT_DEFAULT,
"Failed to set data channel - already borrowed"
),
}
*self.data_channel.write() = Some(Arc::new(data_channel));
}
pub fn get_data_channel(&self) -> Option<Arc<gstreamer_webrtc::WebRTCDataChannel>> {
self.data_channel.read().clone()
}
/// Helper method to clean things up
@@ -79,15 +97,15 @@ impl Signaller {
let self_obj = self.obj().clone();
stream_protocol.register_callback("answer", move |data| {
if let Ok(message) = serde_json::from_slice::<MessageSDP>(&data) {
let sdp =
gst_sdp::SDPMessage::parse_buffer(message.sdp.sdp.as_bytes()).unwrap();
let sdp = gst_sdp::SDPMessage::parse_buffer(message.sdp.sdp.as_bytes())
.map_err(|e| anyhow::anyhow!("Invalid SDP in 'answer': {e:?}"))?;
let answer = WebRTCSessionDescription::new(WebRTCSDPType::Answer, sdp);
self_obj.emit_by_name::<()>(
Ok(self_obj.emit_by_name::<()>(
"session-description",
&[&"unique-session-id", &answer],
);
))
} else {
gstreamer::error!(gstreamer::CAT_DEFAULT, "Failed to decode SDP message");
anyhow::bail!("Failed to decode SDP message");
}
});
}
@@ -98,7 +116,7 @@ impl Signaller {
let candidate = message.candidate;
let sdp_m_line_index = candidate.sdp_mline_index.unwrap_or(0) as u32;
let sdp_mid = candidate.sdp_mid;
self_obj.emit_by_name::<()>(
Ok(self_obj.emit_by_name::<()>(
"handle-ice",
&[
&"unique-session-id",
@@ -106,9 +124,9 @@ impl Signaller {
&sdp_mid,
&candidate.candidate,
],
);
))
} else {
gstreamer::error!(gstreamer::CAT_DEFAULT, "Failed to decode ICE message");
anyhow::bail!("Failed to decode ICE message");
}
});
}
@@ -131,16 +149,16 @@ impl Signaller {
}
// Send our SDP offer
self_obj.emit_by_name::<()>(
Ok(self_obj.emit_by_name::<()>(
"session-requested",
&[
&"unique-session-id",
&"consumer-identifier",
&None::<WebRTCSessionDescription>,
],
);
))
} else {
gstreamer::error!(gstreamer::CAT_DEFAULT, "Failed to decode answer");
anyhow::bail!("Failed to decode answer");
}
});
}
@@ -173,8 +191,25 @@ impl Signaller {
if let Some(data_channel) = data_channel {
gstreamer::info!(gstreamer::CAT_DEFAULT, "Data channel created");
if let Some(wayland_src) = signaller.imp().get_wayland_src() {
setup_data_channel(&data_channel, &*wayland_src);
signaller.imp().set_data_channel(data_channel);
signaller.imp().set_data_channel(data_channel.clone());
let signaller = signaller.clone();
let data_channel = Arc::new(data_channel);
let wayland_src = wayland_src.clone();
// Spawn async task to take the receiver and set up
tokio::spawn(async move {
let rumble_rx = signaller.imp().take_rumble_rx().await;
let controller_manager =
signaller.imp().get_controller_manager();
setup_data_channel(
controller_manager,
rumble_rx,
data_channel,
&wayland_src,
);
});
} else {
gstreamer::error!(
gstreamer::CAT_DEFAULT,
@@ -315,31 +350,83 @@ impl ObjectImpl for Signaller {
}
fn setup_data_channel(
data_channel: &gstreamer_webrtc::WebRTCDataChannel,
controller_manager: Option<Arc<ControllerManager>>,
rumble_rx: Option<mpsc::Receiver<(u32, u16, u16, u16)>>, // (slot, strong, weak, duration_ms)
data_channel: Arc<gstreamer_webrtc::WebRTCDataChannel>,
wayland_src: &gstreamer::Element,
) {
let wayland_src = wayland_src.clone();
let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
data_channel.connect_on_message_data(move |_data_channel, data| {
if let Some(data) = data {
match ProtoMessageInput::decode(data.to_vec().as_slice()) {
// Spawn async processor
tokio::spawn(async move {
while let Some(data) = rx.recv().await {
match ProtoMessageInput::decode(data.as_slice()) {
Ok(message_input) => {
if let Some(input_msg) = message_input.data {
// Process the input message and create an event
if let Some(event) = handle_input_message(input_msg) {
// Send the event to wayland source, result bool is ignored
let _ = wayland_src.send_event(event);
if let Some(message_base) = message_input.message_base {
if message_base.payload_type == "input" {
if let Some(input_data) = message_input.data {
if let Some(event) = handle_input_message(input_data) {
// Send the event to wayland source, result bool is ignored
let _ = wayland_src.send_event(event);
}
}
} else if message_base.payload_type == "controllerInput" {
if let Some(controller_manager) = &controller_manager {
if let Some(input_data) = message_input.data {
let _ = controller_manager.send_command(input_data).await;
}
}
}
} else {
tracing::error!("Failed to parse InputMessage");
}
}
Err(e) => {
tracing::error!("Failed to decode MessageInput: {:?}", e);
tracing::error!("Failed to decode input message: {:?}", e);
}
}
}
});
// Spawn rumble sender
if let Some(mut rumble_rx) = rumble_rx {
let data_channel_clone = data_channel.clone();
tokio::spawn(async move {
while let Some((slot, strong, weak, duration_ms)) = rumble_rx.recv().await {
let rumble_msg = ProtoMessageInput {
message_base: Some(crate::proto::proto::ProtoMessageBase {
payload_type: "controllerInput".to_string(),
latency: None,
}),
data: Some(ProtoInput {
input_type: Some(
crate::proto::proto::proto_input::InputType::ControllerRumble(
crate::proto::proto::ProtoControllerRumble {
r#type: "ControllerRumble".to_string(),
slot: slot as i32,
low_frequency: weak as i32,
high_frequency: strong as i32,
duration: duration_ms as i32,
},
),
),
}),
};
let data = rumble_msg.encode_to_vec();
let bytes = glib::Bytes::from_owned(data);
if let Err(e) = data_channel_clone.send_data_full(Some(&bytes)) {
tracing::warn!("Failed to send rumble data: {}", e);
}
}
});
}
data_channel.connect_on_message_data(move |_data_channel, data| {
if let Some(data) = data {
let _ = tx.send(data.to_vec());
}
});
}
fn handle_input_message(input_msg: ProtoInput) -> Option<gstreamer::Event> {
@@ -401,6 +488,7 @@ fn handle_input_message(input_msg: ProtoInput) -> Option<gstreamer::Event> {
Some(gstreamer::event::CustomUpstream::new(structure))
}
_ => None,
}
} else {
None

View File

@@ -1,8 +1,10 @@
use crate::input::controller::ControllerManager;
use crate::p2p::p2p::NestriConnection;
use gstreamer::glib;
use gstreamer::subclass::prelude::*;
use gstrswebrtc::signaller::Signallable;
use std::sync::Arc;
use tokio::sync::mpsc;
mod imp;
@@ -15,11 +17,19 @@ impl NestriSignaller {
room: String,
nestri_conn: NestriConnection,
wayland_src: Arc<gstreamer::Element>,
controller_manager: Option<Arc<ControllerManager>>,
rumble_rx: Option<mpsc::Receiver<(u32, u16, u16, u16)>>,
) -> Result<Self, Box<dyn std::error::Error>> {
let obj: Self = glib::Object::new();
obj.imp().set_stream_room(room);
obj.imp().set_nestri_connection(nestri_conn).await?;
obj.imp().set_wayland_src(wayland_src);
if let Some(controller_manager) = controller_manager {
obj.imp().set_controller_manager(controller_manager);
}
if let Some(rumble_rx) = rumble_rx {
obj.imp().set_rumble_rx(rumble_rx).await;
}
Ok(obj)
}
}

View File

@@ -1,3 +1,4 @@
use anyhow::Result;
use libp2p::futures::StreamExt;
use libp2p::multiaddr::Protocol;
use libp2p::{
@@ -11,7 +12,6 @@ use libp2p_ping as ping;
use libp2p_stream as stream;
use libp2p_tcp as tcp;
use libp2p_yamux as yamux;
use std::error::Error;
use std::sync::Arc;
use tokio::sync::Mutex;
@@ -46,7 +46,7 @@ pub struct NestriP2P {
swarm: Arc<Mutex<Swarm<NestriBehaviour>>>,
}
impl NestriP2P {
pub async fn new() -> Result<Self, Box<dyn Error>> {
pub async fn new() -> Result<Self> {
let swarm = Arc::new(Mutex::new(
libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
@@ -69,14 +69,16 @@ impl NestriP2P {
Ok(NestriP2P { swarm })
}
pub async fn connect(&self, conn_url: &str) -> Result<NestriConnection, Box<dyn Error>> {
pub async fn connect(&self, conn_url: &str) -> Result<NestriConnection> {
let conn_addr: Multiaddr = conn_url.parse()?;
let mut swarm_lock = self.swarm.lock().await;
swarm_lock.dial(conn_addr.clone())?;
let Some(Protocol::P2p(peer_id)) = conn_addr.clone().iter().last() else {
return Err("Invalid connection URL: missing peer ID".into());
return Err(anyhow::Error::msg(
"Invalid multiaddr: missing /p2p/<peer_id>",
));
};
Ok(NestriConnection {
@@ -88,10 +90,7 @@ impl NestriP2P {
async fn swarm_loop(swarm: Arc<Mutex<Swarm<NestriBehaviour>>>) {
loop {
let event = {
let mut swarm_lock = swarm.lock().await;
swarm_lock.select_next_some().await
};
let event = swarm.lock().await.select_next_some().await;
match event {
/* Ping Events */
SwarmEvent::Behaviour(NestriBehaviourEvent::Ping(ping::Event {

View File

@@ -1,23 +1,23 @@
use crate::p2p::p2p::NestriConnection;
use crate::p2p::p2p_safestream::SafeStream;
use anyhow::Result;
use dashmap::DashMap;
use libp2p::StreamProtocol;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::{self, Duration};
// Cloneable callback type
pub type CallbackInner = dyn Fn(Vec<u8>) + Send + Sync + 'static;
pub type CallbackInner = dyn Fn(Vec<u8>) -> Result<()> + Send + Sync + 'static;
pub struct Callback(Arc<CallbackInner>);
impl Callback {
pub fn new<F>(f: F) -> Self
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
F: Fn(Vec<u8>) -> Result<()> + Send + Sync + 'static,
{
Callback(Arc::new(f))
}
pub fn call(&self, data: Vec<u8>) {
pub fn call(&self, data: Vec<u8>) -> Result<()> {
self.0(data)
}
}
@@ -44,9 +44,7 @@ impl NestriStreamProtocol {
const NESTRI_PROTOCOL_STREAM_PUSH: StreamProtocol =
StreamProtocol::new("/nestri-relay/stream-push/1.0.0");
pub async fn new(
nestri_connection: NestriConnection,
) -> Result<Self, Box<dyn std::error::Error>> {
pub async fn new(nestri_connection: NestriConnection) -> Result<Self> {
let mut nestri_connection = nestri_connection.clone();
let push_stream = match nestri_connection
.control
@@ -55,7 +53,10 @@ impl NestriStreamProtocol {
{
Ok(stream) => stream,
Err(e) => {
return Err(Box::new(e));
return Err(anyhow::Error::msg(format!(
"Failed to open push stream: {}",
e
)));
}
};
@@ -73,7 +74,7 @@ impl NestriStreamProtocol {
Ok(sp)
}
pub fn restart(&mut self) -> Result<(), Box<dyn std::error::Error>> {
pub fn restart(&mut self) -> Result<()> {
// Return if tx and handles are already initialized
if self.tx.is_some() && self.read_handle.is_some() && self.write_handle.is_some() {
tracing::warn!("NestriStreamProtocol is already running, restart skipped");
@@ -111,13 +112,9 @@ impl NestriStreamProtocol {
// we just get the callback directly if it exists
if let Some(callback) = callbacks.get(&response_type) {
// Execute the callback
if let Err(e) =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
callback.call(data.clone())
}))
{
if let Err(e) = callback.call(data.clone()) {
tracing::error!(
"Callback for response type '{}' panicked: {:?}",
"Callback for response type '{}' errored: {:?}",
response_type,
e
);
@@ -133,9 +130,6 @@ impl NestriStreamProtocol {
tracing::error!("Failed to decode message: {}", e);
}
}
// Add a small sleep to reduce CPU usage
time::sleep(Duration::from_micros(100)).await;
}
})
}
@@ -156,27 +150,20 @@ impl NestriStreamProtocol {
break;
}
}
// Add a small sleep to reduce CPU usage
time::sleep(Duration::from_micros(100)).await;
}
})
}
pub fn send_message<M: serde::Serialize>(
&self,
message: &M,
) -> Result<(), Box<dyn std::error::Error>> {
pub fn send_message<M: serde::Serialize>(&self, message: &M) -> Result<()> {
let json_data = serde_json::to_vec(message)?;
let Some(tx) = &self.tx else {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::NotConnected,
return Err(anyhow::Error::msg(
if self.read_handle.is_none() && self.write_handle.is_none() {
"NestriStreamProtocol has been shutdown"
} else {
"NestriStreamProtocol is not properly initialized"
},
)));
));
};
tx.try_send(json_data)?;
Ok(())
@@ -184,7 +171,7 @@ impl NestriStreamProtocol {
pub fn register_callback<F>(&self, response_type: &str, callback: F)
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
F: Fn(Vec<u8>) -> Result<()> + Send + Sync + 'static,
{
self.callbacks
.insert(response_type.to_string(), Callback::new(callback));

View File

@@ -1,3 +1,4 @@
use anyhow::Result;
use byteorder::{BigEndian, ByteOrder};
use libp2p::futures::io::{ReadHalf, WriteHalf};
use libp2p::futures::{AsyncReadExt, AsyncWriteExt};
@@ -19,17 +20,17 @@ impl SafeStream {
}
}
pub async fn send_raw(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
pub async fn send_raw(&self, data: &[u8]) -> Result<()> {
self.send_with_length_prefix(data).await
}
pub async fn receive_raw(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
pub async fn receive_raw(&self) -> Result<Vec<u8>> {
self.receive_with_length_prefix().await
}
async fn send_with_length_prefix(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
async fn send_with_length_prefix(&self, data: &[u8]) -> Result<()> {
if data.len() > MAX_SIZE {
return Err("Data exceeds maximum size".into());
anyhow::bail!("Data exceeds maximum size");
}
let mut buffer = Vec::with_capacity(4 + data.len());
@@ -42,7 +43,7 @@ impl SafeStream {
Ok(())
}
async fn receive_with_length_prefix(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
async fn receive_with_length_prefix(&self) -> Result<Vec<u8>> {
let mut stream_read = self.stream_read.lock().await;
// Read length prefix + data in one syscall
@@ -51,7 +52,7 @@ impl SafeStream {
let length = BigEndian::read_u32(&length_prefix) as usize;
if length > MAX_SIZE {
return Err("Data exceeds maximum size".into());
anyhow::bail!("Received data exceeds maximum size");
}
let mut buffer = vec![0u8; length];

View File

@@ -3,29 +3,31 @@
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoTimestampEntry {
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub stage: ::prost::alloc::string::String,
#[prost(message, optional, tag = "2")]
#[prost(message, optional, tag="2")]
pub time: ::core::option::Option<::prost_types::Timestamp>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoLatencyTracker {
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub sequence_id: ::prost::alloc::string::String,
#[prost(message, repeated, tag = "2")]
#[prost(message, repeated, tag="2")]
pub timestamps: ::prost::alloc::vec::Vec<ProtoTimestampEntry>,
}
// Mouse messages
/// MouseMove message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseMove {
/// Fixed value "MouseMove"
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
#[prost(int32, tag="2")]
pub x: i32,
#[prost(int32, tag = "3")]
#[prost(int32, tag="3")]
pub y: i32,
}
/// MouseMoveAbs message
@@ -33,11 +35,11 @@ pub struct ProtoMouseMove {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseMoveAbs {
/// Fixed value "MouseMoveAbs"
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
#[prost(int32, tag="2")]
pub x: i32,
#[prost(int32, tag = "3")]
#[prost(int32, tag="3")]
pub y: i32,
}
/// MouseWheel message
@@ -45,11 +47,11 @@ pub struct ProtoMouseMoveAbs {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseWheel {
/// Fixed value "MouseWheel"
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
#[prost(int32, tag="2")]
pub x: i32,
#[prost(int32, tag = "3")]
#[prost(int32, tag="3")]
pub y: i32,
}
/// MouseKeyDown message
@@ -57,9 +59,9 @@ pub struct ProtoMouseWheel {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseKeyDown {
/// Fixed value "MouseKeyDown"
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
#[prost(int32, tag="2")]
pub key: i32,
}
/// MouseKeyUp message
@@ -67,19 +69,21 @@ pub struct ProtoMouseKeyDown {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseKeyUp {
/// Fixed value "MouseKeyUp"
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
#[prost(int32, tag="2")]
pub key: i32,
}
// Keyboard messages
/// KeyDown message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoKeyDown {
/// Fixed value "KeyDown"
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
#[prost(int32, tag="2")]
pub key: i32,
}
/// KeyUp message
@@ -87,53 +91,185 @@ pub struct ProtoKeyDown {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoKeyUp {
/// Fixed value "KeyUp"
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
#[prost(int32, tag="2")]
pub key: i32,
}
// Controller messages
/// ControllerAttach message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoControllerAttach {
/// Fixed value "ControllerAttach"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
/// One of the following enums: "ps", "xbox" or "switch"
#[prost(string, tag="2")]
pub id: ::prost::alloc::string::String,
/// Slot number (0-3)
#[prost(int32, tag="3")]
pub slot: i32,
}
/// ControllerDetach message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoControllerDetach {
/// Fixed value "ControllerDetach"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
/// Slot number (0-3)
#[prost(int32, tag="2")]
pub slot: i32,
}
/// ControllerButton message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoControllerButton {
/// Fixed value "ControllerButtons"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
/// Slot number (0-3)
#[prost(int32, tag="2")]
pub slot: i32,
/// Button code (linux input event code)
#[prost(int32, tag="3")]
pub button: i32,
/// true if pressed, false if released
#[prost(bool, tag="4")]
pub pressed: bool,
}
/// ControllerTriggers message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoControllerTrigger {
/// Fixed value "ControllerTriggers"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
/// Slot number (0-3)
#[prost(int32, tag="2")]
pub slot: i32,
/// Trigger number (0 for left, 1 for right)
#[prost(int32, tag="3")]
pub trigger: i32,
/// trigger value (-32768 to 32767)
#[prost(int32, tag="4")]
pub value: i32,
}
/// ControllerSticks message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoControllerStick {
/// Fixed value "ControllerStick"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
/// Slot number (0-3)
#[prost(int32, tag="2")]
pub slot: i32,
/// Stick number (0 for left, 1 for right)
#[prost(int32, tag="3")]
pub stick: i32,
/// X axis value (-32768 to 32767)
#[prost(int32, tag="4")]
pub x: i32,
/// Y axis value (-32768 to 32767)
#[prost(int32, tag="5")]
pub y: i32,
}
/// ControllerAxis message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoControllerAxis {
/// Fixed value "ControllerAxis"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
/// Slot number (0-3)
#[prost(int32, tag="2")]
pub slot: i32,
/// Axis number (0 for d-pad horizontal, 1 for d-pad vertical)
#[prost(int32, tag="3")]
pub axis: i32,
/// axis value (-1 to 1)
#[prost(int32, tag="4")]
pub value: i32,
}
/// ControllerRumble message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoControllerRumble {
/// Fixed value "ControllerRumble"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
/// Slot number (0-3)
#[prost(int32, tag="2")]
pub slot: i32,
/// Low frequency rumble (0-65535)
#[prost(int32, tag="3")]
pub low_frequency: i32,
/// High frequency rumble (0-65535)
#[prost(int32, tag="4")]
pub high_frequency: i32,
/// Duration in milliseconds
#[prost(int32, tag="5")]
pub duration: i32,
}
/// Union of all Input types
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoInput {
#[prost(oneof = "proto_input::InputType", tags = "1, 2, 3, 4, 5, 6, 7")]
#[prost(oneof="proto_input::InputType", tags="1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14")]
pub input_type: ::core::option::Option<proto_input::InputType>,
}
/// Nested message and enum types in `ProtoInput`.
pub mod proto_input {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Oneof)]
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum InputType {
#[prost(message, tag = "1")]
#[prost(message, tag="1")]
MouseMove(super::ProtoMouseMove),
#[prost(message, tag = "2")]
#[prost(message, tag="2")]
MouseMoveAbs(super::ProtoMouseMoveAbs),
#[prost(message, tag = "3")]
#[prost(message, tag="3")]
MouseWheel(super::ProtoMouseWheel),
#[prost(message, tag = "4")]
#[prost(message, tag="4")]
MouseKeyDown(super::ProtoMouseKeyDown),
#[prost(message, tag = "5")]
#[prost(message, tag="5")]
MouseKeyUp(super::ProtoMouseKeyUp),
#[prost(message, tag = "6")]
#[prost(message, tag="6")]
KeyDown(super::ProtoKeyDown),
#[prost(message, tag = "7")]
#[prost(message, tag="7")]
KeyUp(super::ProtoKeyUp),
#[prost(message, tag="8")]
ControllerAttach(super::ProtoControllerAttach),
#[prost(message, tag="9")]
ControllerDetach(super::ProtoControllerDetach),
#[prost(message, tag="10")]
ControllerButton(super::ProtoControllerButton),
#[prost(message, tag="11")]
ControllerTrigger(super::ProtoControllerTrigger),
#[prost(message, tag="12")]
ControllerStick(super::ProtoControllerStick),
#[prost(message, tag="13")]
ControllerAxis(super::ProtoControllerAxis),
#[prost(message, tag="14")]
ControllerRumble(super::ProtoControllerRumble),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMessageBase {
#[prost(string, tag = "1")]
#[prost(string, tag="1")]
pub payload_type: ::prost::alloc::string::String,
#[prost(message, optional, tag = "2")]
#[prost(message, optional, tag="2")]
pub latency: ::core::option::Option<ProtoLatencyTracker>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMessageInput {
#[prost(message, optional, tag = "1")]
#[prost(message, optional, tag="1")]
pub message_base: ::core::option::Option<ProtoMessageBase>,
#[prost(message, optional, tag = "2")]
#[prost(message, optional, tag="2")]
pub data: ::core::option::Option<ProtoInput>,
}
// @@protoc_insertion_point(module)