feat: Custom gst webrtc signaller, runtime GPU driver package install and more (#140)

🔥 🔥

Yes lots of commits because rebasing and all.. thankfully I know Git
just enough to have backups 😅

---------

Co-authored-by: Wanjohi <elviswanjohi47@gmail.com>
Co-authored-by: Kristian Ollikainen <DatCaptainHorse@users.noreply.github.com>
Co-authored-by: Wanjohi <71614375+wanjohiryan@users.noreply.github.com>
Co-authored-by: AquaWolf <3daquawolf@gmail.com>
This commit is contained in:
Kristian Ollikainen
2024-12-08 15:37:36 +02:00
committed by GitHub
parent 20d5ff511e
commit b6196b1c69
27 changed files with 3402 additions and 1349 deletions

2
.gitignore vendored
View File

@@ -48,5 +48,5 @@ bun.lockb
#tests #tests
id_* id_*
# Rust #Rust
target target

2036
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -12,5 +12,6 @@ edition = "2021"
rust-version = "1.80" rust-version = "1.80"
[workspace.dependencies] [workspace.dependencies]
gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", version = "0.24.0" } gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", branch = "main", features = ["v1_24"] }
gst-app = { package = "gstreamer-app", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", version = "0.24.0" } gst-webrtc = { package = "gstreamer-webrtc", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", branch = "main", features = ["v1_24"] }
gstrswebrtc = { package = "gst-plugin-webrtc", git = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs", branch = "main", features = ["v1_22"] }

206
Containerfile.master Normal file
View File

@@ -0,0 +1,206 @@
#! Runs the docker server that handles everything else
#******************************************************************************
# base
#******************************************************************************
FROM archlinux:base-20241027.0.273886 AS base
# How to run - docker run -it --rm --device /dev/dri nestri /bin/bash - DO NOT forget the ports
# TODO: Migrate XDG_RUNTIME_DIR to /run/user/1000
# TODO: Add nestri-server to pulseaudio.conf
# TODO: Add our own entrypoint, with our very own zombie ripper 🧟🏾‍♀️
# FIXME: Add user root to `pulse-access` group as well :D
# TODO: Test the whole damn thing
# Update the pacman repo
RUN \
pacman -Syu --noconfirm
#******************************************************************************
# builder
#******************************************************************************
FROM base AS builder
RUN \
pacman -Su --noconfirm \
base-devel \
git \
sudo \
vim
WORKDIR /scratch
# Allow nobody user to invoke pacman to install packages (as part of makepkg) and modify the system.
# This should never exist in a running image, just used by *-build Docker stages.
RUN \
echo "nobody ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers;
ENV ARTIFACTS=/artifacts \
CARGO_TARGET_DIR=/build
RUN \
mkdir -p /artifacts \
&& mkdir -p /build
RUN \
chgrp nobody /scratch /artifacts /build \
&& chmod g+ws /scratch /artifacts /build
#******************************************************************************
# rust-builder
#******************************************************************************
FROM builder AS rust-builder
RUN \
pacman -Su --noconfirm \
rustup
RUN \
rustup default stable
#******************************************************************************
# nestri-server-builder
#******************************************************************************
# Builds nestri server binary
FROM rust-builder AS nestri-server-builder
RUN \
pacman -Su --noconfirm \
wayland \
vpl-gpu-rt \
gstreamer \
gst-plugin-va \
gst-plugins-base \
gst-plugins-good \
mesa-utils \
weston \
xorg-xwayland
#******************************************************************************
# nestri-server-build
#******************************************************************************
FROM nestri-server-builder AS nestri-server-build
#Allow makepkg to be run as nobody.
RUN chgrp -R nobody /scratch && chmod -R g+ws /scratch
# USER nobody
# Perform the server build.
WORKDIR /scratch/server
RUN \
git clone https://github.com/nestriness/nestri
WORKDIR /scratch/server/nestri
RUN \
git checkout feat/stream \
&& cargo build -j$(nproc) --release
# COPY packages/server/build/ /scratch/server/
# RUN makepkg && cp *.zst "$ARTIFACTS"
#******************************************************************************
# runtime_base_pkgs
#******************************************************************************
FROM base AS runtime_base_pkgs
COPY --from=nestri-server-build /build/release/nestri-server /usr/bin/
#******************************************************************************
# runtime_base
#******************************************************************************
FROM runtime_base_pkgs AS runtime_base
RUN \
pacman -Su --noconfirm \
weston \
sudo \
xorg-xwayland \
gstreamer \
gst-plugins-base \
gst-plugins-good \
gst-plugin-qsv \
gst-plugin-va \
gst-plugin-fmp4 \
mesa \
# Grab GPU encoding packages
# Intel (modern VPL + VA-API)
vpl-gpu-rt \
intel-media-driver \
# AMD/ATI (VA-API)
libva-mesa-driver \
# NVIDIA (proprietary)
nvidia-utils \
# Audio
pulseaudio \
# Supervisor
supervisor
RUN \
# Set up our non-root user $(nestri)
groupadd -g 1000 nestri \
&& useradd -ms /bin/bash nestri -u 1000 -g 1000 \
&& passwd -d nestri \
# Setup Pulseaudio
&& useradd -d /var/run/pulse -s /usr/bin/nologin -G audio pulse \
&& groupadd pulse-access \
&& usermod -aG audio,input,render,video,pulse-access nestri \
&& echo "nestri ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers \
&& echo "Users created" \
# Create an empty machine-id file
&& touch /etc/machine-id
ENV \
XDG_RUNTIME_DIR=/tmp
#******************************************************************************
# runtime
#******************************************************************************
FROM runtime_base AS runtime
# Setup supervisor #
RUN <<-EOF
echo -e "
[supervisord]
user=root
nodaemon=true
loglevel=info
logfile=/tmp/supervisord.log
pidfile=/tmp/supervisord.pid
[program:dbus]
user=root
command=dbus-daemon --system --nofork
logfile=/tmp/dbus.log
pidfile=/tmp/dbus.pid
stopsignal=INT
autostart=true
autorestart=true
priority=1
[program:pulseaudio]
user=root
command=pulseaudio --daemonize=no --system --disallow-module-loading --disallow-exit --exit-idle-time=-1
logfile=/tmp/pulseaudio.log
pidfile=/tmp/pulseaudio.pid
stopsignal=INT
autostart=true
autorestart=true
priority=10
" | tee /etc/supervisord.conf
EOF
RUN \
chown -R nestri:nestri /tmp /etc/supervisord.conf
ENV USER=nestri
USER 1000
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisord.conf"]
# Debug - pactl list

20
Containerfile.relay Normal file
View File

@@ -0,0 +1,20 @@
FROM docker.io/golang:1.23-alpine AS go-build
WORKDIR /builder
COPY packages/relay/ /builder/
RUN go build
FROM docker.io/golang:1.23-alpine
COPY --from=go-build /builder/relay /relay/relay
WORKDIR /relay
# ENV flags
ENV VERBOSE=false
ENV ENDPOINT_PORT=8088
ENV WEBRTC_UDP_START=10000
ENV WEBRTC_UDP_END=20000
ENV STUN_SERVER="stun.l.google.com:19302"
EXPOSE $ENDPOINT_PORT
EXPOSE $WEBRTC_UDP_START-$WEBRTC_UDP_END/udp
ENTRYPOINT ["/relay/relay"]

119
Containerfile.runner Normal file
View File

@@ -0,0 +1,119 @@
# Container build arguments #
ARG BASE_IMAGE=docker.io/cachyos/cachyos-v3:latest
#******************************************************************************
# gst-builder
#******************************************************************************
FROM ${BASE_IMAGE} AS gst-builder
WORKDIR /builder/
# Grab build and rust packages #
RUN pacman -Syu --noconfirm meson pkgconf cmake git gcc make rustup \
gstreamer gst-plugins-base gst-plugins-good gst-plugin-rswebrtc
# Setup stable rust toolchain #
RUN rustup default stable
# Clone nestri source #
RUN git clone -b feat/stream https://github.com/DatCaptainHorse/nestri.git
# Build nestri #
RUN cd nestri/packages/server/ && \
cargo build --release
#******************************************************************************
# gstwayland-builder
#******************************************************************************
FROM ${BASE_IMAGE} AS gstwayland-builder
WORKDIR /builder/
# Grab build and rust packages #
RUN pacman -Syu --noconfirm meson pkgconf cmake git gcc make rustup \
libxkbcommon wayland gstreamer gst-plugins-base gst-plugins-good libinput
# Setup stable rust toolchain #
RUN rustup default stable
# Build required cargo-c package #
RUN cargo install cargo-c
# Clone gst plugin source #
RUN git clone https://github.com/games-on-whales/gst-wayland-display.git
# Build gst plugin #
RUN mkdir plugin && \
cd gst-wayland-display && \
cargo cinstall --prefix=/builder/plugin/
#******************************************************************************
# runtime
#******************************************************************************
FROM ${BASE_IMAGE} AS runtime
## Install Graphics, Media, and Audio packages ##
RUN pacman -Syu --noconfirm --needed \
# Graphics packages
sudo xorg-xwayland labwc wlr-randr mangohud \
# GStreamer and plugins
gstreamer gst-plugins-base gst-plugins-good \
gst-plugins-bad gst-plugin-pipewire \
gst-plugin-rswebrtc gst-plugin-rsrtp \
# Audio packages
pipewire pipewire-pulse pipewire-alsa wireplumber \
# Other requirements
supervisor jq chwd lshw pacman-contrib && \
# Clean up pacman cache
paccache -rk1
## User ##
# Create and setup user #
ENV USER="nestri" \
UID=99 \
GID=100 \
USER_PASSWORD="nestri1234"
RUN mkdir -p /home/${USER} && \
groupadd -g ${GID} ${USER} && \
useradd -d /home/${USER} -u ${UID} -g ${GID} -s /bin/bash ${USER} && \
chown -R ${USER}:${USER} /home/${USER} && \
echo "${USER} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers && \
echo "${USER}:${USER_PASSWORD}" | chpasswd
# Run directory #
RUN mkdir -p /run/user/${UID} && \
chown ${USER}:${USER} /run/user/${UID}
# Groups #
RUN usermod -aG input root && usermod -aG input ${USER} && \
usermod -aG video root && usermod -aG video ${USER} && \
usermod -aG render root && usermod -aG render ${USER}
## Copy files from builders ##
# this is done here at end to not trigger full rebuild on changes to builder
# nestri
COPY --from=gst-builder /builder/nestri/target/release/nestri-server /usr/bin/nestri-server
# gstwayland
COPY --from=gstwayland-builder /builder/plugin/include/libgstwaylanddisplay /usr/include/
COPY --from=gstwayland-builder /builder/plugin/lib/*libgstwayland* /usr/lib/
COPY --from=gstwayland-builder /builder/plugin/lib/gstreamer-1.0/libgstwayland* /usr/lib/gstreamer-1.0/
COPY --from=gstwayland-builder /builder/plugin/lib/pkgconfig/gstwayland* /usr/lib/pkgconfig/
COPY --from=gstwayland-builder /builder/plugin/lib/pkgconfig/libgstwayland* /usr/lib/pkgconfig/
## Copy scripts ##
COPY packages/scripts/ /etc/nestri/
# Set scripts as executable #
RUN chmod +x /etc/nestri/envs.sh /etc/nestri/entrypoint.sh /etc/nestri/entrypoint_nestri.sh
## Set runtime envs ##
ENV XDG_RUNTIME_DIR=/run/user/${UID} \
HOME=/home/${USER}
# Required for NVIDIA.. they want to be special like that #
ENV NVIDIA_DRIVER_CAPABILITIES=all
# Wireplumber disable suspend #
# Remove suspend node
RUN sed -z -i 's/{[[:space:]]*name = node\/suspend-node\.lua,[[:space:]]*type = script\/lua[[:space:]]*provides = hooks\.node\.suspend[[:space:]]*}[[:space:]]*//g' /usr/share/wireplumber/wireplumber.conf
# Remove "hooks.node.suspend" want
RUN sed -i '/wants = \[/{s/hooks\.node\.suspend\s*//; s/,\s*\]/]/}' /usr/share/wireplumber/wireplumber.conf
ENTRYPOINT ["supervisord", "-c", "/etc/nestri/supervisord.conf"]

View File

@@ -4,8 +4,20 @@
Nestri is in a **very early-beta phase**, so errors and bugs may occur. Nestri is in a **very early-beta phase**, so errors and bugs may occur.
:: ::
### Step 0: Construct Your Docker Image
Checkout your branch with the latest version of nestri and build the image `<your-nestri-image>` within git root folder:
```bash
docker buildx build -t <your-nestri-image>:latest -f Containerfile.runner .
```
## Step 1: Navigate to Your Game Directory ::alert{type="info"}
You can right now also pull the docker image from DatHorse GitHub Containter Registry with:
```bash
docker pull ghcr.io/datcaptainhorse/nestri-cachyos:latest
```
::
### Step 1: Navigate to Your Game Directory
First, change your directory to the location of your `.exe` file. For Steam games, this typically means: First, change your directory to the location of your `.exe` file. For Steam games, this typically means:
```bash ```bash
cd $HOME/.steam/steam/steamapps cd $HOME/.steam/steam/steamapps
@@ -18,28 +30,49 @@ echo "$(head /dev/urandom | LC_ALL=C tr -dc 'a-zA-Z0-9' | head -c 16)"
``` ```
This command generates a random 16-character string. Be sure to note this string carefully, as you'll need it for the next step. This command generates a random 16-character string. Be sure to note this string carefully, as you'll need it for the next step.
### Step 3: Launch the Nestri Server ### Step 3: Launch the Nestri Server
With your SESSION_ID ready, insert it into the command below, replacing `<paste here>` with your actual session ID. Then, run the command to start the Nestri server: With your SESSION_ID ready, insert it into the command below, replacing `<your_session_id>` with your actual session ID, also replace `<relay_url>` with your relay URL and `<your-nestri-image>` with your build nestri image or nestri remote image. Then run the command to start the Nestri server:
```bash
docker run --rm -it --shm-size=1g --gpus all -e NVIDIA_DRIVER_CAPABILITIES=all --runtime=nvidia -e RELAY_URL='<relay_url>' -e NESTRI_ROOM=<your_session_id> -e RESOLUTION=1920x1080 -e FRAMERATE=60 -e NESTRI_PARAMS='--verbose=true --video-codec=h264 --video-bitrate=4000 --video-bitrate-max=6000'--name nestri -d -v "$(pwd)":/mnt/game/ <your-nestri-image>:latest
``` ```
docker run --gpus all --device=/dev/dri --name nestri -it --entrypoint /bin/bash -e SESSION_ID=<paste here> -v "$(pwd)":/game -p 8080:8080/udp --cap-add=SYS_NICE --cap-add=SYS_ADMIN ghcr.io/nestriness/nestri/server:nightly
### Step 4: Get Into your container
Get into your container to start your game:
```bash
sudo docker exec -it nestri bash
``` ```
> \[!TIP] ### Step 5: Installing a Launcher
> For most games that are not DRM free you need a launcher. In this case use the umu launcher and optional mangohud:
> Ensure UDP port 8080 is accessible from the internet. Use `ufw allow 8080/udp` or adjust your cloud provider's security group settings accordingly. ```bash
### Step 4: Configure the Game within the Container pacman -S --overwrite="*" umu-launcher mangohud
After executing the previous command, you'll be in a new shell within the container (example: `nestri@3f199ee68c01:~$`). Perform the following checks: ```
1. Verify the game is mounted by executing `ls -la /game`. If not, exit and ensure you've correctly mounted the game directory as a volume.
2. Then, start the Nestri server by running `/etc/startup.sh > /dev/null &`.
### Step 5: Running Your Game ### Step 5: Running Your Game
Wait for the `.X11-unix` directory to appear in `/tmp` (check with `ls -la /tmp`). Once it appears, you're ready to launch your game. You have to execute your game now with nestri user. If you have a linux game just execute it with the nestri user
- With Proton-GE: `nestri-proton -pr <game>.exe` ```bash
- With Wine: `nestri-proton -wr <game>.exe` su nestri
source /etc/nestri/envs.sh
GAMEID=0 PROTONPATH=GE-Proton mangohud umu-run /mnt/game/<your-game.exe>
```
### Step 6: Begin Playing ### Step 6: Begin Playing
Finally, construct the play URL with your session ID: Finally, construct the play URL with your session ID:
``` `https://nestri.io/play/<your_session_id>`
echo "https://nestri.io/play/$SESSION_ID"
```
Navigate to this URL in your browser, click on the page to capture your mouse pointer, and start playing! Navigate to this URL in your browser, click on the page to capture your mouse pointer, and start playing!
::alert{type="info"}
You can also use other relays/frontends depending on your choosen `<relay_url>`
For testing you can use DatHorse Relay and Frontend:
| **Placeholder** | **URL** |
| ---------------------------- | ---------- |
| `<relay_url>` | `https://relay.dathorse.com/` |
| `<frontend_url>` | `https://nestritest.dathorse.com/play/<your_session_id>` |
::
<!-- <!--
Nestri Node is easy to install using the provided installation script. Follow the steps below to get started. Nestri Node is easy to install using the provided installation script. Follow the steps below to get started.

View File

@@ -17,7 +17,7 @@ export default component$(() => {
video = document.createElement("video"); video = document.createElement("video");
video.id = "stream-video-player"; video.id = "stream-video-player";
video.style.visibility = "hidden"; video.style.visibility = "hidden";
const webrtc = new WebRTCStream("https://relay.dathorse.com", id, (mediaStream) => { const webrtc = new WebRTCStream("http://localhost:8088", id, (mediaStream) => {
if (video && mediaStream && (video as HTMLVideoElement).srcObject === null) { if (video && mediaStream && (video as HTMLVideoElement).srcObject === null) {
console.log("Setting mediastream"); console.log("Setting mediastream");
(video as HTMLVideoElement).srcObject = mediaStream; (video as HTMLVideoElement).srcObject = mediaStream;
@@ -26,6 +26,8 @@ export default component$(() => {
window.hasstream = true; window.hasstream = true;
// @ts-ignore // @ts-ignore
window.roomOfflineElement?.remove(); window.roomOfflineElement?.remove();
// @ts-ignore
window.playbtnelement?.remove();
const playbtn = document.createElement("button"); const playbtn = document.createElement("button");
playbtn.style.position = "absolute"; playbtn.style.position = "absolute";
@@ -81,8 +83,14 @@ export default component$(() => {
}); });
}; };
document.body.append(playbtn); document.body.append(playbtn);
// @ts-ignore
window.playbtnelement = playbtn;
} else if (mediaStream === null) { } else if (mediaStream === null) {
console.log("MediaStream is null, Room is offline"); console.log("MediaStream is null, Room is offline");
// @ts-ignore
window.playbtnelement?.remove();
// @ts-ignore
window.roomOfflineElement?.remove();
// Add a message to the screen // Add a message to the screen
const offline = document.createElement("div"); const offline = document.createElement("div");
offline.style.position = "absolute"; offline.style.position = "absolute";
@@ -104,6 +112,30 @@ export default component$(() => {
const ctx = canvas.value.getContext("2d"); const ctx = canvas.value.getContext("2d");
if (ctx) ctx.clearRect(0, 0, canvas.value.width, canvas.value.height); if (ctx) ctx.clearRect(0, 0, canvas.value.width, canvas.value.height);
} }
} else if ((video as HTMLVideoElement).srcObject !== null) {
console.log("Setting new mediastream");
(video as HTMLVideoElement).srcObject = mediaStream;
// @ts-ignore
window.hasstream = true;
// Start video rendering
(video as HTMLVideoElement).play().then(() => {
// @ts-ignore
window.roomOfflineElement?.remove();
if (canvas.value) {
canvas.value.width = (video as HTMLVideoElement).videoWidth;
canvas.value.height = (video as HTMLVideoElement).videoHeight;
const ctx = canvas.value.getContext("2d");
const renderer = () => {
// @ts-ignore
if (ctx && window.hasstream) {
ctx.drawImage((video as HTMLVideoElement), 0, 0);
(video as HTMLVideoElement).requestVideoFrameCallback(renderer);
}
}
(video as HTMLVideoElement).requestVideoFrameCallback(renderer);
}
});
} }
}); });
} }

