feat(runner): More runner improvements (#294)

## Description
Whew..

- Steam can now run without namespaces using live-patcher (because
Docker..)
- Improved NVIDIA GPU selection and handling
- Pipeline tests for GPU picking logic
- Optimizations and cleanup all around
- SSH (by default disabled) for easier instance debugging.
- CachyOS' Proton because that works without namespaces (couldn't figure
out how to enable automatically in Steam yet..)
- Package updates and partial removal of futures (libp2p is going to
switch to Tokio in next release hopefully)



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- SSH server can now be enabled within the container for remote access
when configured.
- Added persistent live patching for Steam runtime entrypoints to
improve compatibility with namespace-less applications.
- Enhanced GPU selection with multi-GPU support and PCI bus ID matching
for improved hardware compatibility.
- Improved encoder selection by runtime testing of video encoders for
better reliability.
  - Added WebSocket transport support in peer-to-peer networking.
- Added flexible compositor and application launching with configurable
commands and improved socket handling.

- **Bug Fixes**
- Addressed NVIDIA-specific GStreamer issues by setting new environment
variables.
  - Improved error handling and logging for GPU and encoder selection.
- Fixed process monitoring to handle patcher restarts and added cleanup
logic.
- Added GStreamer cache clearing workaround for Wayland socket failures.

- **Improvements**
- Real-time logging of container processes to standard output and error
for easier monitoring.
- Enhanced process management and reduced CPU usage in protocol handling
loops.
- Updated dependency versions for greater stability and feature support.
  - Improved audio capture defaults and expanded audio pipeline support.
- Enhanced video pipeline setup with conditional handling for different
encoder APIs and DMA-BUF support.
- Refined concurrency and lifecycle management in protocol messaging for
increased robustness.
- Consistent namespace usage and updated crate references across the
codebase.
- Enhanced SSH configuration with key management, port customization,
and startup verification.
  - Improved GPU and video encoder integration in pipeline construction.
- Simplified error handling and consolidated write operations in
protocol streams.
- Removed Ludusavi installation from container image and updated package
installations.

- **Other**
- Minor formatting and style changes for better code readability and
maintainability.
- Docker build context now ignores `.idea` directory to streamline
builds.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
Kristian Ollikainen
2025-07-07 09:06:48 +03:00
committed by GitHub
parent 191c59d230
commit 41dca22d9d
21 changed files with 2049 additions and 641 deletions

View File

@@ -1,3 +1,4 @@
**/target/
**/.git
**/.env
**/.env
**/.idea

View File