View File

@@ -25,14 +25,107 @@ export class WebRTCStream {
} }
this._onConnected = connectedCallback; this._onConnected = connectedCallback;
this._setup(serverURL, roomName);
}
private _setup(serverURL: string, roomName: string) {
console.log("Setting up WebSocket"); console.log("Setting up WebSocket");
// Replace http/https with ws/wss // Replace http/https with ws/wss
const wsURL = serverURL.replace(/^http/, "ws"); const wsURL = serverURL.replace(/^http/, "ws");
this._ws = new WebSocket(`${wsURL}/api/ws/${roomName}`); this._ws = new WebSocket(`${wsURL}/api/ws/${roomName}`);
this._ws.onopen = async () => { this._ws.onopen = async () => {
console.log("WebSocket opened"); console.log("WebSocket opened");
// Send join message
const joinMessage: MessageJoin = {
payload_type: "join",
joiner_type: JoinerType.JoinerClient
};
this._ws!.send(encodeMessage(joinMessage));
}
let iceHolder: RTCIceCandidateInit[] = [];
this._ws.onmessage = async (e) => {
// allow only binary
if (typeof e.data !== "object") return;
if (!e.data) return;
const message = await decodeMessage<MessageBase>(e.data);
switch (message.payload_type) {
case "sdp":
if (!this._pc) {
// Setup peer connection now
this._setupPeerConnection();
}
console.log("Received SDP: ", (message as MessageSDP).sdp);
await this._pc!.setRemoteDescription((message as MessageSDP).sdp);
// Create our answer
const answer = await this._pc!.createAnswer();
// Force stereo in Chromium browsers
answer.sdp = this.forceOpusStereo(answer.sdp!);
await this._pc!.setLocalDescription(answer);
this._ws!.send(encodeMessage({
payload_type: "sdp",
sdp: answer
}));
break;
case "ice":
if (!this._pc) break;
// If remote description is not set yet, hold the ICE candidates
if (this._pc.remoteDescription) {
await this._pc.addIceCandidate((message as MessageICE).candidate);
// Add held ICE candidates
for (const ice of iceHolder) {
await this._pc.addIceCandidate(ice);
}
iceHolder = [];
} else {
iceHolder.push((message as MessageICE).candidate);
}
break;
case "answer":
switch ((message as MessageAnswer).answer_type) {
case AnswerType.AnswerOffline:
console.log("Room is offline");
// Call callback with null stream
if (this._onConnected)
this._onConnected(null);
break;
case AnswerType.AnswerInUse:
console.warn("Room is in use, we shouldn't even be getting this message");
break;
case AnswerType.AnswerOK:
console.log("Joining Room was successful");
break;
}
break;
default:
console.error("Unknown message type: ", message);
}
}
this._ws.onclose = () => {
console.log("WebSocket closed, reconnecting in 3 seconds");
if (this._onConnected)
this._onConnected(null);
// Clear PeerConnection
if (this._pc) {
this._pc.close();
this._pc = undefined;
}
setTimeout(() => {
this._setup(serverURL, roomName);
}, 3000);
}
this._ws.onerror = (e) => {
console.error("WebSocket error: ", e);
}
}
private _setupPeerConnection() {
console.log("Setting up PeerConnection"); console.log("Setting up PeerConnection");
this._pc = new RTCPeerConnection({ this._pc = new RTCPeerConnection({
iceServers: [ iceServers: [
@@ -69,77 +162,6 @@ export class WebRTCStream {
this._dataChannel = e.channel; this._dataChannel = e.channel;
this._setupDataChannelEvents(); this._setupDataChannelEvents();
} }
// Send join message
const joinMessage: MessageJoin = {
payload_type: "join",
joiner_type: JoinerType.JoinerClient
};
this._ws!.send(encodeMessage(joinMessage));
}
let iceHolder: RTCIceCandidateInit[] = [];
this._ws.onmessage = async (e) => {
// allow only binary
if (typeof e.data !== "object") return;
if (!e.data) return;
const message = await decodeMessage<MessageBase>(e.data);
switch (message.payload_type) {
case "sdp":
await this._pc!.setRemoteDescription((message as MessageSDP).sdp);
// Create our answer
const answer = await this._pc!.createAnswer();
// Force stereo in Chromium browsers
answer.sdp = this.forceOpusStereo(answer.sdp!);
await this._pc!.setLocalDescription(answer);
this._ws!.send(encodeMessage({
payload_type: "sdp",
sdp: answer
}));
break;
case "ice":
// If remote description is not set yet, hold the ICE candidates
if (this._pc!.remoteDescription) {
await this._pc!.addIceCandidate((message as MessageICE).candidate);
// Add held ICE candidates
for (const ice of iceHolder) {
await this._pc!.addIceCandidate(ice);
}
iceHolder = [];
} else {
iceHolder.push((message as MessageICE).candidate);
}
break;
case "answer":
switch ((message as MessageAnswer).answer_type) {
case AnswerType.AnswerOffline:
console.log("Room is offline");
// Call callback with null stream
if (this._onConnected)
this._onConnected(null);
break;
case AnswerType.AnswerInUse:
console.warn("Room is in use, we shouldn't even be getting this message");
break;
case AnswerType.AnswerOK:
console.log("Joining Room was successful");
break;
}
break;
default:
console.error("Unknown message type: ", message);
}
}
this._ws.onclose = () => {
console.log("WebSocket closed");
}
this._ws.onerror = (e) => {
console.error("WebSocket error: ", e);
}
} }
// Forces opus to stereo in Chromium browsers, because of course // Forces opus to stereo in Chromium browsers, because of course

View File

@@ -200,11 +200,15 @@ func ingestHandler(room *Room) {
}) })
room.WebSocket.RegisterOnClose(func() { room.WebSocket.RegisterOnClose(func() {
// If PeerConnection is not open or does not exist, delete room // If PeerConnection is still open, close it
if (room.PeerConnection != nil && room.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateConnected) || if room.PeerConnection != nil {
room.PeerConnection == nil { if err = room.PeerConnection.Close(); err != nil {
DeleteRoomIfEmpty(room) log.Printf("Failed to close PeerConnection for room: '%s' - reason: %s\n", room.Name, err)
} }
room.PeerConnection = nil
}
room.Online = false
DeleteRoomIfEmpty(room)
}) })
log.Printf("Room: '%s' is ready, sending an OK\n", room.Name) log.Printf("Room: '%s' is ready, sending an OK\n", room.Name)

View File

@@ -33,7 +33,7 @@ func (vw *Participant) addTrack(trackLocal *webrtc.TrackLocal) error {
rtcpBuffer := make([]byte, 1400) rtcpBuffer := make([]byte, 1400)
for { for {
if _, _, rtcpErr := rtpSender.Read(rtcpBuffer); rtcpErr != nil { if _, _, rtcpErr := rtpSender.Read(rtcpBuffer); rtcpErr != nil {
return break
} }
} }
}() }()