@@ -86,7 +86,7 @@ RUN --mount=type=cache,target=/var/cache/pacman/pkg \
libxkbcommon wayland gstreamer gst-plugins-base gst-plugins-good libinput
# Clone repository
RUN git clone -b dev-dmabuf https://github.com/DatCaptainHorse/gst-wayland-display.git
RUN git clone --depth 1 -b "dev-dmabuf" https://github.com/DatCaptainHorse/gst-wayland-display.git
#--------------------------------------------------------------------
FROM gst-wayland-deps AS gst-wayland-planner
@@ -135,11 +135,12 @@ RUN --mount=type=cache,target=/var/cache/pacman/pkg \
vulkan-intel lib32-vulkan-intel vpl-gpu-rt \
vulkan-radeon lib32-vulkan-radeon \
mesa \
steam steam-native-runtime gtk3 lib32-gtk3 \
sudo xorg-xwayland seatd libinput gamescope mangohud \
steam steam-native-runtime proton-cachyos gtk3 lib32-gtk3 \
sudo xorg-xwayland seatd libinput gamescope mangohud wlr-randr \
libssh2 curl wget \
pipewire pipewire-pulse pipewire-alsa wireplumber \
noto-fonts-cjk supervisor jq chwd lshw pacman-contrib && \
noto-fonts-cjk supervisor jq chwd lshw pacman-contrib \
openssh && \
# GStreamer stack
pacman -Sy --needed --noconfirm \
gstreamer gst-plugins-base gst-plugins-good \
@@ -153,14 +154,6 @@ RUN --mount=type=cache,target=/var/cache/pacman/pkg \
paccache -rk1 && \
rm -rf /usr/share/{info,man,doc}/*
### Application Installation ###
ARG LUDUSAVI_VERSION="0.28.0"
RUN curl -fsSL -o ludusavi.tar.gz \
"https://github.com/mtkennerly/ludusavi/releases/download/v${LUDUSAVI_VERSION}/ludusavi-v${LUDUSAVI_VERSION}-linux.tar.gz" && \
tar -xzvf ludusavi.tar.gz && \
mv ludusavi /usr/bin/ && \
rm ludusavi.tar.gz
### User Configuration ###
ENV USER="nestri" \
UID=1000 \

View File

@@ -0,0 +1,10 @@
#!/bin/bash
while true; do
case "$1" in
--*) shift ;;
*) break ;;
esac
done
exec "$@"

View File

@@ -13,7 +13,7 @@ log() {
# Ensures user directory ownership
chown_user_directory() {
local user_group="${USER}:${GID}"
if ! chown -R -h --no-preserve-root "$user_group" "${HOME}" 2>/dev/null; then
if ! chown -h --no-preserve-root "$user_group" "${HOME}" 2>/dev/null; then
echo "Error: Failed to change ownership of ${HOME} to ${user_group}" >&2
return 1
fi
@@ -36,6 +36,12 @@ wait_for_socket() {
return 1
}
# Prepares environment for namespace-less applications (like Steam)
setup_namespaceless() {
rm -f /run/systemd/container || true
mkdir -p /run/pressure-vessel || true
}
# Ensures cache directory exists
setup_cache() {
log "Setting up NVIDIA driver cache directory at $CACHE_DIR..."
@@ -103,8 +109,10 @@ get_nvidia_installer() {
install_nvidia_driver() {
local filename="$1"
log "Installing NVIDIA driver components from $filename..."
sudo ./"$filename" \
bash ./"$filename" \
--silent \
--skip-depmod \
--skip-module-unload \
--no-kernel-module \
--install-compat32-libs \
--no-nouveau-check \
@@ -116,18 +124,102 @@ install_nvidia_driver() {
log "Error: NVIDIA driver installation failed."
return 1
}
# Install CUDA package
log "Checking if CUDA is already installed"
if ! pacman -Q cuda &>/dev/null; then
log "Installing CUDA package"
pacman -S --noconfirm cuda --assume-installed opencl-nvidia
else
log "CUDA package is already installed, skipping"
fi
log "NVIDIA driver installation completed."
return 0
}
function log_gpu_info {
log_gpu_info() {
if ! declare -p vendor_devices &>/dev/null; then
log "Warning: vendor_devices array is not defined"
return
fi
log "Detected GPUs:"
for vendor in "${!vendor_devices[@]}"; do
log "> $vendor: ${vendor_devices[$vendor]}"
done
}
configure_ssh() {
# Return early if SSH not enabled
if [ -z "${SSH_ENABLE_PORT+x}" ] || [ "${SSH_ENABLE_PORT:-0}" -eq 0 ]; then
return 0
fi
# Check if we have required key
if [ -z "${SSH_ALLOWED_KEY+x}" ] || [ -z "${SSH_ALLOWED_KEY}" ]; then
return 0
fi
log "Configuring SSH server on port ${SSH_ENABLE_PORT} with public key authentication"
# Ensure SSH host keys exist
ssh-keygen -A 2>/dev/null || {
log "Error: Failed to generate SSH host keys"
return 1
}
# Create .ssh directory and authorized_keys file for nestri user
mkdir -p /home/nestri/.ssh
echo "${SSH_ALLOWED_KEY}" > /home/nestri/.ssh/authorized_keys
chmod 700 /home/nestri/.ssh
chmod 600 /home/nestri/.ssh/authorized_keys
chown -R nestri:nestri /home/nestri/.ssh
# Update SSHD config
sed -i -E "s/^#?Port .*/Port ${SSH_ENABLE_PORT}/" /etc/ssh/sshd_config || {
log "Error: Failed to update SSH port configuration"
return 1
}
# Configure secure SSH settings
{
echo "PasswordAuthentication no"
echo "PermitRootLogin no"
echo "ChallengeResponseAuthentication no"
echo "UsePAM no"
echo "PubkeyAuthentication yes"
} | while read -r line; do
grep -qF "$line" /etc/ssh/sshd_config || echo "$line" >> /etc/ssh/sshd_config
done
# Start SSH server
log "Starting SSH server on port ${SSH_ENABLE_PORT}"
/usr/sbin/sshd -D -p "${SSH_ENABLE_PORT}" &
SSH_PID=$!
# Verify the process started
if ! ps -p $SSH_PID > /dev/null 2>&1; then
log "Error: SSH server failed to start"
return 1
fi
log "SSH server started with PID ${SSH_PID}"
return 0
}
main() {
# Configure SSH
if [ -n "${SSH_ENABLE_PORT+x}" ] && [ "${SSH_ENABLE_PORT:-0}" -ne 0 ] && \
[ -n "${SSH_ALLOWED_KEY+x}" ] && [ -n "${SSH_ALLOWED_KEY}" ]; then
if ! configure_ssh; then
log "Error: SSH configuration failed with given variables - exiting"
exit 1
fi
else
log "SSH not configured (missing SSH_ENABLE_PORT or SSH_ALLOWED_KEY)"
fi
# Wait for required sockets
wait_for_socket "/run/dbus/system_bus_socket" "DBus" || exit 1
wait_for_socket "/run/user/${UID}/pipewire-0" "PipeWire" || exit 1
@@ -215,6 +307,10 @@ main() {
log "Ensuring user directory permissions..."
chown_user_directory || exit 1
# Setup namespaceless env
log "Applying namespace-less configuration"
setup_namespaceless
# Switch to nestri user
log "Switching to nestri user for application startup..."
if [[ ! -x /etc/nestri/entrypoint_nestri.sh ]]; then

View File

@@ -1,5 +1,4 @@
#!/bin/bash
set -euo pipefail
log() {
echo "[$(date +'%Y-%m-%d %H:%M:%S')] $1"
@@ -50,6 +49,70 @@ kill_if_running() {
fi
}
# Starts up Steam namespace-less live-patcher
start_steam_namespaceless_patcher() {
kill_if_running "${PATCHER_PID:-}" "steam-patcher"
local entrypoints=(
"${HOME}/.local/share/Steam/steamrt64/steam-runtime-steamrt/_v2-entry-point"
"${HOME}/.local/share/Steam/steamapps/common/SteamLinuxRuntime_soldier/_v2-entry-point"
"${HOME}/.local/share/Steam/steamapps/common/SteamLinuxRuntime_sniper/_v2-entry-point"
# < Add more entrypoints here if needed >
)
local custom_entrypoint="/etc/nestri/_v2-entry-point"
local temp_entrypoint="/tmp/_v2-entry-point.padded"
if [[ ! -f "$custom_entrypoint" ]]; then
log "Error: Custom _v2-entry-point not found at $custom_entrypoint"
exit 1
fi
log "Starting Steam _v2-entry-point patcher..."
(
while true; do
for i in "${!entrypoints[@]}"; do
local steam_entrypoint="${entrypoints[$i]}"
if [[ -f "$steam_entrypoint" ]]; then
# Get original file size
local original_size
original_size=$(stat -c %s "$steam_entrypoint" 2>/dev/null)
if [[ -z "$original_size" ]] || [[ "$original_size" -eq 0 ]]; then
log "Warning: Could not determine size of $steam_entrypoint, retrying..."
continue
fi
# Copy custom entrypoint to temp location
cp "$custom_entrypoint" "$temp_entrypoint" 2>/dev/null || {
log "Warning: Failed to copy custom entrypoint to $temp_entrypoint"
continue
}
# Pad the temporary file to match original size
if (( $(stat -c %s "$temp_entrypoint") < original_size )); then
truncate -s "$original_size" "$temp_entrypoint" 2>/dev/null || {
log "Warning: Failed to pad $temp_entrypoint to $original_size bytes"
continue
}
fi
# Copy padded file to Steam's entrypoint, if contents differ
if ! cmp -s "$temp_entrypoint" "$steam_entrypoint"; then
cp "$temp_entrypoint" "$steam_entrypoint" 2>/dev/null || {
log "Warning: Failed to patch $steam_entrypoint"
}
fi
fi
done
# Sleep for 1s
sleep 1
done
) &
PATCHER_PID=$!
log "Steam _v2-entry-point patcher started (PID: $PATCHER_PID)"
}
# Starts nestri-server
start_nestri_server() {
kill_if_running "${NESTRI_PID:-}" "nestri-server"
@@ -71,33 +134,93 @@ start_nestri_server() {
done
log "Error: Wayland display 'wayland-1' not available."
# Workaround for gstreamer being bit slow at times
log "Clearing gstreamer cache.."
rm -rf "${HOME}/.cache/gstreamer-1.0" 2>/dev/null || true
increment_retry "nestri-server"
restart_chain
}
# Starts compositor (gamescope) with Steam
# Starts compositor with optional application
start_compositor() {
kill_if_running "${COMPOSITOR_PID:-}" "compositor"
kill_if_running "${APP_PID:-}" "application"
log "Starting compositor with Steam..."
rm -rf /tmp/.X11-unix && mkdir -p /tmp/.X11-unix && chown nestri:nestri /tmp/.X11-unix
WAYLAND_DISPLAY=wayland-1 gamescope --backend wayland -g -f -e --rt --mangoapp -W "${WIDTH}" -H "${HEIGHT}" -- steam-native -tenfoot -cef-force-gpu &
COMPOSITOR_PID=$!
# Set default values only if variables are unset (not empty)
if [[ -z "${NESTRI_LAUNCH_CMD+x}" ]]; then
NESTRI_LAUNCH_CMD="steam-native -tenfoot -cef-force-gpu"
fi
if [[ -z "${NESTRI_LAUNCH_COMPOSITOR+x}" ]]; then
NESTRI_LAUNCH_COMPOSITOR="gamescope --backend wayland --force-grab-cursor -g -f --rt --mangoapp -W ${WIDTH} -H ${HEIGHT} -r ${FRAMERATE:-60}"
fi
log "Waiting for compositor to initialize..."
COMPOSITOR_SOCKET="${XDG_RUNTIME_DIR}/gamescope-0"
for ((i=1; i<=15; i++)); do
if [[ -e "$COMPOSITOR_SOCKET" ]]; then
log "Compositor initialized, gamescope-0 ready."
sleep 2
return
# Start Steam patcher only if Steam command is present
if [[ -n "${NESTRI_LAUNCH_CMD}" ]] && [[ "$NESTRI_LAUNCH_CMD" == *"steam"* ]]; then
start_steam_namespaceless_patcher
fi
# Launch compositor if configured
if [[ -n "${NESTRI_LAUNCH_COMPOSITOR}" ]]; then
local compositor_cmd="$NESTRI_LAUNCH_COMPOSITOR"
local is_gamescope=false
# Check if this is a gamescope command
if [[ "$compositor_cmd" == *"gamescope"* ]]; then
is_gamescope=true
# Append application command for gamescope if needed
if [[ -n "$NESTRI_LAUNCH_CMD" ]] && [[ "$compositor_cmd" != *" -- "* ]]; then
# If steam in launch command, enable gamescope integration via -e
if [[ "$NESTRI_LAUNCH_CMD" == *"steam"* ]]; then
compositor_cmd+=" -e"
fi
compositor_cmd+=" -- $NESTRI_LAUNCH_CMD"
fi
fi
sleep 1
done
log "Error: Compositor did not initialize."
increment_retry "compositor"
start_compositor
log "Starting compositor: $compositor_cmd"
WAYLAND_DISPLAY=wayland-1 /bin/bash -c "$compositor_cmd" &
COMPOSITOR_PID=$!
# Wait for appropriate socket based on compositor type
if $is_gamescope; then
COMPOSITOR_SOCKET="${XDG_RUNTIME_DIR}/gamescope-0"
log "Waiting for gamescope socket..."
else
COMPOSITOR_SOCKET="${XDG_RUNTIME_DIR}/wayland-0"
log "Waiting for wayland-0 socket..."
fi
for ((i=1; i<=15; i++)); do
if [[ -e "$COMPOSITOR_SOCKET" ]]; then
log "Compositor socket ready ($COMPOSITOR_SOCKET)."
# Patch resolution with wlr-randr for non-gamescope compositors
if ! $is_gamescope; then
local OUTPUT_NAME
OUTPUT_NAME=$(WAYLAND_DISPLAY=wayland-0 wlr-randr --json | jq -r '.[] | select(.enabled == true) | .name' | head -n 1)
if [ -z "$OUTPUT_NAME" ]; then
log "Warning: No enabled outputs detected. Skipping wlr-randr resolution patch."
return
fi
WAYLAND_DISPLAY=wayland-0 wlr-randr --output "$OUTPUT_NAME" --custom-mode "$WIDTH"x"$HEIGHT"
log "Patched resolution with wlr-randr."
fi
return
fi
sleep 1
done
log "Warning: Compositor socket not found after 15 seconds ($COMPOSITOR_SOCKET)."
else
# Launch standalone application if no compositor
if [[ -n "${NESTRI_LAUNCH_CMD}" ]]; then
log "Starting application: $NESTRI_LAUNCH_CMD"
WAYLAND_DISPLAY=wayland-1 /bin/bash -c "$NESTRI_LAUNCH_CMD" &
APP_PID=$!
else
log "No compositor or application configured."
fi
fi
}
# Increments retry counter
@@ -113,21 +236,24 @@ increment_retry() {
# Restarts the chain
restart_chain() {
log "Restarting nestri-server and compositor..."
RETRY_COUNT=0
start_nestri_server
}
# Cleans up processes
cleanup() {
local exit_code=$?
log "Terminating processes..."
kill_if_running "${NESTRI_PID:-}" "nestri-server"
kill_if_running "${COMPOSITOR_PID:-}" "compositor"
exit 0
kill_if_running "${APP_PID:-}" "application"
kill_if_running "${PATCHER_PID:-}" "steam-patcher"
rm -f "/tmp/_v2-entry-point.padded" 2>/dev/null
exit $exit_code
}
# Monitor processes for unexpected exits
main_loop() {
trap cleanup SIGINT SIGTERM
trap cleanup SIGINT SIGTERM EXIT
while true; do
sleep 1
@@ -141,6 +267,16 @@ main_loop() {
log "compositor died."
increment_retry "compositor"
start_compositor
# Check application
elif [[ -n "${APP_PID:-}" ]] && ! kill -0 "${APP_PID}" 2>/dev/null; then
log "application died."
increment_retry "application"
start_compositor
# Check patcher
elif [[ -n "${PATCHER_PID:-}" ]] && ! kill -0 "${PATCHER_PID}" 2>/dev/null; then
log "steam-patcher died."
increment_retry "steam-patcher"
start_steam_namespaceless_patcher
fi
done
}

View File

@@ -1,5 +1,4 @@
#!/bin/bash
set -euo pipefail
export XDG_RUNTIME_DIR=/run/user/${UID}/
export XDG_SESSION_TYPE=x11
@@ -11,3 +10,7 @@ export PROTON_NO_FSYNC=1
# Sleeker Mangohud preset :)
export MANGOHUD_CONFIG=preset=2
# Make gstreamer GL elements work without display output (NVIDIA issue..)
export GST_GL_API=gles2
export GST_GL_WINDOW=surfaceless

View File

@@ -54,3 +54,6 @@ autorestart=false
autostart=true
startretries=0
priority=10
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
redirect_stderr=true

File diff suppressed because it is too large Load Diff

View File

@@ -8,25 +8,26 @@ name = "nestri-server"
path = "src/main.rs"
[dependencies]
gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", branch = "main", features = ["v1_26"] }
gst-webrtc = { package = "gstreamer-webrtc", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", branch = "main", features = ["v1_26"] }
gstrswebrtc = { package = "gst-plugin-webrtc", git = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs", branch = "main" }
gstreamer = { version = "0.23", features = ["v1_26"] }
gstreamer-webrtc = { version = "0.23", features = ["v1_26"] }
gst-plugin-webrtc = { version = "0.13", features = ["v1_22"] }
serde = {version = "1.0", features = ["derive"] }
tokio = { version = "1.44", features = ["full"] }
clap = { version = "4.5", features = ["env"] }
tokio = { version = "1.45", features = ["full"] }
tokio-stream = { version = "0.1", features = ["full"] }
clap = { version = "4.5", features = ["env", "derive"] }
serde_json = "1.0"
webrtc = "0.13"
regex = "1.11"
rand = "0.9"
rustls = { version = "0.23", features = ["ring"] }
tracing = "0.1"
tracing-subscriber = "0.3"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
chrono = "0.4"
futures-util = "0.3"
prost = "0.13"
prost-types = "0.13"
prost = "0.14"
prost-types = "0.14"
parking_lot = "0.12"
atomic_refcell = "0.1"
byteorder = "1.5"
libp2p = { version = "0.55", features = ["identify", "dns", "tcp", "noise", "ping", "tokio", "serde", "yamux", "macros"] }
libp2p-stream = "0.3.0-alpha"
libp2p = { version = "0.55", features = ["identify", "dns", "tcp", "noise", "ping", "tokio", "serde", "yamux", "macros", "websocket", "autonat"] }
libp2p-stream = { version = "0.3.0-alpha" }
dashmap = "6.1"

View File

@@ -1,7 +1,7 @@
use crate::args::encoding_args::AudioCaptureMethod;
use crate::enc_helper::{AudioCodec, EncoderType, VideoCodec};
use clap::{Arg, Command, value_parser};
use clap::builder::{BoolishValueParser, NonEmptyStringValueParser};
use clap::{Arg, Command, value_parser};
pub mod app_args;
pub mod device_args;
@@ -89,7 +89,7 @@ impl Args {
.env("GPU_INDEX")
.help("GPU to use by index")
.value_parser(value_parser!(i32).range(-1..))
.default_value("-1")
.default_value("-1"),
)
.arg(
Arg::new("gpu-card-path")
@@ -160,7 +160,7 @@ impl Args {
.env("AUDIO_CAPTURE_METHOD")
.help("Audio capture method")
.value_parser(value_parser!(AudioCaptureMethod))
.default_value("pipewire"),
.default_value("pulseaudio"),
)
.arg(
Arg::new("audio-codec")

View File

@@ -55,7 +55,11 @@ impl AppArgs {
tracing::info!("AppArgs:");
tracing::info!("> verbose: {}", self.verbose);
tracing::info!("> debug: {}", self.debug);
tracing::info!("> resolution: '{}x{}'", self.resolution.0, self.resolution.1);
tracing::info!(
"> resolution: '{}x{}'",
self.resolution.0,
self.resolution.1
);
tracing::info!("> framerate: {}", self.framerate);
tracing::info!("> relay_url: '{}'", self.relay_url);
tracing::info!("> room: '{}'", self.room);

View File

@@ -19,10 +19,7 @@ impl DeviceArgs {
.get_one::<String>("gpu-name")
.unwrap_or(&"".to_string())
.clone(),
gpu_index: matches
.get_one::<i32>("gpu-index")
.unwrap_or(&-1)
.clone(),
gpu_index: matches.get_one::<i32>("gpu-index").unwrap_or(&-1).clone(),
gpu_card_path: matches
.get_one::<String>("gpu-card-path")
.unwrap_or(&"".to_string())

View File

@@ -120,20 +120,11 @@ impl VideoEncodingOptions {
.unwrap(),
}),
RateControlMethod::CBR => RateControl::CBR(RateControlCBR {
target_bitrate: matches
.get_one::<u32>("video-bitrate")
.unwrap()
.clone(),
target_bitrate: matches.get_one::<u32>("video-bitrate").unwrap().clone(),
}),
RateControlMethod::VBR => RateControl::VBR(RateControlVBR {
target_bitrate: matches
.get_one::<u32>("video-bitrate")
.unwrap()
.clone(),
max_bitrate: matches
.get_one::<u32>("video-bitrate-max")
.unwrap()
.clone(),
target_bitrate: matches.get_one::<u32>("video-bitrate").unwrap().clone(),
max_bitrate: matches.get_one::<u32>("video-bitrate-max").unwrap().clone(),
}),
},
},
@@ -209,20 +200,11 @@ impl AudioEncodingOptions {
.unwrap_or(&RateControlMethod::CBR)
{
RateControlMethod::CBR => RateControl::CBR(RateControlCBR {
target_bitrate: matches
.get_one::<u32>("audio-bitrate")
.unwrap()
.clone(),
target_bitrate: matches.get_one::<u32>("audio-bitrate").unwrap().clone(),
}),
RateControlMethod::VBR => RateControl::VBR(RateControlVBR {
target_bitrate: matches
.get_one::<u32>("audio-bitrate")
.unwrap()
.clone(),
max_bitrate: matches
.get_one::<u32>("audio-bitrate-max")
.unwrap()
.clone(),
target_bitrate: matches.get_one::<u32>("audio-bitrate").unwrap().clone(),
max_bitrate: matches.get_one::<u32>("audio-bitrate-max").unwrap().clone(),
}),
wot => panic!("Invalid rate control method for audio: {}", wot.as_str()),
},

View File

@@ -1,7 +1,7 @@
use crate::args::encoding_args::RateControl;
use crate::gpu::{self, GPUInfo, get_gpu_by_card_path, get_gpus_by_vendor};
use crate::gpu::{GPUInfo, get_gpu_by_card_path, get_gpus_by_vendor, get_nvidia_gpu_by_cuda_id};
use clap::ValueEnum;
use gst::prelude::*;
use gstreamer::prelude::*;
use std::error::Error;
use std::str::FromStr;
@@ -107,7 +107,7 @@ impl EncoderType {
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct VideoEncoderInfo {
pub name: String,
pub codec: VideoCodec,
@@ -146,9 +146,9 @@ impl VideoEncoderInfo {
self.parameters.push((key.into(), value.into()));
}
pub fn apply_parameters(&self, element: &gst::Element, verbose: bool) {
pub fn apply_parameters(&self, element: &gstreamer::Element, verbose: bool) {
for (key, value) in &self.parameters {
if element.has_property(key) {
if element.has_property(key, None) {
if verbose {
tracing::debug!("Setting property {} to {}", key, value);
}
@@ -191,7 +191,7 @@ where
F: FnMut(&str) -> Option<(String, String)>,
{
let mut encoder_optz = encoder.clone();
let element = match gst::ElementFactory::make(&encoder_optz.name).build() {
let element = match gstreamer::ElementFactory::make(&encoder_optz.name).build() {
Ok(e) => e,
Err(_) => return encoder_optz, // Return original if element creation fails
};
@@ -329,16 +329,15 @@ pub fn encoder_low_latency_params(
encoder_optz
}
pub fn get_compatible_encoders() -> Vec<VideoEncoderInfo> {
pub fn get_compatible_encoders(gpus: &Vec<GPUInfo>) -> Vec<VideoEncoderInfo> {
let mut encoders = Vec::new();
let registry = gst::Registry::get();
let gpus = gpu::get_gpus();
let registry = gstreamer::Registry::get();
for plugin in registry.plugins() {
for feature in registry.features_by_plugin(plugin.plugin_name().as_str()) {
let encoder_name = feature.name();
let factory = match gst::ElementFactory::find(encoder_name.as_str()) {
let factory = match gstreamer::ElementFactory::find(encoder_name.as_str()) {
Some(f) => f,
None => continue,
};
@@ -376,9 +375,9 @@ pub fn get_compatible_encoders() -> Vec<VideoEncoderInfo> {
match api {
EncoderAPI::QSV | EncoderAPI::VAAPI => {
// Safe property access with panic protection, gstreamer-rs is fun
let path = if element.has_property("device-path") {
let path = if element.has_property("device-path", None) {
Some(element.property::<String>("device-path"))
} else if element.has_property("device") {
} else if element.has_property("device", None) {
Some(element.property::<String>("device"))
} else {
None
@@ -386,13 +385,11 @@ pub fn get_compatible_encoders() -> Vec<VideoEncoderInfo> {
path.and_then(|p| get_gpu_by_card_path(&gpus, &p))
}
EncoderAPI::NVENC if element.has_property("cuda-device-id") => {
EncoderAPI::NVENC if element.has_property("cuda-device-id", None) => {
let cuda_id = element.property::<u32>("cuda-device-id");
get_gpus_by_vendor(&gpus, "nvidia")
.get(cuda_id as usize)
.cloned()
get_nvidia_gpu_by_cuda_id(&gpus, cuda_id as usize)
}
EncoderAPI::AMF if element.has_property("device") => {
EncoderAPI::AMF if element.has_property("device", None) => {
let device_id = element.property::<u32>("device");
get_gpus_by_vendor(&gpus, "amd")
.get(device_id as usize)
@@ -540,3 +537,140 @@ pub fn get_best_compatible_encoder(
Err("No compatible encoder found".into())
}
}
/// Returns the best compatible encoder that also passes test_encoder
pub fn get_best_working_encoder(
encoders: &Vec<VideoEncoderInfo>,
codec: &Codec,
encoder_type: &EncoderType,
dma_buf: bool,
) -> Result<VideoEncoderInfo, Box<dyn Error>> {
let mut candidates = get_encoders_by_videocodec(
encoders,
match codec {
Codec::Video(c) => c,
Codec::Audio(_) => {
return Err("Audio codec not supported for video encoder selection".into());
}
},
);
candidates = get_encoders_by_type(&candidates, encoder_type);
let mut tried = Vec::new();
while !candidates.is_empty() {
let best = get_best_compatible_encoder(&candidates, codec, encoder_type)?;
tracing::info!("Testing encoder: {}", best.name,);
if test_encoder(&best, dma_buf).is_ok() {
return Ok(best);
} else {
// Remove this encoder and try next best
candidates.retain(|e| e != &best);
tried.push(best.name.clone());
}
}
Err(format!("No working encoder found (tried: {:?})", tried).into())
}
/// Test if a pipeline with the given encoder can be created and set to Playing
pub fn test_encoder(encoder: &VideoEncoderInfo, dma_buf: bool) -> Result<(), Box<dyn Error>> {
let src = gstreamer::ElementFactory::make("waylanddisplaysrc").build()?;
if let Some(gpu_info) = &encoder.gpu_info {
src.set_property_from_str("render-node", gpu_info.render_path());
}
let caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let caps = gstreamer::Caps::from_str(&format!(
"{},width=1280,height=720,framerate=30/1{}",
if dma_buf {
"video/x-raw(memory:DMABuf)"
} else {
"video/x-raw"
},
if dma_buf { "" } else { ",format=RGBx" }
))?;
caps_filter.set_property("caps", &caps);
let enc = gstreamer::ElementFactory::make(&encoder.name).build()?;
let sink = gstreamer::ElementFactory::make("fakesink").build()?;
// Apply encoder parameters
encoder.apply_parameters(&enc, false);
// Create pipeline and link elements
let pipeline = gstreamer::Pipeline::new();
if dma_buf && encoder.encoder_api == EncoderAPI::NVENC {
// GL upload element
let glupload = gstreamer::ElementFactory::make("glupload").build()?;
// GL color convert element
let glconvert = gstreamer::ElementFactory::make("glcolorconvert").build()?;
// GL color convert caps
let gl_caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let gl_caps = gstreamer::Caps::from_str("video/x-raw(memory:GLMemory),format=NV12")?;
gl_caps_filter.set_property("caps", &gl_caps);
// CUDA upload element
let cudaupload = gstreamer::ElementFactory::make("cudaupload").build()?;
pipeline.add_many(&[
&src,
&caps_filter,
&glupload,
&glconvert,
&gl_caps_filter,
&cudaupload,
&enc,
&sink,
])?;
gstreamer::Element::link_many(&[
&src,
&caps_filter,
&glupload,
&glconvert,
&gl_caps_filter,
&cudaupload,
&enc,
&sink,
])?;
} else {
let vapostproc = gstreamer::ElementFactory::make("vapostproc").build()?;
// VA caps filter
let va_caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let va_caps = gstreamer::Caps::from_str("video/x-raw(memory:VAMemory),format=NV12")?;
va_caps_filter.set_property("caps", &va_caps);
pipeline.add_many(&[
&src,
&caps_filter,
&vapostproc,
&va_caps_filter,
&enc,
&sink,
])?;
gstreamer::Element::link_many(&[
&src,
&caps_filter,
&vapostproc,
&va_caps_filter,
&enc,
&sink,
])?;
}
let bus = pipeline.bus().ok_or("Pipeline has no bus")?;
let _ = pipeline.set_state(gstreamer::State::Playing);
for msg in bus.iter_timed(gstreamer::ClockTime::from_seconds(2)) {
match msg.view() {
gstreamer::MessageView::Error(err) => {
let err_msg = format!("Pipeline error: {}", err.error());
tracing::error!("Pipeline error, encoder test failed: {}", err_msg);
let _ = pipeline.set_state(gstreamer::State::Null);
return Err(err_msg.into());
}
gstreamer::MessageView::Eos(_) => {
tracing::info!("Pipeline EOS received");
let _ = pipeline.set_state(gstreamer::State::Null);
return Err("Pipeline EOS received, encoder test failed".into());
}
_ => {}
}
}
let _ = pipeline.set_state(gstreamer::State::Null);
Ok(())
}

View File

@@ -3,7 +3,7 @@ use std::fs;
use std::process::Command;
use std::str;
#[derive(Debug, Eq, PartialEq, Clone)]
#[derive(Debug, Eq, PartialEq, Clone, Hash)]
pub enum GPUVendor {
UNKNOWN,
INTEL,
@@ -11,12 +11,13 @@ pub enum GPUVendor {
AMD,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct GPUInfo {
vendor: GPUVendor,
card_path: String,
render_path: String,
device_name: String,
pci_bus_id: String,
}
impl GPUInfo {
@@ -44,6 +45,10 @@ impl GPUInfo {
pub fn device_name(&self) -> &str {
&self.device_name
}
pub fn pci_bus_id(&self) -> &str {
&self.pci_bus_id
}
}
fn get_gpu_vendor(vendor_id: &str) -> GPUVendor {
@@ -71,14 +76,17 @@ pub fn get_gpus() -> Vec<GPUInfo> {
.filter(|(class_id, _, _, _)| matches!(class_id.as_str(), "0300" | "0302" | "0380"))
.filter_map(|(_, vendor_id, device_name, pci_addr)| {
get_dri_device_path(&pci_addr)
.map(|(card, render)| (vendor_id, card, render, device_name))
})
.map(|(vid, card_path, render_path, device_name)| GPUInfo {
vendor: get_gpu_vendor(&vid),
card_path,
render_path,
device_name,
.map(|(card, render)| (vendor_id, card, render, device_name, pci_addr))
})
.map(
|(vid, card_path, render_path, device_name, pci_bus_id)| GPUInfo {
vendor: get_gpu_vendor(&vid),
card_path,
render_path,
device_name,
pci_bus_id,
},
)
.collect()
}
@@ -137,7 +145,6 @@ fn get_dri_device_path(pci_addr: &str) -> Option<(String, String)> {
None
}
// Helper functions remain similar with improved readability:
pub fn get_gpus_by_vendor(gpus: &[GPUInfo], vendor: &str) -> Vec<GPUInfo> {
let target = vendor.to_lowercase();
gpus.iter()
@@ -162,10 +169,42 @@ pub fn get_gpu_by_card_path(gpus: &[GPUInfo], path: &str) -> Option<GPUInfo> {
.cloned()
}
pub fn get_gpu_by_index(gpus: &[GPUInfo], index: i32) -> Option<GPUInfo> {
if index < 0 || index as usize >= gpus.len() {
None
} else {
Some(gpus[index as usize].clone())
pub fn get_nvidia_gpu_by_cuda_id(gpus: &[GPUInfo], cuda_device_id: usize) -> Option<GPUInfo> {
// Check if nvidia-smi is available
if Command::new("nvidia-smi").arg("--help").output().is_err() {
tracing::warn!("nvidia-smi is not available");
return None;
}
// Run nvidia-smi to get information about the CUDA device
let output = Command::new("nvidia-smi")
.args([
"--query-gpu=pci.bus_id",
"--format=csv,noheader",
"-i",
&cuda_device_id.to_string(),
])
.output()
.ok()?;
if !output.status.success() {
return None;
}
// Parse the output to get the PCI bus ID
let pci_bus_id = str::from_utf8(&output.stdout).ok()?.trim().to_uppercase(); // nvidia-smi returns uppercase PCI IDs
// Convert from 00000000:05:00.0 to 05:00.0 if needed
let pci_bus_id = if pci_bus_id.starts_with("00000000:") {
pci_bus_id[9..].to_string() // Skip the domain part
} else if pci_bus_id.starts_with("0000:") {
pci_bus_id[5..].to_string() // Alternate check for older nvidia-smi versions
} else {
pci_bus_id
};
// Find the GPU with the matching PCI bus ID
gpus.iter()
.find(|gpu| gpu.vendor == GPUVendor::NVIDIA && gpu.pci_bus_id.to_uppercase() == pci_bus_id)
.cloned()
}

View File

@@ -8,24 +8,24 @@ mod p2p;
mod proto;
use crate::args::encoding_args;
use crate::enc_helper::EncoderType;
use crate::gpu::GPUVendor;
use crate::enc_helper::{EncoderAPI, EncoderType};
use crate::gpu::{GPUInfo, GPUVendor};
use crate::nestrisink::NestriSignaller;
use crate::p2p::p2p::NestriP2P;
use futures_util::StreamExt;
use gst::prelude::*;
use gstreamer::prelude::*;
use gstrswebrtc::signaller::Signallable;
use gstrswebrtc::webrtcsink::BaseWebRTCSink;
use std::error::Error;
use std::str::FromStr;
use std::sync::Arc;
use tokio_stream::StreamExt;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
// Handles gathering GPU information and selecting the most suitable GPU
fn handle_gpus(args: &args::Args) -> Result<gpu::GPUInfo, Box<dyn Error>> {
fn handle_gpus(args: &args::Args) -> Result<Vec<gpu::GPUInfo>, Box<dyn Error>> {
tracing::info!("Gathering GPU information..");
let gpus = gpu::get_gpus();
let mut gpus = gpu::get_gpus();
if gpus.is_empty() {
return Err("No GPUs found".into());
}
@@ -40,10 +40,11 @@ fn handle_gpus(args: &args::Args) -> Result<gpu::GPUInfo, Box<dyn Error>> {
);
}
// Based on available arguments, pick a GPU
let gpu;
// Additional GPU filtering
if !args.device.gpu_card_path.is_empty() {
gpu = gpu::get_gpu_by_card_path(&gpus, &args.device.gpu_card_path);
if let Some(gpu) = gpu::get_gpu_by_card_path(&gpus, &args.device.gpu_card_path) {
return Ok(Vec::from([gpu]));
}
} else {
// Run all filters that are not empty
let mut filtered_gpus = gpus.clone();
@@ -55,35 +56,43 @@ fn handle_gpus(args: &args::Args) -> Result<gpu::GPUInfo, Box<dyn Error>> {
}
if args.device.gpu_index > -1 {
// get single GPU by index
gpu = gpu::get_gpu_by_index(&filtered_gpus, args.device.gpu_index).or_else(|| {
tracing::warn!("GPU index {} is out of range", args.device.gpu_index);
None
});
let gpu_index = args.device.gpu_index as usize;
if gpu_index >= filtered_gpus.len() {
return Err(format!(
"GPU index {} is out of bounds for available GPUs (0-{})",
gpu_index,
filtered_gpus.len() - 1
)
.into());
}
gpus = Vec::from([filtered_gpus[gpu_index].clone()]);
} else {
// get first GPU
gpu = filtered_gpus
// Filter out unknown vendor GPUs
gpus = filtered_gpus
.into_iter()
.find(|g| *g.vendor() != GPUVendor::UNKNOWN);
.filter(|gpu| *gpu.vendor() != GPUVendor::UNKNOWN)
.collect();
}
}
if gpu.is_none() {
if gpus.is_empty() {
return Err(format!(
"No GPU found with the specified parameters: vendor='{}', name='{}', index='{}', card_path='{}'",
"No GPU(s) found with the specified parameters: vendor='{}', name='{}', index='{}', card_path='{}'",
args.device.gpu_vendor,
args.device.gpu_name,
args.device.gpu_index,
args.device.gpu_card_path
).into());
}
let gpu = gpu.unwrap();
tracing::info!("Selected GPU: '{}'", gpu.device_name());
Ok(gpu)
Ok(gpus)
}
// Handles picking video encoder
fn handle_encoder_video(args: &args::Args) -> Result<enc_helper::VideoEncoderInfo, Box<dyn Error>> {
fn handle_encoder_video(
args: &args::Args,
gpus: &Vec<GPUInfo>,
) -> Result<enc_helper::VideoEncoderInfo, Box<dyn Error>> {
tracing::info!("Getting compatible video encoders..");
let video_encoders = enc_helper::get_compatible_encoders();
let video_encoders = enc_helper::get_compatible_encoders(gpus);
if video_encoders.is_empty() {
return Err("No compatible video encoders found".into());
}
@@ -107,10 +116,11 @@ fn handle_encoder_video(args: &args::Args) -> Result<enc_helper::VideoEncoderInf
video_encoder =
enc_helper::get_encoder_by_name(&video_encoders, &args.encoding.video.encoder)?;
} else {
video_encoder = enc_helper::get_best_compatible_encoder(
video_encoder = enc_helper::get_best_working_encoder(
&video_encoders,
&args.encoding.video.codec,
&args.encoding.video.encoder_type,
args.app.dma_buf,
)?;
}
tracing::info!("Selected video encoder: '{}'", video_encoder.name);
@@ -191,17 +201,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
let nestri_p2p = Arc::new(NestriP2P::new().await?);
let p2p_conn = nestri_p2p.connect(relay_url).await?;
gst::init()?;
gstrswebrtc::plugin_register_static()?;
// Handle GPU selection
let gpu = match handle_gpus(&args) {
Ok(gpu) => gpu,
Err(e) => {
tracing::error!("Failed to find a suitable GPU: {}", e);
return Err(e);
}
};
gstreamer::init()?;
let _ = gstrswebrtc::plugin_register_static(); // Might be already registered, so we'll pass..
if args.app.dma_buf {
if args.encoding.video.encoder_type != EncoderType::HARDWARE {
@@ -214,8 +215,17 @@ async fn main() -> Result<(), Box<dyn Error>> {
}
}
// Handle GPU selection
let gpus = match handle_gpus(&args) {
Ok(gpu) => gpu,
Err(e) => {
tracing::error!("Failed to find a suitable GPU: {}", e);
return Err(e);
}
};
// Handle video encoder selection
let mut video_encoder_info = match handle_encoder_video(&args) {
let mut video_encoder_info = match handle_encoder_video(&args, &gpus) {
Ok(encoder) => encoder,
Err(e) => {
tracing::error!("Failed to find a suitable video encoder: {}", e);
@@ -231,33 +241,35 @@ async fn main() -> Result<(), Box<dyn Error>> {
/*** PIPELINE CREATION ***/
// Create the pipeline
let pipeline = Arc::new(gst::Pipeline::new());
let pipeline = Arc::new(gstreamer::Pipeline::new());
/* Audio */
// Audio Source Element
let audio_source = match args.encoding.audio.capture_method {
encoding_args::AudioCaptureMethod::PULSEAUDIO => {
gst::ElementFactory::make("pulsesrc").build()?
gstreamer::ElementFactory::make("pulsesrc").build()?
}
encoding_args::AudioCaptureMethod::PIPEWIRE => {
gst::ElementFactory::make("pipewiresrc").build()?
gstreamer::ElementFactory::make("pipewiresrc").build()?
}
encoding_args::AudioCaptureMethod::ALSA => {
gstreamer::ElementFactory::make("alsasrc").build()?
}
encoding_args::AudioCaptureMethod::ALSA => gst::ElementFactory::make("alsasrc").build()?,
};
// Audio Converter Element
let audio_converter = gst::ElementFactory::make("audioconvert").build()?;
let audio_converter = gstreamer::ElementFactory::make("audioconvert").build()?;
// Audio Rate Element
let audio_rate = gst::ElementFactory::make("audiorate").build()?;
let audio_rate = gstreamer::ElementFactory::make("audiorate").build()?;
// Required to fix gstreamer opus issue, where quality sounds off (due to wrong sample rate)
let audio_capsfilter = gst::ElementFactory::make("capsfilter").build()?;
let audio_caps = gst::Caps::from_str("audio/x-raw,rate=48000,channels=2").unwrap();
let audio_capsfilter = gstreamer::ElementFactory::make("capsfilter").build()?;
let audio_caps = gstreamer::Caps::from_str("audio/x-raw,rate=48000,channels=2").unwrap();
audio_capsfilter.set_property("caps", &audio_caps);
// Audio Encoder Element
let audio_encoder = gst::ElementFactory::make(audio_encoder.as_str()).build()?;
let audio_encoder = gstreamer::ElementFactory::make(audio_encoder.as_str()).build()?;
audio_encoder.set_property(
"bitrate",
&match &args.encoding.audio.rate_control {
@@ -267,18 +279,27 @@ async fn main() -> Result<(), Box<dyn Error>> {
},
);
// If has "frame-size" (opus), set to 10 for lower latency (below 10 seems to be too low?)
if audio_encoder.has_property("frame-size") {
if audio_encoder.has_property("frame-size", None) {
audio_encoder.set_property_from_str("frame-size", "10");
}
// Audio parse Element
let mut audio_parser = None;
if audio_encoder.name() == "opusenc" {
// Opus encoder requires a parser
audio_parser = Some(gstreamer::ElementFactory::make("opusparse").build()?);
}
/* Video */
// Video Source Element
let video_source = Arc::new(gst::ElementFactory::make("waylanddisplaysrc").build()?);
video_source.set_property_from_str("render-node", gpu.render_path());
let video_source = Arc::new(gstreamer::ElementFactory::make("waylanddisplaysrc").build()?);
if let Some(gpu_info) = &video_encoder_info.gpu_info {
video_source.set_property_from_str("render-node", gpu_info.render_path());
}
// Caps Filter Element (resolution, fps)
let caps_filter = gst::ElementFactory::make("capsfilter").build()?;
let caps = gst::Caps::from_str(&format!(
let caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let caps = gstreamer::Caps::from_str(&format!(
"{},width={},height={},framerate={}/1{}",
if args.app.dma_buf {
"video/x-raw(memory:DMABuf)"
@@ -292,37 +313,70 @@ async fn main() -> Result<(), Box<dyn Error>> {
))?;
caps_filter.set_property("caps", &caps);
// GL Upload element
let glupload = gst::ElementFactory::make("glupload").build()?;
// GL and CUDA elements (NVIDIA only..)
let mut glupload = None;
let mut glconvert = None;
let mut gl_caps_filter = None;
let mut cudaupload = None;
if args.app.dma_buf && video_encoder_info.encoder_api == EncoderAPI::NVENC {
// GL upload element
glupload = Some(gstreamer::ElementFactory::make("glupload").build()?);
// GL color convert element
glconvert = Some(gstreamer::ElementFactory::make("glcolorconvert").build()?);
// GL color convert caps
let caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let gl_caps = gstreamer::Caps::from_str("video/x-raw(memory:GLMemory),format=NV12")?;
caps_filter.set_property("caps", &gl_caps);
gl_caps_filter = Some(caps_filter);
// CUDA upload element
cudaupload = Some(gstreamer::ElementFactory::make("cudaupload").build()?);
}
// GL color convert element
let glcolorconvert = gst::ElementFactory::make("glcolorconvert").build()?;
// GL upload caps filter
let gl_caps_filter = gst::ElementFactory::make("capsfilter").build()?;
let gl_caps = gst::Caps::from_str("video/x-raw(memory:GLMemory),format=NV12")?;
gl_caps_filter.set_property("caps", &gl_caps);
// GL download element (needed only for DMA-BUF outside NVIDIA GPUs)
let gl_download = gst::ElementFactory::make("gldownload").build()?;
// vapostproc for VA compatible encoders
let mut vapostproc = None;
let mut va_caps_filter = None;
if video_encoder_info.encoder_api == EncoderAPI::VAAPI
|| video_encoder_info.encoder_api == EncoderAPI::QSV
{
vapostproc = Some(gstreamer::ElementFactory::make("vapostproc").build()?);
// VA caps filter
let caps_filter = gstreamer::ElementFactory::make("capsfilter").build()?;
let va_caps = gstreamer::Caps::from_str("video/x-raw(memory:VAMemory),format=NV12")?;
caps_filter.set_property("caps", &va_caps);
va_caps_filter = Some(caps_filter);
}
// Video Converter Element
let video_converter = gst::ElementFactory::make("videoconvert").build()?;
let mut video_converter = None;
if !args.app.dma_buf {
video_converter = Some(gstreamer::ElementFactory::make("videoconvert").build()?);
}
// Video Encoder Element
let video_encoder = gst::ElementFactory::make(video_encoder_info.name.as_str()).build()?;
let video_encoder =
gstreamer::ElementFactory::make(video_encoder_info.name.as_str()).build()?;
video_encoder_info.apply_parameters(&video_encoder, args.app.verbose);
// Video parser Element, required for GStreamer 1.26 as it broke some things..
// Video parser Element
let video_parser;
if video_encoder_info.codec == enc_helper::VideoCodec::H264 {
video_parser = Some(
gst::ElementFactory::make("h264parse")
.property("config-interval", -1i32)
.build()?,
);
} else {
video_parser = None;
match video_encoder_info.codec {
enc_helper::VideoCodec::H264 => {
video_parser = Some(
gstreamer::ElementFactory::make("h264parse")
.property("config-interval", -1i32)
.build()?,
);
}
enc_helper::VideoCodec::H265 => {
video_parser = Some(
gstreamer::ElementFactory::make("h265parse")
.property("config-interval", -1i32)
.build()?,
);
}
_ => {
video_parser = None;
}
}
/* Output */
@@ -335,24 +389,24 @@ async fn main() -> Result<(), Box<dyn Error>> {
webrtcsink.set_property("do-retransmission", false);
/* Queues */
let video_queue = gst::ElementFactory::make("queue2")
let video_queue = gstreamer::ElementFactory::make("queue2")
.property("max-size-buffers", 3u32)
.property("max-size-time", 0u64)
.property("max-size-bytes", 0u32)
.build()?;
let audio_queue = gst::ElementFactory::make("queue2")
let audio_queue = gstreamer::ElementFactory::make("queue2")
.property("max-size-buffers", 3u32)
.property("max-size-time", 0u64)
.property("max-size-bytes", 0u32)
.build()?;
/* Clock Sync */
let video_clocksync = gst::ElementFactory::make("clocksync")
let video_clocksync = gstreamer::ElementFactory::make("clocksync")
.property("sync-to-first", true)
.build()?;
let audio_clocksync = gst::ElementFactory::make("clocksync")
let audio_clocksync = gstreamer::ElementFactory::make("clocksync")
.property("sync-to-first", true)
.build()?;
@@ -360,7 +414,6 @@ async fn main() -> Result<(), Box<dyn Error>> {
pipeline.add_many(&[
webrtcsink.upcast_ref(),
&video_encoder,
&video_converter,
&caps_filter,
&video_queue,
&video_clocksync,
@@ -374,21 +427,35 @@ async fn main() -> Result<(), Box<dyn Error>> {
&audio_source,
])?;
if let Some(video_converter) = &video_converter {
pipeline.add(video_converter)?;
}
if let Some(parser) = &audio_parser {
pipeline.add(parser)?;
}
if let Some(parser) = &video_parser {
pipeline.add(parser)?;
}
// If DMA-BUF is enabled, add glupload, color conversion and caps filter
// If DMA-BUF..
if args.app.dma_buf {
if *gpu.vendor() == GPUVendor::NVIDIA {
pipeline.add_many(&[&glupload, &glcolorconvert, &gl_caps_filter])?;
// VA-API / QSV pipeline
if let (Some(vapostproc), Some(va_caps_filter)) = (&vapostproc, &va_caps_filter) {
pipeline.add_many(&[vapostproc, va_caps_filter])?;
} else {
pipeline.add_many(&[&glupload, &glcolorconvert, &gl_caps_filter, &gl_download])?;
// NVENC pipeline
if let (Some(glupload), Some(glconvert), Some(gl_caps_filter), Some(cudaupload)) =
(&glupload, &glconvert, &gl_caps_filter, &cudaupload)
{
pipeline.add_many(&[glupload, glconvert, gl_caps_filter, cudaupload])?;
}
}
}
// Link main audio branch
gst::Element::link_many(&[
gstreamer::Element::link_many(&[
&audio_source,
&audio_converter,
&audio_rate,
@@ -396,51 +463,62 @@ async fn main() -> Result<(), Box<dyn Error>> {
&audio_queue,
&audio_clocksync,
&audio_encoder,
webrtcsink.upcast_ref(),
])?;
// With DMA-BUF, also link glupload and it's caps
// Link audio parser to audio encoder if present, otherwise just webrtcsink
if let Some(parser) = &audio_parser {
gstreamer::Element::link_many(&[&audio_encoder, parser, webrtcsink.upcast_ref()])?;
} else {
gstreamer::Element::link_many(&[&audio_encoder, webrtcsink.upcast_ref()])?;
}
// With DMA-BUF..
if args.app.dma_buf {
if *gpu.vendor() == GPUVendor::NVIDIA {
gst::Element::link_many(&[
// VA-API / QSV pipeline
if let (Some(vapostproc), Some(va_caps_filter)) = (&vapostproc, &va_caps_filter) {
gstreamer::Element::link_many(&[
&video_source,
&caps_filter,
&video_queue,
&video_clocksync,
&glupload,
&glcolorconvert,
&gl_caps_filter,
&vapostproc,
&va_caps_filter,
&video_encoder,
])?;
} else {
gst::Element::link_many(&[
&video_source,
&caps_filter,
&video_queue,
&video_clocksync,
&glupload,
&glcolorconvert,
&gl_caps_filter,
&gl_download,
&video_encoder,
])?;
// NVENC pipeline
if let (Some(glupload), Some(glconvert), Some(gl_caps_filter), Some(cudaupload)) =
(&glupload, &glconvert, &gl_caps_filter, &cudaupload)
{
gstreamer::Element::link_many(&[
&video_source,
&caps_filter,
&video_queue,
&video_clocksync,
&glupload,
&glconvert,
&gl_caps_filter,
&cudaupload,
&video_encoder,
])?;
}
}
} else {
gst::Element::link_many(&[
gstreamer::Element::link_many(&[
&video_source,
&caps_filter,
&video_queue,
&video_clocksync,
&video_converter,
&video_converter.unwrap(),
&video_encoder,
])?;
}
// Link video parser if present with webrtcsink, otherwise just link webrtc sink
if let Some(parser) = &video_parser {
gst::Element::link_many(&[&video_encoder, parser, webrtcsink.upcast_ref()])?;
gstreamer::Element::link_many(&[&video_encoder, parser, webrtcsink.upcast_ref()])?;
} else {
gst::Element::link_many(&[&video_encoder, webrtcsink.upcast_ref()])?;
gstreamer::Element::link_many(&[&video_encoder, webrtcsink.upcast_ref()])?;
}
// Set QOS
@@ -468,14 +546,17 @@ async fn main() -> Result<(), Box<dyn Error>> {
}
}
// Clean up
tracing::info!("Exiting gracefully..");
Ok(())
}
async fn run_pipeline(pipeline: Arc<gst::Pipeline>) -> Result<(), Box<dyn Error>> {
async fn run_pipeline(pipeline: Arc<gstreamer::Pipeline>) -> Result<(), Box<dyn Error>> {
let bus = { pipeline.bus().ok_or("Pipeline has no bus")? };
{
if let Err(e) = pipeline.set_state(gst::State::Playing) {
if let Err(e) = pipeline.set_state(gstreamer::State::Playing) {
tracing::error!("Failed to start pipeline: {}", e);
return Err("Failed to start pipeline".into());
}
@@ -495,24 +576,24 @@ async fn run_pipeline(pipeline: Arc<gst::Pipeline>) -> Result<(), Box<dyn Error>
}
{
pipeline.set_state(gst::State::Null)?;
pipeline.set_state(gstreamer::State::Null)?;
}
Ok(())
}
async fn listen_for_gst_messages(bus: gst::Bus) -> Result<(), Box<dyn Error>> {
async fn listen_for_gst_messages(bus: gstreamer::Bus) -> Result<(), Box<dyn Error>> {
let bus_stream = bus.stream();
tokio::pin!(bus_stream);
while let Some(msg) = bus_stream.next().await {
match msg.view() {
gst::MessageView::Eos(_) => {
gstreamer::MessageView::Eos(_) => {
tracing::info!("Received EOS");
break;
}
gst::MessageView::Error(err) => {
gstreamer::MessageView::Error(err) => {
let err_msg = format!(
"Error from {:?}: {:?}",
err.src().map(|s| s.path_string()),

View File

@@ -7,9 +7,9 @@ use crate::proto::proto::proto_input::InputType::{
use crate::proto::proto::{ProtoInput, ProtoMessageInput};
use atomic_refcell::AtomicRefCell;
use glib::subclass::prelude::*;
use gst::glib;
use gst::prelude::*;
use gst_webrtc::{WebRTCSDPType, WebRTCSessionDescription, gst_sdp};
use gstreamer::glib;
use gstreamer::prelude::*;
use gstreamer_webrtc::{gst_sdp, WebRTCSDPType, WebRTCSessionDescription};
use gstrswebrtc::signaller::{Signallable, SignallableImpl};
use parking_lot::RwLock as PLRwLock;
use prost::Message;
@@ -20,8 +20,8 @@ use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
pub struct Signaller {
stream_room: PLRwLock<Option<String>>,
stream_protocol: PLRwLock<Option<Arc<NestriStreamProtocol>>>,
wayland_src: PLRwLock<Option<Arc<gst::Element>>>,
data_channel: AtomicRefCell<Option<gst_webrtc::WebRTCDataChannel>>,
wayland_src: PLRwLock<Option<Arc<gstreamer::Element>>>,
data_channel: AtomicRefCell<Option<gstreamer_webrtc::WebRTCDataChannel>>,
}
impl Default for Signaller {
fn default() -> Self {
@@ -51,19 +51,19 @@ impl Signaller {
self.stream_protocol.read().clone()
}
pub fn set_wayland_src(&self, wayland_src: Arc<gst::Element>) {
pub fn set_wayland_src(&self, wayland_src: Arc<gstreamer::Element>) {
*self.wayland_src.write() = Some(wayland_src);
}
pub fn get_wayland_src(&self) -> Option<Arc<gst::Element>> {
pub fn get_wayland_src(&self) -> Option<Arc<gstreamer::Element>> {
self.wayland_src.read().clone()
}
pub fn set_data_channel(&self, data_channel: gst_webrtc::WebRTCDataChannel) {
pub fn set_data_channel(&self, data_channel: gstreamer_webrtc::WebRTCDataChannel) {
match self.data_channel.try_borrow_mut() {
Ok(mut dc) => *dc = Some(data_channel),
Err(_) => gst::warning!(
gst::CAT_DEFAULT,
Err(_) => gstreamer::warning!(
gstreamer::CAT_DEFAULT,
"Failed to set data channel - already borrowed"
),
}
@@ -72,7 +72,7 @@ impl Signaller {
/// Helper method to clean things up
fn register_callbacks(&self) {
let Some(stream_protocol) = self.get_stream_protocol() else {
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
gstreamer::error!(gstreamer::CAT_DEFAULT, "Stream protocol not set");
return;
};
{
@@ -87,7 +87,7 @@ impl Signaller {
&[&"unique-session-id", &answer],
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode SDP message");
gstreamer::error!(gstreamer::CAT_DEFAULT, "Failed to decode SDP message");
}
});
}
@@ -108,7 +108,7 @@ impl Signaller {
],
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode ICE message");
gstreamer::error!(gstreamer::CAT_DEFAULT, "Failed to decode ICE message");
}
});
}
@@ -118,13 +118,16 @@ impl Signaller {
if let Ok(answer) = serde_json::from_slice::<MessageRaw>(&data) {
// Decode room name string
if let Some(room_name) = answer.data.as_str() {
gst::info!(
gst::CAT_DEFAULT,
gstreamer::info!(
gstreamer::CAT_DEFAULT,
"Received OK answer for room: {}",
room_name
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode room name from answer");
gstreamer::error!(
gstreamer::CAT_DEFAULT,
"Failed to decode room name from answer"
);
}
// Send our SDP offer
@@ -137,7 +140,7 @@ impl Signaller {
],
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode answer");
gstreamer::error!(gstreamer::CAT_DEFAULT, "Failed to decode answer");
}
});
}
@@ -147,44 +150,52 @@ impl Signaller {
self_obj.connect_closure(
"webrtcbin-ready",
false,
glib::closure!(move |signaller: &super::NestriSignaller,
_consumer_identifier: &str,
webrtcbin: &gst::Element| {
gst::info!(gst::CAT_DEFAULT, "Adding data channels");
// Create data channels on webrtcbin
let data_channel = Some(
webrtcbin.emit_by_name::<gst_webrtc::WebRTCDataChannel>(
"create-data-channel",
&[
&"nestri-data-channel",
&gst::Structure::builder("config")
.field("ordered", &true)
.field("max-retransmits", &2u32)
.field("priority", "high")
.field("protocol", "raw")
.build(),
],
),
);
if let Some(data_channel) = data_channel {
gst::info!(gst::CAT_DEFAULT, "Data channel created");
if let Some(wayland_src) = signaller.imp().get_wayland_src() {
setup_data_channel(&data_channel, &*wayland_src);
signaller.imp().set_data_channel(data_channel);
glib::closure!(
move |signaller: &super::NestriSignaller,
_consumer_identifier: &str,
webrtcbin: &gstreamer::Element| {
gstreamer::info!(gstreamer::CAT_DEFAULT, "Adding data channels");
// Create data channels on webrtcbin
let data_channel = Some(
webrtcbin.emit_by_name::<gstreamer_webrtc::WebRTCDataChannel>(
"create-data-channel",
&[
&"nestri-data-channel",
&gstreamer::Structure::builder("config")
.field("ordered", &true)
.field("max-retransmits", &2u32)
.field("priority", "high")
.field("protocol", "raw")
.build(),
],
),
);
if let Some(data_channel) = data_channel {
gstreamer::info!(gstreamer::CAT_DEFAULT, "Data channel created");
if let Some(wayland_src) = signaller.imp().get_wayland_src() {
setup_data_channel(&data_channel, &*wayland_src);
signaller.imp().set_data_channel(data_channel);
} else {
gstreamer::error!(
gstreamer::CAT_DEFAULT,
"Wayland display source not set"
);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Wayland display source not set");
gstreamer::error!(
gstreamer::CAT_DEFAULT,
"Failed to create data channel"
);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to create data channel");
}
}),
),
);
}
}
}
impl SignallableImpl for Signaller {
fn start(&self) {
gst::info!(gst::CAT_DEFAULT, "Signaller started");
gstreamer::info!(gstreamer::CAT_DEFAULT, "Signaller started");
// Register message callbacks
self.register_callbacks();
@@ -193,7 +204,7 @@ impl SignallableImpl for Signaller {
// TODO: Re-implement reconnection handling
let Some(stream_room) = self.stream_room.read().clone() else {
gst::error!(gst::CAT_DEFAULT, "Stream room not set");
gstreamer::error!(gstreamer::CAT_DEFAULT, "Stream room not set");
return;
};
@@ -206,7 +217,7 @@ impl SignallableImpl for Signaller {
};
let Some(stream_protocol) = self.get_stream_protocol() else {
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
gstreamer::error!(gstreamer::CAT_DEFAULT, "Stream protocol not set");
return;
};
@@ -216,7 +227,7 @@ impl SignallableImpl for Signaller {
}
fn stop(&self) {
gst::info!(gst::CAT_DEFAULT, "Signaller stopped");
gstreamer::info!(gstreamer::CAT_DEFAULT, "Signaller stopped");
}
fn send_sdp(&self, _session_id: &str, sdp: &WebRTCSessionDescription) {
@@ -229,7 +240,7 @@ impl SignallableImpl for Signaller {
};
let Some(stream_protocol) = self.get_stream_protocol() else {
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
gstreamer::error!(gstreamer::CAT_DEFAULT, "Stream protocol not set");
return;
};
@@ -260,7 +271,7 @@ impl SignallableImpl for Signaller {
};
let Some(stream_protocol) = self.get_stream_protocol() else {
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
gstreamer::error!(gstreamer::CAT_DEFAULT, "Stream protocol not set");
return;
};
@@ -270,7 +281,7 @@ impl SignallableImpl for Signaller {
}
fn end_session(&self, session_id: &str) {
gst::info!(gst::CAT_DEFAULT, "Ending session: {}", session_id);
gstreamer::info!(gstreamer::CAT_DEFAULT, "Ending session: {}", session_id);
}
}
#[glib::object_subclass]
@@ -303,7 +314,10 @@ impl ObjectImpl for Signaller {
}
}
fn setup_data_channel(data_channel: &gst_webrtc::WebRTCDataChannel, wayland_src: &gst::Element) {
fn setup_data_channel(
data_channel: &gstreamer_webrtc::WebRTCDataChannel,
wayland_src: &gstreamer::Element,
) {
let wayland_src = wayland_src.clone();
data_channel.connect_on_message_data(move |_data_channel, data| {
@@ -328,64 +342,64 @@ fn setup_data_channel(data_channel: &gst_webrtc::WebRTCDataChannel, wayland_src:
});
}
fn handle_input_message(input_msg: ProtoInput) -> Option<gst::Event> {
fn handle_input_message(input_msg: ProtoInput) -> Option<gstreamer::Event> {
if let Some(input_type) = input_msg.input_type {
match input_type {
MouseMove(data) => {
let structure = gst::Structure::builder("MouseMoveRelative")
let structure = gstreamer::Structure::builder("MouseMoveRelative")
.field("pointer_x", data.x as f64)
.field("pointer_y", data.y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
Some(gstreamer::event::CustomUpstream::new(structure))
}
MouseMoveAbs(data) => {
let structure = gst::Structure::builder("MouseMoveAbsolute")
let structure = gstreamer::Structure::builder("MouseMoveAbsolute")
.field("pointer_x", data.x as f64)
.field("pointer_y", data.y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
Some(gstreamer::event::CustomUpstream::new(structure))
}
KeyDown(data) => {
let structure = gst::Structure::builder("KeyboardKey")
let structure = gstreamer::Structure::builder("KeyboardKey")
.field("key", data.key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
Some(gstreamer::event::CustomUpstream::new(structure))
}
KeyUp(data) => {
let structure = gst::Structure::builder("KeyboardKey")
let structure = gstreamer::Structure::builder("KeyboardKey")
.field("key", data.key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
Some(gstreamer::event::CustomUpstream::new(structure))
}
MouseWheel(data) => {
let structure = gst::Structure::builder("MouseAxis")
let structure = gstreamer::Structure::builder("MouseAxis")
.field("x", data.x as f64)
.field("y", data.y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
Some(gstreamer::event::CustomUpstream::new(structure))
}
MouseKeyDown(data) => {
let structure = gst::Structure::builder("MouseButton")
let structure = gstreamer::Structure::builder("MouseButton")
.field("button", data.key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
Some(gstreamer::event::CustomUpstream::new(structure))
}
MouseKeyUp(data) => {
let structure = gst::Structure::builder("MouseButton")
let structure = gstreamer::Structure::builder("MouseButton")
.field("button", data.key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
Some(gstreamer::event::CustomUpstream::new(structure))
}
}
} else {

View File

@@ -1,6 +1,6 @@
use crate::p2p::p2p::NestriConnection;
use gst::glib;
use gst::subclass::prelude::*;
use gstreamer::glib;
use gstreamer::subclass::prelude::*;
use gstrswebrtc::signaller::Signallable;
use std::sync::Arc;
@@ -14,7 +14,7 @@ impl NestriSignaller {
pub async fn new(
room: String,
nestri_conn: NestriConnection,
wayland_src: Arc<gst::Element>,
wayland_src: Arc<gstreamer::Element>,
) -> Result<Self, Box<dyn std::error::Error>> {
let obj: Self = glib::Object::new();
obj.imp().set_stream_room(room);

View File

@@ -1,4 +1,4 @@
use futures_util::StreamExt;
use libp2p::futures::StreamExt;
use libp2p::multiaddr::Protocol;
use libp2p::{
Multiaddr, PeerId, Swarm, identify, noise, ping,
@@ -20,6 +20,7 @@ struct NestriBehaviour {
identify: identify::Behaviour,
ping: ping::Behaviour,
stream: libp2p_stream::Behaviour,
autonatv2: libp2p::autonat::v2::client::Behaviour,
}
pub struct NestriP2P {
@@ -36,6 +37,8 @@ impl NestriP2P {
yamux::Config::default,
)?
.with_dns()?
.with_websocket(noise::Config::new, yamux::Config::default)
.await?
.with_behaviour(|key| {
let identify_behaviour = identify::Behaviour::new(identify::Config::new(
"/ipfs/id/1.0.0".to_string(),
@@ -43,11 +46,13 @@ impl NestriP2P {
));
let ping_behaviour = ping::Behaviour::default();
let stream_behaviour = libp2p_stream::Behaviour::default();
let autonatv2_behaviour = libp2p::autonat::v2::client::Behaviour::default();
Ok(NestriBehaviour {
identify: identify_behaviour,
ping: ping_behaviour,
stream: stream_behaviour,
autonatv2: autonatv2_behaviour,
})
})?
.build(),

View File

@@ -1,9 +1,10 @@
use crate::p2p::p2p::NestriConnection;
use crate::p2p::p2p_safestream::SafeStream;
use dashmap::DashMap;
use libp2p::StreamProtocol;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::{self, Duration};
// Cloneable callback type
pub type CallbackInner = dyn Fn(Vec<u8>) + Send + Sync + 'static;
@@ -33,9 +34,11 @@ impl From<Box<CallbackInner>> for Callback {
/// NestriStreamProtocol manages the stream protocol for Nestri connections.
pub struct NestriStreamProtocol {
tx: mpsc::Sender<Vec<u8>>,
tx: Option<mpsc::Sender<Vec<u8>>>,
safe_stream: Arc<SafeStream>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
callbacks: Arc<DashMap<String, Callback>>,
read_handle: Option<tokio::task::JoinHandle<()>>,
write_handle: Option<tokio::task::JoinHandle<()>>,
}
impl NestriStreamProtocol {
const NESTRI_PROTOCOL_STREAM_PUSH: StreamProtocol =
@@ -56,21 +59,35 @@ impl NestriStreamProtocol {
}
};
let (tx, rx) = mpsc::channel(1000);
let sp = NestriStreamProtocol {
tx,
let mut sp = NestriStreamProtocol {
tx: None,
safe_stream: Arc::new(SafeStream::new(push_stream)),
callbacks: Arc::new(RwLock::new(HashMap::new())),
callbacks: Arc::new(DashMap::new()),
read_handle: None,
write_handle: None,
};
// Spawn the loops
sp.spawn_read_loop();
sp.spawn_write_loop(rx);
// Use restart method to initialize the read and write loops
sp.restart()?;
Ok(sp)
}
pub fn restart(&mut self) -> Result<(), Box<dyn std::error::Error>> {
// Return if tx and handles are already initialized
if self.tx.is_some() && self.read_handle.is_some() && self.write_handle.is_some() {
tracing::warn!("NestriStreamProtocol is already running, restart skipped");
return Ok(());
}
let (tx, rx) = mpsc::channel(1000);
self.tx = Some(tx);
self.read_handle = Some(self.spawn_read_loop());
self.write_handle = Some(self.spawn_write_loop(rx));
Ok(())
}
fn spawn_read_loop(&self) -> tokio::task::JoinHandle<()> {
let safe_stream = self.safe_stream.clone();
let callbacks = self.callbacks.clone();
@@ -89,14 +106,22 @@ impl NestriStreamProtocol {
match serde_json::from_slice::<crate::messages::MessageBase>(&data) {
Ok(base_message) => {
let response_type = base_message.payload_type;
let callback = {
let callbacks_lock = callbacks.read().unwrap();
callbacks_lock.get(&response_type).cloned()
};
if let Some(callback) = callback {
// Call the registered callback with the raw data
callback.call(data);
// With DashMap, we don't need explicit locking
// we just get the callback directly if it exists
if let Some(callback) = callbacks.get(&response_type) {
// Execute the callback
if let Err(e) =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
callback.call(data.clone())
}))
{
tracing::error!(
"Callback for response type '{}' panicked: {:?}",
response_type,
e
);
}
} else {
tracing::warn!(
"No callback registered for response type: {}",
@@ -108,6 +133,9 @@ impl NestriStreamProtocol {
tracing::error!("Failed to decode message: {}", e);
}
}
// Add a small sleep to reduce CPU usage
time::sleep(Duration::from_micros(100)).await;
}
})
}
@@ -117,14 +145,20 @@ impl NestriStreamProtocol {
tokio::spawn(async move {
loop {
// Wait for a message from the channel
if let Some(tx_data) = rx.recv().await {
if let Err(e) = safe_stream.send_raw(&tx_data).await {
tracing::error!("Error sending data: {:?}", e);
match rx.recv().await {
Some(tx_data) => {
if let Err(e) = safe_stream.send_raw(&tx_data).await {
tracing::error!("Error sending data: {:?}", e);
}
}
None => {
tracing::info!("Receiver closed, exiting write loop");
break;
}
} else {
tracing::info!("Receiver closed, exiting write loop");
break;
}
// Add a small sleep to reduce CPU usage
time::sleep(Duration::from_micros(100)).await;
}
})
}
@@ -134,16 +168,25 @@ impl NestriStreamProtocol {
message: &M,
) -> Result<(), Box<dyn std::error::Error>> {
let json_data = serde_json::to_vec(message)?;
self.tx.try_send(json_data)?;
let Some(tx) = &self.tx else {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::NotConnected,
if self.read_handle.is_none() && self.write_handle.is_none() {
"NestriStreamProtocol has been shutdown"
} else {
"NestriStreamProtocol is not properly initialized"
},
)));
};
tx.try_send(json_data)?;
Ok(())
}
/// Register a callback for a specific response type
pub fn register_callback<F>(&self, response_type: &str, callback: F)
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
let mut callbacks_lock = self.callbacks.write().unwrap();
callbacks_lock.insert(response_type.to_string(), Callback::new(callback));
self.callbacks
.insert(response_type.to_string(), Callback::new(callback));
}
}

View File

@@ -1,6 +1,6 @@
use byteorder::{BigEndian, ByteOrder};
use futures_util::io::{ReadHalf, WriteHalf};
use futures_util::{AsyncReadExt, AsyncWriteExt};
use libp2p::futures::io::{ReadHalf, WriteHalf};
use libp2p::futures::{AsyncReadExt, AsyncWriteExt};
use prost::Message;
use serde::Serialize;
use serde::de::DeserializeOwned;
@@ -63,21 +63,15 @@ impl SafeStream {
async fn send_with_length_prefix(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
if data.len() > MAX_SIZE {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Data exceeds maximum size",
)));
return Err("Data exceeds maximum size".into());
}
let mut buffer = Vec::with_capacity(4 + data.len());
buffer.extend_from_slice(&(data.len() as u32).to_be_bytes()); // Length prefix
buffer.extend_from_slice(data); // Payload
let mut stream_write = self.stream_write.lock().await;
// Write the 4-byte length prefix
let mut length_prefix = [0u8; 4];
BigEndian::write_u32(&mut length_prefix, data.len() as u32);
stream_write.write_all(&length_prefix).await?;
// Write the actual data
stream_write.write_all(data).await?;
stream_write.write_all(&buffer).await?; // Single write
stream_write.flush().await?;
Ok(())
}
@@ -85,20 +79,16 @@ impl SafeStream {
async fn receive_with_length_prefix(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let mut stream_read = self.stream_read.lock().await;
// Read the 4-byte length prefix
// Read length prefix + data in one syscall
let mut length_prefix = [0u8; 4];
stream_read.read_exact(&mut length_prefix).await?;
let length = BigEndian::read_u32(&length_prefix) as usize;
if length > MAX_SIZE {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Data exceeds maximum size",
)));
return Err("Data exceeds maximum size".into());
}
// Read the actual data
let mut buffer = vec![0; length];
let mut buffer = vec![0u8; length];
stream_read.read_exact(&mut buffer).await?;
Ok(buffer)
}