View File

@@ -10,6 +10,7 @@ import (
type SafeWebSocket struct { type SafeWebSocket struct {
*websocket.Conn *websocket.Conn
sync.Mutex sync.Mutex
closeCallback func() // OnClose callback
binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback
} }
@@ -17,6 +18,7 @@ type SafeWebSocket struct {
func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket { func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
ws := &SafeWebSocket{ ws := &SafeWebSocket{
Conn: conn, Conn: conn,
closeCallback: nil,
binaryCallbacks: make(map[string]OnMessageCallback), binaryCallbacks: make(map[string]OnMessageCallback),
} }
@@ -32,10 +34,6 @@ func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
} }
break break
} else if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) { } else if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
// If closing, just break
if GetFlags().Verbose {
log.Printf("WebSocket closing\n")
}
break break
} else if err != nil { } else if err != nil {
log.Printf("Failed to read WebSocket message, reason: %s\n", err) log.Printf("Failed to read WebSocket message, reason: %s\n", err)
@@ -62,6 +60,11 @@ func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
log.Printf("Unknown WebSocket message type: %d\n", kind) log.Printf("Unknown WebSocket message type: %d\n", kind)
} }
} }
// Call close callback
if ws.closeCallback != nil {
ws.closeCallback()
}
}() }()
return ws return ws
@@ -102,13 +105,12 @@ func (ws *SafeWebSocket) UnregisterMessageCallback(msgType string) {
// RegisterOnClose sets the callback for websocket closing // RegisterOnClose sets the callback for websocket closing
func (ws *SafeWebSocket) RegisterOnClose(callback func()) { func (ws *SafeWebSocket) RegisterOnClose(callback func()) {
ws.SetCloseHandler(func(code int, text string) error { ws.closeCallback = func() {
// Clear our callbacks // Clear our callbacks
ws.Lock() ws.Lock()
ws.binaryCallbacks = nil ws.binaryCallbacks = nil
ws.Unlock() ws.Unlock()
// Call the callback // Call the callback
callback() callback()
return nil }
})
} }

View File

@@ -0,0 +1,65 @@
#!/bin/bash
set -euo pipefail
# Wait for dbus socket to be ready
echo "Waiting for DBus system bus socket..."
DBUS_SOCKET="/run/dbus/system_bus_socket"
for _ in {1..10}; do # Wait up to 10 seconds
if [ -e "$DBUS_SOCKET" ]; then
echo "DBus system bus socket is ready."
break
fi
sleep 1
done
if [ ! -e "$DBUS_SOCKET" ]; then
echo "Error: DBus system bus socket did not appear. Exiting."
exit 1
fi
# Wait for PipeWire to be ready
echo "Waiting for PipeWire socket..."
PIPEWIRE_SOCKET="/run/user/${UID}/pipewire-0"
for _ in {1..10}; do # Wait up to 10 seconds
if [ -e "$PIPEWIRE_SOCKET" ]; then
echo "PipeWire socket is ready."
break
fi
sleep 1
done
if [ ! -e "$PIPEWIRE_SOCKET" ]; then
echo "Error: PipeWire socket did not appear. Exiting."
exit 1
fi
# Update system packages before proceeding
echo "Upgrading system packages..."
pacman -Syu --noconfirm
echo "Detecting GPU vendor and installing necessary GStreamer plugins..."
source /etc/nestri/gpu_helpers.sh
get_gpu_info
# Identify vendor
if [[ "${vendor_full_map[0],,}" =~ "intel" ]]; then
echo "Intel GPU detected, installing required packages..."
chwd -a
pacman -Syu --noconfirm gstreamer-vaapi gst-plugin-va gst-plugin-qsv
# chwd missed a thing
pacman -Syu --noconfirm vpl-gpu-rt
elif [[ "${vendor_full_map[0],,}" =~ "amd" ]]; then
echo "AMD GPU detected, installing required packages..."
chwd -a
pacman -Syu --noconfirm gstreamer-vaapi gst-plugin-va
elif [[ "${vendor_full_map[0],,}" =~ "nvidia" ]]; then
echo "NVIDIA GPU detected. Assuming drivers are linked"
else
echo "Unknown GPU vendor. No additional packages will be installed"
fi
# Clean up remainders
echo "Cleaning up old package cache..."
paccache -rk1
echo "Switching to nestri user for application startup..."
exec sudo -E -u nestri /etc/nestri/entrypoint_nestri.sh

View File

@@ -0,0 +1,156 @@
#!/bin/bash
set -euo pipefail
# Source environment variables from envs.sh
if [ -f /etc/nestri/envs.sh ]; then
echo "Sourcing environment variables from envs.sh..."
source /etc/nestri/envs.sh
else
echo "envs.sh not found! Ensure it exists at /etc/nestri/envs.sh."
exit 1
fi
# Configuration
MAX_RETRIES=3
# Helper function to restart the chain
restart_chain() {
echo "Restarting nestri-server, labwc, and wlr-randr..."
# Kill all child processes safely (if running)
if [[ -n "${NESTRI_PID:-}" ]] && kill -0 "${NESTRI_PID}" 2 >/dev/null; then
kill "${NESTRI_PID}"
fi
if [[ -n "${LABWC_PID:-}" ]] && kill -0 "${LABWC_PID}" 2 >/dev/null; then
kill "${LABWC_PID}"
fi
wait || true
# Start nestri-server
start_nestri_server
RETRY_COUNT=0
}
# Function to start nestri-server
start_nestri_server() {
echo "Starting nestri-server..."
nestri-server $(echo $NESTRI_PARAMS) &
NESTRI_PID=$!
# Wait for Wayland display (wayland-1) to be ready
echo "Waiting for Wayland display 'wayland-1' to be ready..."
WAYLAND_SOCKET="/run/user/${UID}/wayland-1"
for _ in {1..15}; do # Wait up to 15 seconds
if [ -e "$WAYLAND_SOCKET" ]; then
echo "Wayland display 'wayland-1' is ready."
start_labwc
return
fi
sleep 1
done
echo "Error: Wayland display 'wayland-1' did not appear. Incrementing retry count..."
((RETRY_COUNT++))
if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for nestri-server. Exiting."
exit 1
fi
restart_chain
}
# Function to start labwc
start_labwc() {
echo "Starting labwc..."
rm -rf /tmp/.X11-unix && mkdir -p /tmp/.X11-unix && chown nestri:nestri /tmp/.X11-unix
WAYLAND_DISPLAY=wayland-1 WLR_BACKENDS=wayland labwc &
LABWC_PID=$!
# Wait for labwc to initialize (using `wlr-randr` as an indicator)
echo "Waiting for labwc to initialize..."
for _ in {1..15}; do
if wlr-randr --json | jq -e '.[] | select(.enabled == true)' >/dev/null; then
echo "labwc is initialized and Wayland outputs are ready."
start_wlr_randr
return
fi
sleep 1
done
echo "Error: labwc did not initialize correctly. Incrementing retry count..."
((RETRY_COUNT++))
if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for labwc. Exiting."
exit 1
fi
restart_chain
}
# Function to run wlr-randr
start_wlr_randr() {
echo "Configuring resolution with wlr-randr..."
OUTPUT_NAME=$(wlr-randr --json | jq -r '.[] | select(.enabled == true) | .name' | head -n 1)
if [ -z "$OUTPUT_NAME" ]; then
echo "Error: No enabled outputs detected. Skipping wlr-randr."
return
fi
# Retry logic for wlr-randr
local WLR_RETRIES=0
while ! WAYLAND_DISPLAY=wayland-0 wlr-randr --output "$OUTPUT_NAME" --custom-mode "$RESOLUTION"; do
echo "Error: Failed to configure wlr-randr. Retrying..."
((WLR_RETRIES++))
if [ "$WLR_RETRIES" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for wlr-randr. Moving on without resolution setup."
return
fi
sleep 2
done
echo "wlr-randr configuration successful."
}
# Main loop to monitor processes
main_loop() {
trap 'echo "Terminating...";
if [[ -n "${NESTRI_PID:-}" ]] && kill -0 "${NESTRI_PID}" 2>/dev/null; then
kill "${NESTRI_PID}"
fi
if [[ -n "${LABWC_PID:-}" ]] && kill -0 "${LABWC_PID}" 2>/dev/null; then
kill "${LABWC_PID}"
fi
exit 0' SIGINT SIGTERM
while true; do
# Wait for any child process to exit
wait -n
# Check which process exited
if ! kill -0 ${NESTRI_PID:-} 2 >/dev/null; then
echo "nestri-server crashed. Restarting chain..."
((RETRY_COUNT++))
if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for nestri-server. Exiting."
exit 1
fi
restart_chain
elif ! kill -0 ${LABWC_PID:-} 2 >/dev/null; then
echo "labwc crashed. Restarting labwc and wlr-randr..."
((RETRY_COUNT++))
if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for labwc. Exiting."
exit 1
fi
start_labwc
fi
done
}
# Initialize retry counter
RETRY_COUNT=0
# Start the initial chain
restart_chain
# Enter monitoring loop
main_loop

View File

@@ -1,4 +1,5 @@
#!/bin/bash -e #!/bin/bash
set -euo pipefail
export XDG_RUNTIME_DIR=/run/user/${UID}/ export XDG_RUNTIME_DIR=/run/user/${UID}/
export WAYLAND_DISPLAY=wayland-0 export WAYLAND_DISPLAY=wayland-0
@@ -9,4 +10,4 @@ export $(dbus-launch)
export PROTON_NO_FSYNC=1 export PROTON_NO_FSYNC=1
# Our preferred prefix # Our preferred prefix
export WINEPREFIX=${USER_HOME}/.nestripfx/ export WINEPREFIX=/home/${USER}/.nestripfx/

View File

@@ -1,277 +1,52 @@
#!/bin/bash -e #!/bin/bash
set -euo pipefail
# Various helper functions for handling available GPUs declare -a vendor_full_map=()
declare -a vendor_id_map=()
declare -A vendor_index_map=()
declare -ga gpu_map function get_gpu_info {
declare -gA gpu_bus_map # Initialize arrays/maps to avoid unbound variable errors
declare -gA gpu_card_map
declare -gA gpu_product_map
declare -gA vendor_index_map
declare -gA vendor_full_map
# Map to help get shorter vendor identifiers
declare -A vendor_keywords=(
["advanced micro devices"]='amd'
["ati"]='amd'
["amd"]='amd'
["radeon"]='amd'
["nvidia"]='nvidia'
["intel"]='intel'
)
get_gpu_info() {
# Clear out previous data
gpu_map=()
gpu_bus_map=()
gpu_card_map=()
gpu_product_map=()
vendor_index_map=()
vendor_full_map=() vendor_full_map=()
vendor_id_map=()
vendor_index_map=()
local vendor="" # Use lspci to detect GPU info
local product="" gpu_info=$(lspci | grep -i 'vga\|3d\|display')
local bus_info=""
local vendor_full=""
while read -r line; do # Parse each line of GPU info
line="${line##*( )}" while IFS= read -r line; do
# Extract vendor name and ID from lspci output
vendor=$(echo "$line" | awk -F: '{print $3}' | sed -E 's/^[[:space:]]+//g' | tr '[:upper:]' '[:lower:]')
id=$(echo "$line" | awk '{print $1}')
if [[ "${line,,}" =~ "vendor:" ]]; then # Normalize vendor name
vendor="" if [[ $vendor =~ .*nvidia.* ]]; then
vendor_full=$(echo "$line" | awk '{$1=""; print $0}' | xargs) vendor="nvidia"
elif [[ $vendor =~ .*intel.* ]]; then
# Look for short vendor keyword in line vendor="intel"
for keyword in "${!vendor_keywords[@]}"; do elif [[ $vendor =~ .*advanced[[:space:]]micro[[:space:]]devices.* ]]; then
if [[ "${line,,}" == *"$keyword"* ]]; then vendor="amd"
vendor="${vendor_keywords[$keyword]}" elif [[ $vendor =~ .*ati.* ]]; then
break vendor="amd"
fi
done
# If no vendor keywords match, use first word
if [[ -z "$vendor" ]]; then
vendor=$(echo "$vendor_full" | awk '{print tolower($1)}')
fi
elif [[ "${line,,}" =~ "product:" ]]; then
product=$(echo "$line" | awk '{$1=""; print $0}' | xargs)
elif [[ "${line,,}" =~ "bus info:" ]]; then
bus_info=$(echo "$line" | awk '{print $3}')
fi
if [[ -n "$vendor" && -n "$product" && -n "$bus_info" && ! "${line,,}" =~ \*"-display" ]]; then
# We have gathered all GPU info necessary, store it
# Check if vendor index is being tracked
if [[ -z "${vendor_index_map[$vendor]}" ]]; then
# Start new vendor index tracking
vendor_index_map[$vendor]=0
else else
# Another GPU of same vendor, increment index vendor="unknown"
vendor_index_map[$vendor]="$((vendor_index_map[$vendor] + 1))"
fi fi
# Resolved GPU index # Add to arrays/maps if unique
local gpu_index="${vendor_index_map[$vendor]}" if ! [[ "${vendor_index_map[$vendor]:-}" ]]; then
local gpu_key="$vendor:$gpu_index" vendor_index_map[$vendor]="${#vendor_full_map[@]}"
vendor_full_map+=("$vendor")
# Get /dev/dri/cardN of GPU
local gpu_card=$({ ls -1d /sys/bus/pci/devices/*${bus_info#pci@}/drm/*; } 2>&1 | grep card* | grep -oP '(?<=card)\d+')
# Store info in maps
gpu_map+=("$gpu_key")
gpu_bus_map["$gpu_key"]="$bus_info"
gpu_product_map["$gpu_key"]="$product"
vendor_full_map["$gpu_key"]="$vendor_full"
if [[ -n "$gpu_card" ]]; then
gpu_card_map["$gpu_key"]="$gpu_card"
fi fi
vendor_id_map+=("$id")
# Clear values for additional GPUs done <<< "$gpu_info"
vendor=""
product=""
bus_info=""
vendor_full=""
fi
if [[ "${line,,}" =~ \*"-display" ]]; then
# New GPU found before storing, clear incomplete values to prevent mixing
vendor=""
product=""
bus_info=""
vendor_full=""
fi
done < <(sudo lshw -c video)
} }
check_and_populate_gpus() { function debug_gpu_info {
if [[ "${#gpu_map[@]}" -eq 0 ]]; then echo "Vendor Full Map: ${vendor_full_map[*]}"
get_gpu_info # Gather info incase info not gathered yet echo "Vendor ID Map: ${vendor_id_map[*]}"
if [[ "${#gpu_map[@]}" -eq 0 ]]; then echo "Vendor Index Map:"
echo "No GPUs found on this system" >&2 for key in "${!vendor_index_map[@]}"; do
return 1 echo " $key: ${vendor_index_map[$key]}"
fi
fi
}
check_selected_gpu() {
local selected_gpu="${1,,}"
if [[ ! " ${gpu_map[*]} " =~ " $selected_gpu " ]]; then
echo "No such GPU: '$selected_gpu'" >&2
return 1
fi
echo "$selected_gpu"
}
list_available_gpus() {
if ! check_and_populate_gpus; then
return 1
fi
echo "Available GPUs:" >&2
for gpu in "${gpu_map[@]}"; do
echo " [$gpu] \"${gpu_product_map[$gpu]}\" @[${gpu_bus_map[$gpu]}]"
done done
} }
convert_bus_id_to_xorg() {
local bus_info="$1"
IFS=":." read -ra bus_parts <<< "${bus_info#pci@????:}" # Remove "pci@" and the following 4 characters (domain)
# Check if bus_info has the correct format (at least 3 parts after removing domain)
if [[ "${#bus_parts[@]}" -lt 3 ]]; then
echo "Invalid bus info format: $bus_info" >&2
return 1
fi
# Convert each part from hexadecimal to decimal
bus_info_xorg="PCI:"
for part in "${bus_parts[@]}"; do
bus_info_xorg+="$((16#$part)):"
done
bus_info_xorg="${bus_info_xorg%:}" # Remove the trailing colon
echo "$bus_info_xorg"
}
print_gpu_info() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "[$selected_gpu]"
echo " Vendor: ${vendor_full_map[$selected_gpu]}"
echo " Product: ${gpu_product_map[$selected_gpu]}"
echo " Bus: ${gpu_bus_map[$selected_gpu]}"
# Check if card path was found
if [[ "${gpu_card_map[$selected_gpu]}" ]]; then
echo " Card: /dev/dri/card${gpu_card_map[$selected_gpu]}"
fi
echo
}
get_gpu_vendor() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${selected_gpu%%:*}"
}
get_gpu_vendor_full() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${vendor_full_map[$selected_gpu]}"
}
get_gpu_index() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${selected_gpu#*:}"
}
get_gpu_product() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${gpu_product_map[$selected_gpu]}"
}
get_gpu_bus() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${gpu_bus_map[$selected_gpu]}"
}
get_gpu_bus_xorg() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo $(convert_bus_id_to_xorg "${gpu_bus_map[$selected_gpu]}")
}
get_gpu_card() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
# Check if card path was found
if [[ -z "${gpu_card_map[$selected_gpu]}" ]]; then
echo "No card device found for GPU: $selected_gpu" >&2
return 1
fi
echo "/dev/dri/card${gpu_card_map[$selected_gpu]}"
}

View File

@@ -0,0 +1,53 @@
[supervisord]
user=root
nodaemon=true
loglevel=info
logfile=/tmp/supervisord.log
[program:dbus]
user=root
command=dbus-daemon --system --nofork --nopidfile
autorestart=true
autostart=true
startretries=3
priority=1
[program:seatd]
user=root
command=seatd
autorestart=true
autostart=true
startretries=3
priority=2
[program:pipewire]
user=nestri
command=dbus-launch pipewire
autorestart=true
autostart=true
startretries=3
priority=3
[program:pipewire-pulse]
user=nestri
command=dbus-launch pipewire-pulse
autorestart=true
autostart=true
startretries=3
priority=4
[program:wireplumber]
user=nestri
command=dbus-launch wireplumber
autorestart=true
autostart=true
startretries=3
priority=5
[program:entrypoint]
user=root
command=/etc/nestri/entrypoint.sh
autorestart=false
autostart=true
startretries=0
priority=10

View File

@@ -9,7 +9,8 @@ path = "src/main.rs"
[dependencies] [dependencies]
gst.workspace = true gst.workspace = true
gst-app.workspace = true gst-webrtc.workspace = true
gstrswebrtc.workspace = true
serde = {version = "1.0.214", features = ["derive"] } serde = {version = "1.0.214", features = ["derive"] }
tokio = { version = "1.41.0", features = ["full"] } tokio = { version = "1.41.0", features = ["full"] }
clap = { version = "4.5.20", features = ["env"] } clap = { version = "4.5.20", features = ["env"] }

View File

@@ -42,7 +42,7 @@ impl Args {
.short('u') .short('u')
.long("relay-url") .long("relay-url")
.env("RELAY_URL") .env("RELAY_URL")
.help("Nestri relay URL") .help("Nestri relay URL"),
) )
.arg( .arg(
Arg::new("resolution") Arg::new("resolution")
@@ -64,7 +64,7 @@ impl Args {
Arg::new("room") Arg::new("room")
.long("room") .long("room")
.env("NESTRI_ROOM") .env("NESTRI_ROOM")
.help("Nestri room name/identifier") .help("Nestri room name/identifier"),
) )
.arg( .arg(
Arg::new("gpu-vendor") Arg::new("gpu-vendor")
@@ -110,7 +110,7 @@ impl Args {
Arg::new("video-encoder") Arg::new("video-encoder")
.long("video-encoder") .long("video-encoder")
.env("VIDEO_ENCODER") .env("VIDEO_ENCODER")
.help("Override video encoder (e.g. 'vah264enc')") .help("Override video encoder (e.g. 'vah264enc')"),
) )
.arg( .arg(
Arg::new("video-rate-control") Arg::new("video-rate-control")
@@ -165,7 +165,7 @@ impl Args {
Arg::new("audio-encoder") Arg::new("audio-encoder")
.long("audio-encoder") .long("audio-encoder")
.env("AUDIO_ENCODER") .env("AUDIO_ENCODER")
.help("Override audio encoder (e.g. 'opusenc')") .help("Override audio encoder (e.g. 'opusenc')"),
) )
.arg( .arg(
Arg::new("audio-rate-control") Arg::new("audio-rate-control")
@@ -188,6 +188,13 @@ impl Args {
.help("Maximum bitrate in kbps") .help("Maximum bitrate in kbps")
.default_value("192"), .default_value("192"),
) )
.arg(
Arg::new("dma-buf")
.long("dma-buf")
.env("DMA_BUF")
.help("Use DMA-BUF for pipeline")
.default_value("false"),
)
.get_matches(); .get_matches();
Self { Self {

View File

@@ -15,6 +15,9 @@ pub struct AppArgs {
pub relay_url: String, pub relay_url: String,
/// Nestri room name/identifier /// Nestri room name/identifier
pub room: String, pub room: String,
/// Experimental DMA-BUF support
pub dma_buf: bool,
} }
impl AppArgs { impl AppArgs {
pub fn from_matches(matches: &clap::ArgMatches) -> Self { pub fn from_matches(matches: &clap::ArgMatches) -> Self {
@@ -26,23 +29,33 @@ impl AppArgs {
debug_latency: matches.get_one::<String>("debug-latency").unwrap() == "true" debug_latency: matches.get_one::<String>("debug-latency").unwrap() == "true"
|| matches.get_one::<String>("debug-latency").unwrap() == "1", || matches.get_one::<String>("debug-latency").unwrap() == "1",
resolution: { resolution: {
let res = matches.get_one::<String>("resolution").unwrap().clone(); let res = matches
.get_one::<String>("resolution")
.unwrap_or(&"1280x720".to_string())
.clone();
let parts: Vec<&str> = res.split('x').collect(); let parts: Vec<&str> = res.split('x').collect();
if parts.len() >= 2 {
( (
parts[0].parse::<u32>().unwrap(), parts[0].parse::<u32>().unwrap_or(1280),
parts[1].parse::<u32>().unwrap(), parts[1].parse::<u32>().unwrap_or(720),
) )
} else {
(1280, 720)
}
}, },
framerate: matches framerate: matches
.get_one::<String>("framerate") .get_one::<String>("framerate")
.unwrap() .unwrap()
.parse::<u32>() .parse::<u32>()
.unwrap(), .unwrap_or(60),
relay_url: matches.get_one::<String>("relay-url").unwrap().clone(), relay_url: matches.get_one::<String>("relay-url").unwrap().clone(),
// Generate random room name if not provided // Generate random room name if not provided
room: matches.get_one::<String>("room") room: matches
.get_one::<String>("room")
.unwrap_or(&rand::random::<u32>().to_string()) .unwrap_or(&rand::random::<u32>().to_string())
.clone(), .clone(),
dma_buf: matches.get_one::<String>("dma-buf").unwrap() == "true"
|| matches.get_one::<String>("dma-buf").unwrap() == "1",
} }
} }
@@ -55,5 +68,6 @@ impl AppArgs {
println!("> framerate: {}", self.framerate); println!("> framerate: {}", self.framerate);
println!("> relay_url: {}", self.relay_url); println!("> relay_url: {}", self.relay_url);
println!("> room: {}", self.room); println!("> room: {}", self.room);
println!("> dma_buf: {}", self.dma_buf);
} }
} }

View File

@@ -197,6 +197,11 @@ fn is_encoder_supported(encoder: &String) -> bool {
gst::ElementFactory::find(encoder.as_str()).is_some() gst::ElementFactory::find(encoder.as_str()).is_some()
} }
fn set_element_property(element: &gst::Element, property: &str, value: &dyn ToValue) {
element.set_property(property, value.to_value());
}
/// Helper to set CQP value of known encoder /// Helper to set CQP value of known encoder
/// # Arguments /// # Arguments
/// * `encoder` - Information about the encoder. /// * `encoder` - Information about the encoder.
@@ -316,7 +321,7 @@ pub fn encoder_gop_params(encoder: &VideoEncoderInfo, gop_size: u32) -> VideoEnc
let prop_name = prop.name(); let prop_name = prop.name();
// Look for known keys // Look for known keys
if prop_name.to_lowercase().contains("gop") if prop_name.to_lowercase().contains("gop-size")
|| prop_name.to_lowercase().contains("int-max") || prop_name.to_lowercase().contains("int-max")
|| prop_name.to_lowercase().contains("max-dist") || prop_name.to_lowercase().contains("max-dist")
|| prop_name.to_lowercase().contains("intra-period-length") || prop_name.to_lowercase().contains("intra-period-length")

View File

@@ -1,21 +1,21 @@
mod args; mod args;
mod enc_helper; mod enc_helper;
mod gpu; mod gpu;
mod room;
mod websocket;
mod latency; mod latency;
mod messages; mod messages;
mod nestrisink;
mod websocket;
use crate::args::encoding_args; use crate::args::encoding_args;
use crate::nestrisink::NestriSignaller;
use crate::websocket::NestriWebSocket;
use futures_util::StreamExt;
use gst::prelude::*; use gst::prelude::*;
use gst_app::AppSink; use gstrswebrtc::signaller::Signallable;
use gstrswebrtc::webrtcsink::BaseWebRTCSink;
use std::error::Error; use std::error::Error;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use futures_util::StreamExt;
use gst_app::app_sink::AppSinkStream;
use tokio::sync::{Mutex};
use crate::websocket::{NestriWebSocket};
// Handles gathering GPU information and selecting the most suitable GPU // Handles gathering GPU information and selecting the most suitable GPU
fn handle_gpus(args: &args::Args) -> Option<gpu::GPUInfo> { fn handle_gpus(args: &args::Args) -> Option<gpu::GPUInfo> {
@@ -161,20 +161,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Begin connection attempt to the relay WebSocket endpoint // Begin connection attempt to the relay WebSocket endpoint
// replace any http/https with ws/wss // replace any http/https with ws/wss
let replaced_relay_url let replaced_relay_url = args
= args.app.relay_url.replace("http://", "ws://").replace("https://", "wss://"); .app
let ws_url = format!( .relay_url
"{}/api/ws/{}", .replace("http://", "ws://")
replaced_relay_url, .replace("https://", "wss://");
args.app.room, let ws_url = format!("{}/api/ws/{}", replaced_relay_url, args.app.room,);
);
// Setup our websocket // Setup our websocket
let nestri_ws = Arc::new(NestriWebSocket::new(ws_url).await?); let nestri_ws = Arc::new(NestriWebSocket::new(ws_url).await?);
log::set_max_level(log::LevelFilter::Info); log::set_max_level(log::LevelFilter::Info);
log::set_boxed_logger(Box::new(nestri_ws.clone())).unwrap(); log::set_boxed_logger(Box::new(nestri_ws.clone())).unwrap();
let _ = gst::init(); gst::init()?;
gstrswebrtc::plugin_register_static()?;
// Handle GPU selection // Handle GPU selection
let gpu = handle_gpus(&args); let gpu = handle_gpus(&args);
@@ -197,12 +197,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Handle audio encoder selection // Handle audio encoder selection
let audio_encoder = handle_encoder_audio(&args); let audio_encoder = handle_encoder_audio(&args);
/*** ROOM SETUP ***/
let room = Arc::new(Mutex::new(
room::Room::new(nestri_ws.clone()).await?,
));
/*** PIPELINE CREATION ***/ /*** PIPELINE CREATION ***/
// Create the pipeline
let pipeline = Arc::new(gst::Pipeline::new());
/* Audio */ /* Audio */
// Audio Source Element // Audio Source Element
let audio_source = match args.encoding.audio.capture_method { let audio_source = match args.encoding.audio.capture_method {
@@ -237,9 +235,6 @@ async fn main() -> Result<(), Box<dyn Error>> {
}, },
); );
// Audio RTP Payloader Element
let audio_rtp_payloader = gst::ElementFactory::make("rtpopuspay").build()?;
/* Video */ /* Video */
// Video Source Element // Video Source Element
let video_source = gst::ElementFactory::make("waylanddisplaysrc").build()?; let video_source = gst::ElementFactory::make("waylanddisplaysrc").build()?;
@@ -248,13 +243,26 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Caps Filter Element (resolution, fps) // Caps Filter Element (resolution, fps)
let caps_filter = gst::ElementFactory::make("capsfilter").build()?; let caps_filter = gst::ElementFactory::make("capsfilter").build()?;
let caps = gst::Caps::from_str(&format!( let caps = gst::Caps::from_str(&format!(
"video/x-raw,width={},height={},framerate={}/1,format=RGBx", "{},width={},height={},framerate={}/1{}",
args.app.resolution.0, args.app.resolution.1, args.app.framerate if args.app.dma_buf {
"video/x-raw(memory:DMABuf)"
} else {
"video/x-raw"
},
args.app.resolution.0,
args.app.resolution.1,
args.app.framerate,
if args.app.dma_buf { "" } else { ",format=RGBx" }
))?; ))?;
caps_filter.set_property("caps", &caps); caps_filter.set_property("caps", &caps);
// Video Tee Element // GL Upload Element
let video_tee = gst::ElementFactory::make("tee").build()?; let glupload = gst::ElementFactory::make("glupload").build()?;
// GL upload caps filter
let gl_caps_filter = gst::ElementFactory::make("capsfilter").build()?;
let gl_caps = gst::Caps::from_str("video/x-raw(memory:VAMemory)")?;
gl_caps_filter.set_property("caps", &gl_caps);
// Video Converter Element // Video Converter Element
let video_converter = gst::ElementFactory::make("videoconvert").build()?; let video_converter = gst::ElementFactory::make("videoconvert").build()?;
@@ -263,82 +271,30 @@ async fn main() -> Result<(), Box<dyn Error>> {
let video_encoder = gst::ElementFactory::make(video_encoder_info.name.as_str()).build()?; let video_encoder = gst::ElementFactory::make(video_encoder_info.name.as_str()).build()?;
video_encoder_info.apply_parameters(&video_encoder, &args.app.verbose); video_encoder_info.apply_parameters(&video_encoder, &args.app.verbose);
// Required for AV1 - av1parse
let av1_parse = gst::ElementFactory::make("av1parse").build()?;
// Video RTP Payloader Element
let video_rtp_payloader = gst::ElementFactory::make(
format!("rtp{}pay", video_encoder_info.codec.to_gst_str()).as_str(),
)
.build()?;
/* Output */ /* Output */
// Audio AppSink Element // WebRTC sink Element
let audio_appsink = gst::ElementFactory::make("appsink").build()?; let signaller = NestriSignaller::new(nestri_ws.clone(), pipeline.clone());
audio_appsink.set_property("emit-signals", &true); let webrtcsink = BaseWebRTCSink::with_signaller(Signallable::from(signaller.clone()));
let audio_appsink = audio_appsink.downcast_ref::<AppSink>().unwrap(); webrtcsink.set_property_from_str("stun-server", "stun://stun.l.google.com:19302");
webrtcsink.set_property_from_str("congestion-control", "disabled");
// Video AppSink Element
let video_appsink = gst::ElementFactory::make("appsink").build()?;
video_appsink.set_property("emit-signals", &true);
let video_appsink = video_appsink.downcast_ref::<AppSink>().unwrap();
/* Debug */
// Debug Feed Element
let debug_latency = gst::ElementFactory::make("timeoverlay").build()?;
debug_latency.set_property_from_str("halignment", &"right");
debug_latency.set_property_from_str("valignment", &"bottom");
// Debug Sink Element
let debug_sink = gst::ElementFactory::make("ximagesink").build()?;
// Debug video converter
let debug_video_converter = gst::ElementFactory::make("videoconvert").build()?;
// Queues with max 2ms latency
let debug_queue = gst::ElementFactory::make("queue2").build()?;
debug_queue.set_property("max-size-time", &1000000u64);
let main_video_queue = gst::ElementFactory::make("queue2").build()?;
main_video_queue.set_property("max-size-time", &1000000u64);
let main_audio_queue = gst::ElementFactory::make("queue2").build()?;
main_audio_queue.set_property("max-size-time", &1000000u64);
// Create the pipeline
let pipeline = gst::Pipeline::new();
// Add elements to the pipeline // Add elements to the pipeline
pipeline.add_many(&[ pipeline.add_many(&[
&video_appsink.upcast_ref(), webrtcsink.upcast_ref(),
&video_rtp_payloader,
&video_encoder, &video_encoder,
&video_converter, &video_converter,
&video_tee,
&caps_filter, &caps_filter,
&video_source, &video_source,
&audio_appsink.upcast_ref(),
&audio_rtp_payloader,
&audio_encoder, &audio_encoder,
&audio_capsfilter, &audio_capsfilter,
&audio_rate, &audio_rate,
&audio_converter, &audio_converter,
&audio_source, &audio_source,
&main_video_queue,
&main_audio_queue,
])?; ])?;
// Add debug elements if debug is enabled // If DMA-BUF is enabled, add glupload and gl caps filter
if args.app.debug_feed { if args.app.dma_buf {
pipeline.add_many(&[&debug_sink, &debug_queue, &debug_video_converter])?; pipeline.add_many(&[&glupload, &gl_caps_filter])?;
}
// Add debug latency element if debug latency is enabled
if args.app.debug_latency {
pipeline.add(&debug_latency)?;
}
// Add AV1 parse element if AV1 is selected
if video_encoder_info.codec == enc_helper::VideoCodec::AV1 {
pipeline.add(&av1_parse)?;
} }
// Link main audio branch // Link main audio branch
@@ -348,47 +304,29 @@ async fn main() -> Result<(), Box<dyn Error>> {
&audio_rate, &audio_rate,
&audio_capsfilter, &audio_capsfilter,
&audio_encoder, &audio_encoder,
&audio_rtp_payloader, webrtcsink.upcast_ref(),
&main_audio_queue,
&audio_appsink.upcast_ref(),
])?; ])?;
// If debug latency, add time overlay before tee // With DMA-BUF, also link glupload and it's caps
if args.app.debug_latency { if args.app.dma_buf {
gst::Element::link_many(&[&video_source, &caps_filter, &debug_latency, &video_tee])?; // Link video source to caps_filter, glupload, gl_caps_filter, video_converter, video_encoder, webrtcsink
} else {
gst::Element::link_many(&[&video_source, &caps_filter, &video_tee])?;
}
// Link debug branch if debug is enabled
if args.app.debug_feed {
gst::Element::link_many(&[ gst::Element::link_many(&[
&video_tee, &video_source,
&debug_video_converter, &caps_filter,
&debug_queue, &glupload,
&debug_sink, &gl_caps_filter,
])?;
}
// Link main video branch, if AV1, add av1_parse
if video_encoder_info.codec == enc_helper::VideoCodec::AV1 {
gst::Element::link_many(&[
&video_tee,
&video_converter, &video_converter,
&video_encoder, &video_encoder,
&av1_parse, webrtcsink.upcast_ref(),
&video_rtp_payloader,
&main_video_queue,
&video_appsink.upcast_ref(),
])?; ])?;
} else { } else {
// Link video source to caps_filter, video_converter, video_encoder, webrtcsink
gst::Element::link_many(&[ gst::Element::link_many(&[
&video_tee, &video_source,
&caps_filter,
&video_converter, &video_converter,
&video_encoder, &video_encoder,
&video_rtp_payloader, webrtcsink.upcast_ref(),
&main_video_queue,
&video_appsink.upcast_ref(),
])?; ])?;
} }
@@ -397,21 +335,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
audio_source.set_property("do-timestamp", &true); audio_source.set_property("do-timestamp", &true);
pipeline.set_property("latency", &0u64); pipeline.set_property("latency", &0u64);
// Wrap the pipeline in Arc<Mutex> to safely share it
let pipeline = Arc::new(Mutex::new(pipeline));
// Run both pipeline and websocket tasks concurrently // Run both pipeline and websocket tasks concurrently
let result = tokio::try_join!( let result = run_pipeline(pipeline.clone()).await;
run_room(
room.clone(),
"audio/opus",
video_encoder_info.codec.to_mime_str(),
pipeline.clone(),
Arc::new(Mutex::new(audio_appsink.stream())),
Arc::new(Mutex::new(video_appsink.stream()))
),
run_pipeline(pipeline.clone())
);
match result { match result {
Ok(_) => log::info!("All tasks completed successfully"), Ok(_) => log::info!("All tasks completed successfully"),
@@ -424,53 +349,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
Ok(()) Ok(())
} }
async fn run_room( async fn run_pipeline(pipeline: Arc<gst::Pipeline>) -> Result<(), Box<dyn Error>> {
room: Arc<Mutex<room::Room>>, let bus = { pipeline.bus().ok_or("Pipeline has no bus")? };
audio_codec: &str,
video_codec: &str,
pipeline: Arc<Mutex<gst::Pipeline>>,
audio_stream: Arc<Mutex<AppSinkStream>>,
video_stream: Arc<Mutex<AppSinkStream>>,
) -> Result<(), Box<dyn Error>> {
// Run loop, with recovery on error
loop {
let mut room = room.lock().await;
tokio::select! {
_ = tokio::signal::ctrl_c() => {
log::info!("Room interrupted via Ctrl+C");
return Ok(());
}
result = room.run(
audio_codec,
video_codec,
pipeline.clone(),
audio_stream.clone(),
video_stream.clone(),
) => {
if let Err(e) = result {
log::error!("Room error: {}", e);
// Sleep for a while before retrying
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
} else {
return Ok(());
}
}
}
}
}
async fn run_pipeline(
pipeline: Arc<Mutex<gst::Pipeline>>,
) -> Result<(), Box<dyn Error>> {
// Take ownership of the bus without holding the lock
let bus = {
let pipeline = pipeline.lock().await;
pipeline.bus().ok_or("Pipeline has no bus")?
};
{ {
// Temporarily lock the pipeline to change state
let pipeline = pipeline.lock().await;
if let Err(e) = pipeline.set_state(gst::State::Playing) { if let Err(e) = pipeline.set_state(gst::State::Playing) {
log::error!("Failed to start pipeline: {}", e); log::error!("Failed to start pipeline: {}", e);
return Err("Failed to start pipeline".into()); return Err("Failed to start pipeline".into());
@@ -491,8 +373,6 @@ async fn run_pipeline(
} }
{ {
// Temporarily lock the pipeline to reset state
let pipeline = pipeline.lock().await;
pipeline.set_state(gst::State::Null)?; pipeline.set_state(gst::State::Null)?;
} }

View File

@@ -10,6 +10,31 @@ use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::latency::LatencyTracker; use crate::latency::LatencyTracker;
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
pub enum InputMessage {
#[serde(rename = "mousemove")]
MouseMove { x: i32, y: i32 },
#[serde(rename = "mousemoveabs")]
MouseMoveAbs { x: i32, y: i32 },
#[serde(rename = "wheel")]
Wheel { x: f64, y: f64 },
#[serde(rename = "mousedown")]
MouseDown { key: i32 },
// Add other variants as needed
#[serde(rename = "mouseup")]
MouseUp { key: i32 },
#[serde(rename = "keydown")]
KeyDown { key: i32 },
#[serde(rename = "keyup")]
KeyUp { key: i32 },
}
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct MessageBase { pub struct MessageBase {
pub payload_type: String, pub payload_type: String,

View File

@@ -0,0 +1,466 @@
use crate::messages::{
decode_message_as, encode_message, AnswerType, InputMessage, JoinerType, MessageAnswer,
MessageBase, MessageICE, MessageInput, MessageJoin, MessageSDP,
};
use crate::websocket::NestriWebSocket;
use glib::subclass::prelude::*;
use gst::glib;
use gst::prelude::*;
use gst_webrtc::{gst_sdp, WebRTCSDPType, WebRTCSessionDescription};
use gstrswebrtc::signaller::{Signallable, SignallableImpl};
use std::collections::HashSet;
use std::sync::{Arc, LazyLock};
use std::sync::{Mutex, RwLock};
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
pub struct Signaller {
nestri_ws: RwLock<Option<Arc<NestriWebSocket>>>,
pipeline: RwLock<Option<Arc<gst::Pipeline>>>,
data_channel: RwLock<Option<gst_webrtc::WebRTCDataChannel>>,
}
impl Default for Signaller {
fn default() -> Self {
Self {
nestri_ws: RwLock::new(None),
pipeline: RwLock::new(None),
data_channel: RwLock::new(None),
}
}
}
impl Signaller {
pub fn set_nestri_ws(&self, nestri_ws: Arc<NestriWebSocket>) {
*self.nestri_ws.write().unwrap() = Some(nestri_ws);
}
pub fn set_pipeline(&self, pipeline: Arc<gst::Pipeline>) {
*self.pipeline.write().unwrap() = Some(pipeline);
}
pub fn get_pipeline(&self) -> Option<Arc<gst::Pipeline>> {
self.pipeline.read().unwrap().clone()
}
pub fn set_data_channel(&self, data_channel: gst_webrtc::WebRTCDataChannel) {
*self.data_channel.write().unwrap() = Some(data_channel);
}
/// Helper method to clean things up
fn register_callbacks(&self) {
let nestri_ws = {
self.nestri_ws
.read()
.unwrap()
.clone()
.expect("NestriWebSocket not set")
};
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("sdp", move |data| {
if let Ok(message) = decode_message_as::<MessageSDP>(data) {
let sdp =
gst_sdp::SDPMessage::parse_buffer(message.sdp.sdp.as_bytes()).unwrap();
let answer = WebRTCSessionDescription::new(WebRTCSDPType::Answer, sdp);
self_obj.emit_by_name::<()>(
"session-description",
&[&"unique-session-id", &answer],
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode SDP message");
}
});
}
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("ice", move |data| {
if let Ok(message) = decode_message_as::<MessageICE>(data) {
let candidate = message.candidate;
let sdp_m_line_index = candidate.sdp_mline_index.unwrap_or(0) as u32;
let sdp_mid = candidate.sdp_mid;
self_obj.emit_by_name::<()>(
"handle-ice",
&[
&"unique-session-id",
&sdp_m_line_index,
&sdp_mid,
&candidate.candidate,
],
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode ICE message");
}
});
}
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("answer", move |data| {
if let Ok(answer) = decode_message_as::<MessageAnswer>(data) {
gst::info!(gst::CAT_DEFAULT, "Received answer: {:?}", answer);
match answer.answer_type {
AnswerType::AnswerOK => {
gst::info!(gst::CAT_DEFAULT, "Received OK answer");
// Send our SDP offer
self_obj.emit_by_name::<()>(
"session-requested",
&[
&"unique-session-id",
&"consumer-identifier",
&None::<WebRTCSessionDescription>,
],
);
}
AnswerType::AnswerInUse => {
gst::error!(gst::CAT_DEFAULT, "Room is in use by another node");
}
AnswerType::AnswerOffline => {
gst::warning!(gst::CAT_DEFAULT, "Room is offline");
}
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode answer");
}
});
}
{
let self_obj = self.obj().clone();
// After creating webrtcsink
self_obj.connect_closure(
"webrtcbin-ready",
false,
glib::closure!(move |signaller: &super::NestriSignaller,
_consumer_identifier: &str,
webrtcbin: &gst::Element| {
gst::info!(gst::CAT_DEFAULT, "Adding data channels");
// Create data channels on webrtcbin
let data_channel = Some(
webrtcbin.emit_by_name::<gst_webrtc::WebRTCDataChannel>(
"create-data-channel",
&[
&"nestri-data-channel",
&gst::Structure::builder("config")
.field("ordered", &false)
.field("max-retransmits", &0u32)
.build(),
],
),
);
if let Some(data_channel) = data_channel {
gst::info!(gst::CAT_DEFAULT, "Data channel created");
if let Some(pipeline) = signaller.imp().get_pipeline() {
setup_data_channel(&data_channel, &pipeline);
signaller.imp().set_data_channel(data_channel);
} else {
gst::error!(gst::CAT_DEFAULT, "Wayland display source not set");
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to create data channel");
}
}),
);
}
}
}
impl SignallableImpl for Signaller {
fn start(&self) {
gst::info!(gst::CAT_DEFAULT, "Signaller started");
// Get WebSocket connection
let nestri_ws = {
self.nestri_ws
.read()
.unwrap()
.clone()
.expect("NestriWebSocket not set")
};
// Register message callbacks
self.register_callbacks();
// Subscribe to reconnection notifications
let reconnected_notify = nestri_ws.subscribe_reconnected();
// Clone necessary references
let self_clone = self.obj().clone();
let nestri_ws_clone = nestri_ws.clone();
// Spawn a task to handle actions upon reconnection
tokio::spawn(async move {
loop {
// Wait for a reconnection notification
reconnected_notify.notified().await;
println!("Reconnected to relay, re-negotiating...");
gst::warning!(gst::CAT_DEFAULT, "Reconnected to relay, re-negotiating...");
// Emit "session-ended" first to make sure the element is cleaned up
self_clone.emit_by_name::<bool>("session-ended", &[&"unique-session-id"]);
// Send a new join message
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
},
joiner_type: JoinerType::JoinerNode,
};
if let Ok(encoded) = encode_message(&join_msg) {
if let Err(e) = nestri_ws_clone.send_message(encoded) {
gst::error!(
gst::CAT_DEFAULT,
"Failed to send join message after reconnection: {:?}",
e
);
}
} else {
gst::error!(
gst::CAT_DEFAULT,
"Failed to encode join message after reconnection"
);
}
// If we need to interact with GStreamer or GLib, schedule it on the main thread
let self_clone_for_main = self_clone.clone();
glib::MainContext::default().invoke(move || {
// Emit the "session-requested" signal
self_clone_for_main.emit_by_name::<()>(
"session-requested",
&[
&"unique-session-id",
&"consumer-identifier",
&None::<WebRTCSessionDescription>,
],
);
});
}
});
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
},
joiner_type: JoinerType::JoinerNode,
};
if let Ok(encoded) = encode_message(&join_msg) {
if let Err(e) = nestri_ws.send_message(encoded) {
eprintln!("Failed to send join message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send join message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode join message");
}
}
fn stop(&self) {
gst::info!(gst::CAT_DEFAULT, "Signaller stopped");
}
fn send_sdp(&self, _session_id: &str, sdp: &WebRTCSessionDescription) {
let nestri_ws = {
self.nestri_ws
.read()
.unwrap()
.clone()
.expect("NestriWebSocket not set")
};
let sdp_message = MessageSDP {
base: MessageBase {
payload_type: "sdp".to_string(),
},
sdp: RTCSessionDescription::offer(sdp.sdp().as_text().unwrap()).unwrap(),
};
if let Ok(encoded) = encode_message(&sdp_message) {
if let Err(e) = nestri_ws.send_message(encoded) {
eprintln!("Failed to send SDP message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send SDP message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode SDP message");
}
}
fn add_ice(
&self,
_session_id: &str,
candidate: &str,
sdp_m_line_index: u32,
sdp_mid: Option<String>,
) {
let nestri_ws = {
self.nestri_ws
.read()
.unwrap()
.clone()
.expect("NestriWebSocket not set")
};
let candidate_init = RTCIceCandidateInit {
candidate: candidate.to_string(),
sdp_mid,
sdp_mline_index: Some(sdp_m_line_index as u16),
..Default::default()
};
let ice_message = MessageICE {
base: MessageBase {
payload_type: "ice".to_string(),
},
candidate: candidate_init,
};
if let Ok(encoded) = encode_message(&ice_message) {
if let Err(e) = nestri_ws.send_message(encoded) {
eprintln!("Failed to send ICE message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send ICE message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode ICE message");
}
}
fn end_session(&self, session_id: &str) {
gst::info!(gst::CAT_DEFAULT, "Ending session: {}", session_id);
}
}
#[glib::object_subclass]
impl ObjectSubclass for Signaller {
const NAME: &'static str = "NestriSignaller";
type Type = super::NestriSignaller;
type ParentType = glib::Object;
type Interfaces = (Signallable,);
}
impl ObjectImpl for Signaller {
fn properties() -> &'static [glib::ParamSpec] {
static PROPS: LazyLock<Vec<glib::ParamSpec>> = LazyLock::new(|| {
vec![glib::ParamSpecBoolean::builder("manual-sdp-munging")
.nick("Manual SDP munging")
.blurb("Whether the signaller manages SDP munging itself")
.default_value(false)
.read_only()
.build()]
});
PROPS.as_ref()
}
fn property(&self, _id: usize, pspec: &glib::ParamSpec) -> glib::Value {
match pspec.name() {
"manual-sdp-munging" => false.to_value(),
_ => unimplemented!(),
}
}
}
fn setup_data_channel(data_channel: &gst_webrtc::WebRTCDataChannel, pipeline: &gst::Pipeline) {
let pipeline = pipeline.clone();
// A shared state to track currently pressed keys
let pressed_keys = Arc::new(Mutex::new(HashSet::new()));
let pressed_buttons = Arc::new(Mutex::new(HashSet::new()));
data_channel.connect_on_message_data(move |_data_channel, data| {
if let Some(data) = data {
match decode_message_as::<MessageInput>(data.to_vec()) {
Ok(message_input) => {
// Deserialize the input message data
if let Ok(input_msg) = serde_json::from_str::<InputMessage>(&message_input.data)
{
// Process the input message and create an event
if let Some(event) =
handle_input_message(input_msg, &pressed_keys, &pressed_buttons)
{
// Send the event to pipeline, result bool is ignored
let _ = pipeline.send_event(event);
}
} else {
eprintln!("Failed to parse InputMessage");
}
}
Err(e) => {
eprintln!("Failed to decode MessageInput: {:?}", e);
}
}
}
});
}
fn handle_input_message(
input_msg: InputMessage,
pressed_keys: &Arc<Mutex<HashSet<i32>>>,
pressed_buttons: &Arc<Mutex<HashSet<i32>>>,
) -> Option<gst::Event> {
match input_msg {
InputMessage::MouseMove { x, y } => {
let structure = gst::Structure::builder("MouseMoveRelative")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseMoveAbs { x, y } => {
let structure = gst::Structure::builder("MouseMoveAbsolute")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyDown { key } => {
let mut keys = pressed_keys.lock().unwrap();
// If the key is already pressed, return to prevent key lockup
if keys.contains(&key) {
return None;
}
keys.insert(key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyUp { key } => {
let mut keys = pressed_keys.lock().unwrap();
// Remove the key from the pressed state when released
keys.remove(&key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::Wheel { x, y } => {
let structure = gst::Structure::builder("MouseAxis")
.field("x", x as f64)
.field("y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseDown { key } => {
let mut buttons = pressed_buttons.lock().unwrap();
// If the button is already pressed, return to prevent button lockup
if buttons.contains(&key) {
return None;
}
buttons.insert(key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseUp { key } => {
let mut buttons = pressed_buttons.lock().unwrap();
// Remove the button from the pressed state when released
buttons.remove(&key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
}
}

View File

@@ -0,0 +1,25 @@
use std::sync::Arc;
use gst::glib;
use gst::subclass::prelude::*;
use gstrswebrtc::signaller::Signallable;
use crate::websocket::NestriWebSocket;
mod imp;
glib::wrapper! {
pub struct NestriSignaller(ObjectSubclass<imp::Signaller>) @implements Signallable;
}
impl NestriSignaller {
pub fn new(nestri_ws: Arc<NestriWebSocket>, pipeline: Arc<gst::Pipeline>) -> Self {
let obj: Self = glib::Object::new();
obj.imp().set_nestri_ws(nestri_ws);
obj.imp().set_pipeline(pipeline);
obj
}
}
impl Default for NestriSignaller {
fn default() -> Self {
panic!("Cannot create NestriSignaller without NestriWebSocket");
}
}

View File

@@ -1,540 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use std::collections::{HashSet};
use std::error::Error;
use std::sync::Arc;
use futures_util::StreamExt;
use gst::prelude::ElementExtManual;
use gst_app::app_sink::AppSinkStream;
use tokio::sync::{oneshot, Mutex};
use tokio::sync::{mpsc};
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::ice_transport::ice_gathering_state::RTCIceGatheringState;
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
use webrtc::track::track_local::TrackLocalWriter;
use crate::messages::*;
use crate::websocket::NestriWebSocket;
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
enum InputMessage {
#[serde(rename = "mousemove")]
MouseMove { x: i32, y: i32 },
#[serde(rename = "mousemoveabs")]
MouseMoveAbs { x: i32, y: i32 },
#[serde(rename = "wheel")]
Wheel { x: f64, y: f64 },
#[serde(rename = "mousedown")]
MouseDown { key: i32 },
// Add other variants as needed
#[serde(rename = "mouseup")]
MouseUp { key: i32 },
#[serde(rename = "keydown")]
KeyDown { key: i32 },
#[serde(rename = "keyup")]
KeyUp { key: i32 },
}
pub struct Room {
nestri_ws: Arc<NestriWebSocket>,
webrtc_api: webrtc::api::API,
webrtc_config: RTCConfiguration,
}
impl Room {
pub async fn new(
nestri_ws: Arc<NestriWebSocket>,
) -> Result<Room, Box<dyn Error>> {
// Create media engine and register default codecs
let mut media_engine = MediaEngine::default();
media_engine.register_default_codecs()?;
// Registry
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media_engine)?;
// Create the API object with the MediaEngine
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
// Prepare the configuration
let config = RTCConfiguration {
ice_servers: vec![RTCIceServer {
urls: vec!["stun:stun.l.google.com:19302".to_owned()],
..Default::default()
}],
..Default::default()
};
Ok(Self {
nestri_ws,
webrtc_api: api,
webrtc_config: config,
})
}
pub async fn run(
&mut self,
audio_codec: &str,
video_codec: &str,
pipeline: Arc<Mutex<gst::Pipeline>>,
audio_sink: Arc<Mutex<AppSinkStream>>,
video_sink: Arc<Mutex<AppSinkStream>>,
) -> Result<(), Box<dyn Error>> {
let (tx, rx) = oneshot::channel();
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
self.nestri_ws
.register_callback("answer", {
let tx = tx.clone();
move |data| {
if let Ok(answer) = decode_message_as::<MessageAnswer>(data) {
log::info!("Received answer: {:?}", answer);
match answer.answer_type {
AnswerType::AnswerOffline => {
log::warn!("Room is offline, we shouldn't be receiving this");
}
AnswerType::AnswerInUse => {
log::error!("Room is in use by another node!");
}
AnswerType::AnswerOK => {
// Notify that we got an OK answer
if let Some(tx) = tx.lock().unwrap().take() {
if let Err(_) = tx.send(()) {
log::error!("Failed to send OK answer signal");
}
}
}
}
} else {
log::error!("Failed to decode answer");
}
}
})
.await;
// Send a request to join the room
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
},
joiner_type: JoinerType::JoinerNode,
};
if let Ok(encoded) = encode_message(&join_msg) {
self.nestri_ws.send_message(encoded).await?;
} else {
log::error!("Failed to encode join message");
return Err("Failed to encode join message".into());
}
// Wait for the signal indicating that we have received an OK answer
match rx.await {
Ok(()) => {
log::info!("Received OK answer, proceeding...");
}
Err(_) => {
log::error!("Oneshot channel closed unexpectedly");
return Err("Unexpected error while waiting for OK answer".into());
}
}
// Create a new RTCPeerConnection
let config = self.webrtc_config.clone();
let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?);
// Create audio track
let audio_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: audio_codec.to_owned(),
..Default::default()
},
"audio".to_owned(),
"audio-nestri-server".to_owned(),
));
// Create video track
let video_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: video_codec.to_owned(),
..Default::default()
},
"video".to_owned(),
"video-nestri-server".to_owned(),
));
// Cancellation token to stop spawned tasks after peer connection is closed
let cancel_token = tokio_util::sync::CancellationToken::new();
// Add audio track to peer connection
let audio_sender = peer_connection.add_track(audio_track.clone()).await?;
let audio_sender_token = cancel_token.child_token();
tokio::spawn(async move {
loop {
let mut rtcp_buf = vec![0u8; 1500];
tokio::select! {
_ = audio_sender_token.cancelled() => {
break;
}
_ = audio_sender.read(&mut rtcp_buf) => {}
}
}
});
// Add video track to peer connection
let video_sender = peer_connection.add_track(video_track.clone()).await?;
let video_sender_token = cancel_token.child_token();
tokio::spawn(async move {
loop {
let mut rtcp_buf = vec![0u8; 1500];
tokio::select! {
_ = video_sender_token.cancelled() => {
break;
}
_ = video_sender.read(&mut rtcp_buf) => {}
}
}
});
// Create a datachannel with label 'input'
let data_channel_opts = Some(RTCDataChannelInit {
ordered: Some(false),
max_retransmits: Some(0),
..Default::default()
});
let data_channel
= peer_connection.create_data_channel("input", data_channel_opts).await?;
// PeerConnection state change tracker
let (pc_sndr, mut pc_recv) = mpsc::channel(1);
// Peer connection state change handler
peer_connection.on_peer_connection_state_change(Box::new(
move |s: RTCPeerConnectionState| {
let pc_sndr = pc_sndr.clone();
Box::pin(async move {
log::info!("PeerConnection State has changed: {s}");
if s == RTCPeerConnectionState::Failed
|| s == RTCPeerConnectionState::Disconnected
|| s == RTCPeerConnectionState::Closed
{
// Notify pc_state that the peer connection has closed
if let Err(e) = pc_sndr.send(s).await {
log::error!("Failed to send PeerConnection state: {}", e);
}
}
})
},
));
peer_connection.on_ice_gathering_state_change(Box::new(move |s| {
Box::pin(async move {
log::info!("ICE Gathering State has changed: {s}");
})
}));
peer_connection.on_ice_connection_state_change(Box::new(move |s| {
Box::pin(async move {
log::info!("ICE Connection State has changed: {s}");
})
}));
// Trickle ICE over WebSocket
let ws = self.nestri_ws.clone();
peer_connection.on_ice_candidate(Box::new(move |c| {
let nestri_ws = ws.clone();
Box::pin(async move {
if let Some(candidate) = c {
let candidate_json = candidate.to_json().unwrap();
let ice_msg = MessageICE {
base: MessageBase {
payload_type: "ice".to_string(),
},
candidate: candidate_json,
};
if let Ok(encoded) = encode_message(&ice_msg) {
let _ = nestri_ws.send_message(encoded);
}
}
})
}));
// Temporary ICE candidate buffer until remote description is set
let ice_holder: Arc<Mutex<Vec<RTCIceCandidateInit>>> = Arc::new(Mutex::new(Vec::new()));
// Register set_response_callback for ICE candidate
let pc = peer_connection.clone();
let ice_clone = ice_holder.clone();
self.nestri_ws.register_callback("ice", move |data| {
match decode_message_as::<MessageICE>(data) {
Ok(message) => {
log::info!("Received ICE message");
let candidate = RTCIceCandidateInit::from(message.candidate);
let pc = pc.clone();
let ice_clone = ice_clone.clone();
tokio::spawn(async move {
// If remote description is not set, buffer ICE candidates
if pc.remote_description().await.is_none() {
let mut ice_holder = ice_clone.lock().await;
ice_holder.push(candidate);
} else {
if let Err(e) = pc.add_ice_candidate(candidate).await {
log::error!("Failed to add ICE candidate: {}", e);
} else {
// Add any held ICE candidates
let mut ice_holder = ice_clone.lock().await;
for candidate in ice_holder.drain(..) {
if let Err(e) = pc.add_ice_candidate(candidate).await {
log::error!("Failed to add ICE candidate: {}", e);
}
}
}
}
});
}
Err(e) => eprintln!("Failed to decode callback message: {:?}", e),
}
}).await;
// A shared state to track currently pressed keys
let pressed_keys = Arc::new(Mutex::new(HashSet::new()));
let pressed_buttons = Arc::new(Mutex::new(HashSet::new()));
// Data channel message handler
data_channel.on_message(Box::new(move |msg: DataChannelMessage| {
let pipeline = pipeline.clone();
let pressed_keys = pressed_keys.clone();
let pressed_buttons = pressed_buttons.clone();
Box::pin({
async move {
// We don't care about string messages for now
if !msg.is_string {
// Decode the message as an MessageInput (binary encoded gzip)
match decode_message_as::<MessageInput>(msg.data.to_vec()) {
Ok(message_input) => {
// Handle the input message
if let Ok(input_msg) = from_str::<InputMessage>(&message_input.data) {
if let Some(event) =
handle_input_message(input_msg, &pressed_keys, &pressed_buttons).await
{
let _ = pipeline.lock().await.send_event(event);
}
}
}
Err(e) => {
log::error!("Failed to decode input message: {:?}", e);
}
}
}
}
})
}));
log::info!("Creating offer...");
// Create an offer to send to the browser
let offer = peer_connection.create_offer(None).await?;
log::info!("Setting local description...");
// Sets the LocalDescription, and starts our UDP listeners
peer_connection.set_local_description(offer).await?;
log::info!("Local description set...");
if let Some(local_description) = peer_connection.local_description().await {
// Wait until we have gathered all ICE candidates
while peer_connection.ice_gathering_state() != RTCIceGatheringState::Complete {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
// Register set_response_callback for SDP answer
let pc = peer_connection.clone();
self.nestri_ws.register_callback("sdp", move |data| {
match decode_message_as::<MessageSDP>(data) {
Ok(message) => {
log::info!("Received SDP message");
let sdp = message.sdp;
let pc = pc.clone();
tokio::spawn(async move {
if let Err(e) = pc.set_remote_description(sdp).await {
log::error!("Failed to set remote description: {}", e);
}
});
}
Err(e) => eprintln!("Failed to decode callback message: {:?}", e),
}
}).await;
log::info!("Sending local description to remote...");
// Encode and send the local description via WebSocket
let sdp_msg = MessageSDP {
base: MessageBase {
payload_type: "sdp".to_string(),
},
sdp: local_description,
};
let encoded = encode_message(&sdp_msg)?;
self.nestri_ws.send_message(encoded).await?;
} else {
log::error!("generate local_description failed!");
cancel_token.cancel();
return Err("generate local_description failed!".into());
};
// Send video and audio data
let audio_track = audio_track.clone();
tokio::spawn(async move {
let mut audio_sink = audio_sink.lock().await;
while let Some(sample) = audio_sink.next().await {
if let Some(buffer) = sample.buffer() {
if let Ok(map) = buffer.map_readable() {
if let Err(e) = audio_track.write(map.as_slice()).await {
if webrtc::Error::ErrClosedPipe == e {
break;
} else {
log::error!("Failed to write audio track: {}", e);
}
}
}
}
}
});
let video_track = video_track.clone();
tokio::spawn(async move {
let mut video_sink = video_sink.lock().await;
while let Some(sample) = video_sink.next().await {
if let Some(buffer) = sample.buffer() {
if let Ok(map) = buffer.map_readable() {
if let Err(e) = video_track.write(map.as_slice()).await {
if webrtc::Error::ErrClosedPipe == e {
break;
} else {
log::error!("Failed to write video track: {}", e);
}
}
}
}
}
});
// Block until closed or error
tokio::select! {
_ = pc_recv.recv() => {
log::info!("Peer connection closed with state: {:?}", peer_connection.connection_state());
}
}
cancel_token.cancel();
// Make double-sure to close the peer connection
if let Err(e) = peer_connection.close().await {
log::error!("Failed to close peer connection: {}", e);
}
Ok(())
}
}
async fn handle_input_message(
input_msg: InputMessage,
pressed_keys: &Arc<Mutex<HashSet<i32>>>,
pressed_buttons: &Arc<Mutex<HashSet<i32>>>,
) -> Option<gst::Event> {
match input_msg {
InputMessage::MouseMove { x, y } => {
let structure = gst::Structure::builder("MouseMoveRelative")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseMoveAbs { x, y } => {
let structure = gst::Structure::builder("MouseMoveAbsolute")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyDown { key } => {
let mut keys = pressed_keys.lock().await;
// If the key is already pressed, return to prevent key lockup
if keys.contains(&key) {
return None;
}
keys.insert(key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyUp { key } => {
let mut keys = pressed_keys.lock().await;
// Remove the key from the pressed state when released
keys.remove(&key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::Wheel { x, y } => {
let structure = gst::Structure::builder("MouseAxis")
.field("x", x as f64)
.field("y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseDown { key } => {
let mut buttons = pressed_buttons.lock().await;
// If the button is already pressed, return to prevent button lockup
if buttons.contains(&key) {
return None;
}
buttons.insert(key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseUp { key } => {
let mut buttons = pressed_buttons.lock().await;
// Remove the button from the pressed state when released
buttons.remove(&key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
}
}

View File

@@ -1,18 +1,17 @@
use crate::messages::{decode_message, encode_message, MessageBase, MessageLog};
use futures_util::sink::SinkExt; use futures_util::sink::SinkExt;
use futures_util::stream::{SplitSink, SplitStream}; use futures_util::stream::{SplitSink, SplitStream};
use futures_util::StreamExt; use futures_util::StreamExt;
use log::{Level, Log, Metadata, Record}; use log::{Level, Log, Metadata, Record};
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::sync::Arc; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::Mutex; use tokio::sync::{mpsc, Mutex, Notify};
use tokio::sync::RwLock;
use tokio::time::sleep; use tokio::time::sleep;
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use crate::messages::{decode_message, encode_message, MessageBase, MessageLog};
type Callback = Box<dyn Fn(Vec<u8>) + Send + Sync>; type Callback = Box<dyn Fn(Vec<u8>) + Send + Sync>;
type WSRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>; type WSRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
@@ -21,28 +20,36 @@ type WSWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
#[derive(Clone)] #[derive(Clone)]
pub struct NestriWebSocket { pub struct NestriWebSocket {
ws_url: String, ws_url: String,
reader: Arc<Mutex<WSRead>>, reader: Arc<Mutex<Option<WSRead>>>,
writer: Arc<Mutex<WSWrite>>, writer: Arc<Mutex<Option<WSWrite>>>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>, callbacks: Arc<RwLock<HashMap<String, Callback>>>,
message_tx: mpsc::UnboundedSender<Vec<u8>>,
reconnected_notify: Arc<Notify>,
} }
impl NestriWebSocket { impl NestriWebSocket {
pub async fn new( pub async fn new(ws_url: String) -> Result<NestriWebSocket, Box<dyn Error>> {
ws_url: String,
) -> Result<NestriWebSocket, Box<dyn Error>> {
// Attempt to connect to the WebSocket // Attempt to connect to the WebSocket
let ws_stream = NestriWebSocket::do_connect(&ws_url).await.unwrap(); let ws_stream = NestriWebSocket::do_connect(&ws_url).await.unwrap();
// If the connection is successful, split the stream // Split the stream into read and write halves
let (write, read) = ws_stream.split(); let (write, read) = ws_stream.split();
let mut ws = NestriWebSocket {
// Create the message channel
let (message_tx, message_rx) = mpsc::unbounded_channel();
let ws = NestriWebSocket {
ws_url, ws_url,
reader: Arc::new(Mutex::new(read)), reader: Arc::new(Mutex::new(Some(read))),
writer: Arc::new(Mutex::new(write)), writer: Arc::new(Mutex::new(Some(write))),
callbacks: Arc::new(RwLock::new(HashMap::new())), callbacks: Arc::new(RwLock::new(HashMap::new())),
message_tx: message_tx.clone(),
reconnected_notify: Arc::new(Notify::new()),
}; };
// Spawn the read loop // Spawn the read loop
ws.spawn_read_loop(); ws.spawn_read_loop();
// Spawn the write loop
ws.spawn_write_loop(message_rx);
Ok(ws) Ok(ws)
} }
@@ -57,24 +64,37 @@ impl NestriWebSocket {
} }
Err(e) => { Err(e) => {
eprintln!("Failed to connect to WebSocket, retrying: {:?}", e); eprintln!("Failed to connect to WebSocket, retrying: {:?}", e);
sleep(Duration::from_secs(1)).await; // Wait before retrying sleep(Duration::from_secs(3)).await; // Wait before retrying
} }
} }
} }
} }
// Handles message -> callback calls and reconnects on error/disconnect // Handles message -> callback calls and reconnects on error/disconnect
fn spawn_read_loop(&mut self) { fn spawn_read_loop(&self) {
let reader = self.reader.clone(); let reader = self.reader.clone();
let callbacks = self.callbacks.clone(); let callbacks = self.callbacks.clone();
let self_clone = self.clone();
let mut self_clone = self.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let message = reader.lock().await.next().await; // Lock the reader to get the WSRead, then drop the lock
match message { let ws_read_option = {
Some(Ok(message)) => { let mut reader_lock = reader.lock().await;
reader_lock.take()
};
let mut ws_read = match ws_read_option {
Some(ws_read) => ws_read,
None => {
eprintln!("Reader is None, cannot proceed");
return;
}
};
while let Some(message_result) = ws_read.next().await {
match message_result {
Ok(message) => {
let data = message.into_data(); let data = message.into_data();
let base_message = match decode_message(&data) { let base_message = match decode_message(&data) {
Ok(base_message) => base_message, Ok(base_message) => base_message,
@@ -84,62 +104,120 @@ impl NestriWebSocket {
} }
}; };
let callbacks_lock = callbacks.read().await; let callbacks_lock = callbacks.read().unwrap();
if let Some(callback) = callbacks_lock.get(&base_message.payload_type) { if let Some(callback) = callbacks_lock.get(&base_message.payload_type) {
let data = data.clone(); let data = data.clone();
callback(data); callback(data);
} }
} }
Some(Err(e)) => { Err(e) => {
eprintln!("Error receiving message: {:?}", e); eprintln!("Error receiving message: {:?}, reconnecting in 3 seconds...", e);
sleep(Duration::from_secs(3)).await;
self_clone.reconnect().await.unwrap(); self_clone.reconnect().await.unwrap();
break; // Break the inner loop to get a new ws_read
} }
None => {
eprintln!("WebSocket connection closed, reconnecting...");
self_clone.reconnect().await.unwrap();
} }
} }
// After reconnection, the loop continues, and we acquire a new ws_read
}
});
}
fn spawn_write_loop(&self, mut message_rx: mpsc::UnboundedReceiver<Vec<u8>>) {
let writer = self.writer.clone();
let self_clone = self.clone();
tokio::spawn(async move {
loop {
// Wait for a message from the channel
if let Some(message) = message_rx.recv().await {
loop {
// Acquire the writer lock
let mut writer_lock = writer.lock().await;
if let Some(writer) = writer_lock.as_mut() {
// Try to send the message over the WebSocket
match writer.send(Message::Binary(message.clone())).await {
Ok(_) => {
// Message sent successfully
break;
}
Err(e) => {
eprintln!("Error sending message: {:?}", e);
// Attempt to reconnect
if let Err(e) = self_clone.reconnect().await {
eprintln!("Error during reconnection: {:?}", e);
// Wait before retrying
sleep(Duration::from_secs(3)).await;
continue;
}
}
}
} else {
eprintln!("Writer is None, cannot send message");
// Attempt to reconnect
if let Err(e) = self_clone.reconnect().await {
eprintln!("Error during reconnection: {:?}", e);
// Wait before retrying
sleep(Duration::from_secs(3)).await;
continue;
}
}
}
} else {
break;
}
} }
}); });
} }
async fn reconnect(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> { async fn reconnect(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
// Keep trying to reconnect until successful
loop { loop {
match NestriWebSocket::do_connect(&self.ws_url).await { match NestriWebSocket::do_connect(&self.ws_url).await {
Ok(ws_stream) => { Ok(ws_stream) => {
let (write, read) = ws_stream.split(); let (write, read) = ws_stream.split();
*self.reader.lock().await = read; {
*self.writer.lock().await = write; let mut writer_lock = self.writer.lock().await;
*writer_lock = Some(write);
}
{
let mut reader_lock = self.reader.lock().await;
*reader_lock = Some(read);
}
// Notify subscribers of successful reconnection
self.reconnected_notify.notify_waiters();
return Ok(()); return Ok(());
} }
Err(e) => { Err(e) => {
eprintln!("Failed to reconnect to WebSocket: {:?}", e); eprintln!("Failed to reconnect to WebSocket: {:?}", e);
sleep(Duration::from_secs(2)).await; // Wait before retrying sleep(Duration::from_secs(3)).await; // Wait before retrying
} }
} }
} }
} }
pub async fn send_message(&self, message: Vec<u8>) -> Result<(), Box<dyn Error>> { /// Send a message through the WebSocket
let mut writer_lock = self.writer.lock().await; pub fn send_message(&self, message: Vec<u8>) -> Result<(), Box<dyn Error>> {
writer_lock self.message_tx
.send(Message::Binary(message)) .send(message)
.await .map_err(|e| format!("Failed to send message: {:?}", e).into())
.map_err(|e| format!("Error sending message: {:?}", e).into())
} }
pub async fn register_callback<F>(&self, response_type: &str, callback: F) /// Register a callback for a specific response type
pub fn register_callback<F>(&self, response_type: &str, callback: F)
where where
F: Fn(Vec<u8>) + Send + Sync + 'static, F: Fn(Vec<u8>) + Send + Sync + 'static,
{ {
let mut callbacks_lock = self.callbacks.write().await; let mut callbacks_lock = self.callbacks.write().unwrap();
callbacks_lock.insert(response_type.to_string(), Box::new(callback)); callbacks_lock.insert(response_type.to_string(), Box::new(callback));
} }
/// Subscribe to event for reconnection
pub fn subscribe_reconnected(&self) -> Arc<Notify> {
self.reconnected_notify.clone()
}
} }
impl Log for NestriWebSocket { impl Log for NestriWebSocket {
fn enabled(&self, metadata: &Metadata) -> bool { fn enabled(&self, metadata: &Metadata) -> bool {
// Filter logs by level
metadata.level() <= Level::Info metadata.level() <= Level::Info
} }
@@ -162,7 +240,9 @@ impl Log for NestriWebSocket {
time, time,
}; };
if let Ok(encoded_message) = encode_message(&log_message) { if let Ok(encoded_message) = encode_message(&log_message) {
let _ = self.send_message(encoded_message); if let Err(e) = self.send_message(encoded_message) {
eprintln!("Failed to send log message: {:?}", e);
}
} }
} }
} }
@@ -171,4 +251,3 @@ impl Log for NestriWebSocket {
// No-op for this logger // No-op for this logger
} }
} }