feat: Custom gst webrtc signaller, runtime GPU driver package install and more (#140)

🔥 🔥

Yes lots of commits because rebasing and all.. thankfully I know Git
just enough to have backups 😅

---------

Co-authored-by: Wanjohi <elviswanjohi47@gmail.com>
Co-authored-by: Kristian Ollikainen <DatCaptainHorse@users.noreply.github.com>
Co-authored-by: Wanjohi <71614375+wanjohiryan@users.noreply.github.com>
Co-authored-by: AquaWolf <3daquawolf@gmail.com>
This commit is contained in:
Kristian Ollikainen
2024-12-08 15:37:36 +02:00
committed by GitHub
parent 20d5ff511e
commit b6196b1c69
27 changed files with 3402 additions and 1349 deletions

2
.gitignore vendored
View File

@@ -48,5 +48,5 @@ bun.lockb
#tests
id_*
# Rust
#Rust
target

2036
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -12,5 +12,6 @@ edition = "2021"
rust-version = "1.80"
[workspace.dependencies]
gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", version = "0.24.0" }
gst-app = { package = "gstreamer-app", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", version = "0.24.0" }
gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", branch = "main", features = ["v1_24"] }
gst-webrtc = { package = "gstreamer-webrtc", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", branch = "main", features = ["v1_24"] }
gstrswebrtc = { package = "gst-plugin-webrtc", git = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs", branch = "main", features = ["v1_22"] }

206
Containerfile.master Normal file
View File

@@ -0,0 +1,206 @@
#! Runs the docker server that handles everything else
#******************************************************************************
# base
#******************************************************************************
FROM archlinux:base-20241027.0.273886 AS base
# How to run - docker run -it --rm --device /dev/dri nestri /bin/bash - DO NOT forget the ports
# TODO: Migrate XDG_RUNTIME_DIR to /run/user/1000
# TODO: Add nestri-server to pulseaudio.conf
# TODO: Add our own entrypoint, with our very own zombie ripper 🧟🏾‍♀️
# FIXME: Add user root to `pulse-access` group as well :D
# TODO: Test the whole damn thing
# Update the pacman repo
RUN \
pacman -Syu --noconfirm
#******************************************************************************
# builder
#******************************************************************************
FROM base AS builder
RUN \
pacman -Su --noconfirm \
base-devel \
git \
sudo \
vim
WORKDIR /scratch
# Allow nobody user to invoke pacman to install packages (as part of makepkg) and modify the system.
# This should never exist in a running image, just used by *-build Docker stages.
RUN \
echo "nobody ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers;
ENV ARTIFACTS=/artifacts \
CARGO_TARGET_DIR=/build
RUN \
mkdir -p /artifacts \
&& mkdir -p /build
RUN \
chgrp nobody /scratch /artifacts /build \
&& chmod g+ws /scratch /artifacts /build
#******************************************************************************
# rust-builder
#******************************************************************************
FROM builder AS rust-builder
RUN \
pacman -Su --noconfirm \
rustup
RUN \
rustup default stable
#******************************************************************************
# nestri-server-builder
#******************************************************************************
# Builds nestri server binary
FROM rust-builder AS nestri-server-builder
RUN \
pacman -Su --noconfirm \
wayland \
vpl-gpu-rt \
gstreamer \
gst-plugin-va \
gst-plugins-base \
gst-plugins-good \
mesa-utils \
weston \
xorg-xwayland
#******************************************************************************
# nestri-server-build
#******************************************************************************
FROM nestri-server-builder AS nestri-server-build
#Allow makepkg to be run as nobody.
RUN chgrp -R nobody /scratch && chmod -R g+ws /scratch
# USER nobody
# Perform the server build.
WORKDIR /scratch/server
RUN \
git clone https://github.com/nestriness/nestri
WORKDIR /scratch/server/nestri
RUN \
git checkout feat/stream \
&& cargo build -j$(nproc) --release
# COPY packages/server/build/ /scratch/server/
# RUN makepkg && cp *.zst "$ARTIFACTS"
#******************************************************************************
# runtime_base_pkgs
#******************************************************************************
FROM base AS runtime_base_pkgs
COPY --from=nestri-server-build /build/release/nestri-server /usr/bin/
#******************************************************************************
# runtime_base
#******************************************************************************
FROM runtime_base_pkgs AS runtime_base
RUN \
pacman -Su --noconfirm \
weston \
sudo \
xorg-xwayland \
gstreamer \
gst-plugins-base \
gst-plugins-good \
gst-plugin-qsv \
gst-plugin-va \
gst-plugin-fmp4 \
mesa \
# Grab GPU encoding packages
# Intel (modern VPL + VA-API)
vpl-gpu-rt \
intel-media-driver \
# AMD/ATI (VA-API)
libva-mesa-driver \
# NVIDIA (proprietary)
nvidia-utils \
# Audio
pulseaudio \
# Supervisor
supervisor
RUN \
# Set up our non-root user $(nestri)
groupadd -g 1000 nestri \
&& useradd -ms /bin/bash nestri -u 1000 -g 1000 \
&& passwd -d nestri \
# Setup Pulseaudio
&& useradd -d /var/run/pulse -s /usr/bin/nologin -G audio pulse \
&& groupadd pulse-access \
&& usermod -aG audio,input,render,video,pulse-access nestri \
&& echo "nestri ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers \
&& echo "Users created" \
# Create an empty machine-id file
&& touch /etc/machine-id
ENV \
XDG_RUNTIME_DIR=/tmp
#******************************************************************************
# runtime
#******************************************************************************
FROM runtime_base AS runtime
# Setup supervisor #
RUN <<-EOF
echo -e "
[supervisord]
user=root
nodaemon=true
loglevel=info
logfile=/tmp/supervisord.log
pidfile=/tmp/supervisord.pid
[program:dbus]
user=root
command=dbus-daemon --system --nofork
logfile=/tmp/dbus.log
pidfile=/tmp/dbus.pid
stopsignal=INT
autostart=true
autorestart=true
priority=1
[program:pulseaudio]
user=root
command=pulseaudio --daemonize=no --system --disallow-module-loading --disallow-exit --exit-idle-time=-1
logfile=/tmp/pulseaudio.log
pidfile=/tmp/pulseaudio.pid
stopsignal=INT
autostart=true
autorestart=true
priority=10
" | tee /etc/supervisord.conf
EOF
RUN \
chown -R nestri:nestri /tmp /etc/supervisord.conf
ENV USER=nestri
USER 1000
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisord.conf"]
# Debug - pactl list

20
Containerfile.relay Normal file
View File

@@ -0,0 +1,20 @@
FROM docker.io/golang:1.23-alpine AS go-build
WORKDIR /builder
COPY packages/relay/ /builder/
RUN go build
FROM docker.io/golang:1.23-alpine
COPY --from=go-build /builder/relay /relay/relay
WORKDIR /relay
# ENV flags
ENV VERBOSE=false
ENV ENDPOINT_PORT=8088
ENV WEBRTC_UDP_START=10000
ENV WEBRTC_UDP_END=20000
ENV STUN_SERVER="stun.l.google.com:19302"
EXPOSE $ENDPOINT_PORT
EXPOSE $WEBRTC_UDP_START-$WEBRTC_UDP_END/udp
ENTRYPOINT ["/relay/relay"]

119
Containerfile.runner Normal file
View File

@@ -0,0 +1,119 @@
# Container build arguments #
ARG BASE_IMAGE=docker.io/cachyos/cachyos-v3:latest
#******************************************************************************
# gst-builder
#******************************************************************************
FROM ${BASE_IMAGE} AS gst-builder
WORKDIR /builder/
# Grab build and rust packages #
RUN pacman -Syu --noconfirm meson pkgconf cmake git gcc make rustup \
gstreamer gst-plugins-base gst-plugins-good gst-plugin-rswebrtc
# Setup stable rust toolchain #
RUN rustup default stable
# Clone nestri source #
RUN git clone -b feat/stream https://github.com/DatCaptainHorse/nestri.git
# Build nestri #
RUN cd nestri/packages/server/ && \
cargo build --release
#******************************************************************************
# gstwayland-builder
#******************************************************************************
FROM ${BASE_IMAGE} AS gstwayland-builder
WORKDIR /builder/
# Grab build and rust packages #
RUN pacman -Syu --noconfirm meson pkgconf cmake git gcc make rustup \
libxkbcommon wayland gstreamer gst-plugins-base gst-plugins-good libinput
# Setup stable rust toolchain #
RUN rustup default stable
# Build required cargo-c package #
RUN cargo install cargo-c
# Clone gst plugin source #
RUN git clone https://github.com/games-on-whales/gst-wayland-display.git
# Build gst plugin #
RUN mkdir plugin && \
cd gst-wayland-display && \
cargo cinstall --prefix=/builder/plugin/
#******************************************************************************
# runtime
#******************************************************************************
FROM ${BASE_IMAGE} AS runtime
## Install Graphics, Media, and Audio packages ##
RUN pacman -Syu --noconfirm --needed \
# Graphics packages
sudo xorg-xwayland labwc wlr-randr mangohud \
# GStreamer and plugins
gstreamer gst-plugins-base gst-plugins-good \
gst-plugins-bad gst-plugin-pipewire \
gst-plugin-rswebrtc gst-plugin-rsrtp \
# Audio packages
pipewire pipewire-pulse pipewire-alsa wireplumber \
# Other requirements
supervisor jq chwd lshw pacman-contrib && \
# Clean up pacman cache
paccache -rk1
## User ##
# Create and setup user #
ENV USER="nestri" \
UID=99 \
GID=100 \
USER_PASSWORD="nestri1234"
RUN mkdir -p /home/${USER} && \
groupadd -g ${GID} ${USER} && \
useradd -d /home/${USER} -u ${UID} -g ${GID} -s /bin/bash ${USER} && \
chown -R ${USER}:${USER} /home/${USER} && \
echo "${USER} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers && \
echo "${USER}:${USER_PASSWORD}" | chpasswd
# Run directory #
RUN mkdir -p /run/user/${UID} && \
chown ${USER}:${USER} /run/user/${UID}
# Groups #
RUN usermod -aG input root && usermod -aG input ${USER} && \
usermod -aG video root && usermod -aG video ${USER} && \
usermod -aG render root && usermod -aG render ${USER}
## Copy files from builders ##
# this is done here at end to not trigger full rebuild on changes to builder
# nestri
COPY --from=gst-builder /builder/nestri/target/release/nestri-server /usr/bin/nestri-server
# gstwayland
COPY --from=gstwayland-builder /builder/plugin/include/libgstwaylanddisplay /usr/include/
COPY --from=gstwayland-builder /builder/plugin/lib/*libgstwayland* /usr/lib/
COPY --from=gstwayland-builder /builder/plugin/lib/gstreamer-1.0/libgstwayland* /usr/lib/gstreamer-1.0/
COPY --from=gstwayland-builder /builder/plugin/lib/pkgconfig/gstwayland* /usr/lib/pkgconfig/
COPY --from=gstwayland-builder /builder/plugin/lib/pkgconfig/libgstwayland* /usr/lib/pkgconfig/
## Copy scripts ##
COPY packages/scripts/ /etc/nestri/
# Set scripts as executable #
RUN chmod +x /etc/nestri/envs.sh /etc/nestri/entrypoint.sh /etc/nestri/entrypoint_nestri.sh
## Set runtime envs ##
ENV XDG_RUNTIME_DIR=/run/user/${UID} \
HOME=/home/${USER}
# Required for NVIDIA.. they want to be special like that #
ENV NVIDIA_DRIVER_CAPABILITIES=all
# Wireplumber disable suspend #
# Remove suspend node
RUN sed -z -i 's/{[[:space:]]*name = node\/suspend-node\.lua,[[:space:]]*type = script\/lua[[:space:]]*provides = hooks\.node\.suspend[[:space:]]*}[[:space:]]*//g' /usr/share/wireplumber/wireplumber.conf
# Remove "hooks.node.suspend" want
RUN sed -i '/wants = \[/{s/hooks\.node\.suspend\s*//; s/,\s*\]/]/}' /usr/share/wireplumber/wireplumber.conf
ENTRYPOINT ["supervisord", "-c", "/etc/nestri/supervisord.conf"]

View File

@@ -4,8 +4,20 @@
Nestri is in a **very early-beta phase**, so errors and bugs may occur.
::
### Step 0: Construct Your Docker Image
Checkout your branch with the latest version of nestri and build the image `<your-nestri-image>` within git root folder:
```bash
docker buildx build -t <your-nestri-image>:latest -f Containerfile.runner .
```
## Step 1: Navigate to Your Game Directory
::alert{type="info"}
You can right now also pull the docker image from DatHorse GitHub Containter Registry with:
```bash
docker pull ghcr.io/datcaptainhorse/nestri-cachyos:latest
```
::
### Step 1: Navigate to Your Game Directory
First, change your directory to the location of your `.exe` file. For Steam games, this typically means:
```bash
cd $HOME/.steam/steam/steamapps
@@ -18,28 +30,49 @@ echo "$(head /dev/urandom | LC_ALL=C tr -dc 'a-zA-Z0-9' | head -c 16)"
```
This command generates a random 16-character string. Be sure to note this string carefully, as you'll need it for the next step.
### Step 3: Launch the Nestri Server
With your SESSION_ID ready, insert it into the command below, replacing `<paste here>` with your actual session ID. Then, run the command to start the Nestri server:
With your SESSION_ID ready, insert it into the command below, replacing `<your_session_id>` with your actual session ID, also replace `<relay_url>` with your relay URL and `<your-nestri-image>` with your build nestri image or nestri remote image. Then run the command to start the Nestri server:
```bash
docker run --rm -it --shm-size=1g --gpus all -e NVIDIA_DRIVER_CAPABILITIES=all --runtime=nvidia -e RELAY_URL='<relay_url>' -e NESTRI_ROOM=<your_session_id> -e RESOLUTION=1920x1080 -e FRAMERATE=60 -e NESTRI_PARAMS='--verbose=true --video-codec=h264 --video-bitrate=4000 --video-bitrate-max=6000'--name nestri -d -v "$(pwd)":/mnt/game/ <your-nestri-image>:latest
```
docker run --gpus all --device=/dev/dri --name nestri -it --entrypoint /bin/bash -e SESSION_ID=<paste here> -v "$(pwd)":/game -p 8080:8080/udp --cap-add=SYS_NICE --cap-add=SYS_ADMIN ghcr.io/nestriness/nestri/server:nightly
### Step 4: Get Into your container
Get into your container to start your game:
```bash
sudo docker exec -it nestri bash
```
> \[!TIP]
>
> Ensure UDP port 8080 is accessible from the internet. Use `ufw allow 8080/udp` or adjust your cloud provider's security group settings accordingly.
### Step 4: Configure the Game within the Container
After executing the previous command, you'll be in a new shell within the container (example: `nestri@3f199ee68c01:~$`). Perform the following checks:
1. Verify the game is mounted by executing `ls -la /game`. If not, exit and ensure you've correctly mounted the game directory as a volume.
2. Then, start the Nestri server by running `/etc/startup.sh > /dev/null &`.
### Step 5: Installing a Launcher
For most games that are not DRM free you need a launcher. In this case use the umu launcher and optional mangohud:
```bash
pacman -S --overwrite="*" umu-launcher mangohud
```
### Step 5: Running Your Game
Wait for the `.X11-unix` directory to appear in `/tmp` (check with `ls -la /tmp`). Once it appears, you're ready to launch your game.
- With Proton-GE: `nestri-proton -pr <game>.exe`
- With Wine: `nestri-proton -wr <game>.exe`
You have to execute your game now with nestri user. If you have a linux game just execute it with the nestri user
```bash
su nestri
source /etc/nestri/envs.sh
GAMEID=0 PROTONPATH=GE-Proton mangohud umu-run /mnt/game/<your-game.exe>
```
### Step 6: Begin Playing
Finally, construct the play URL with your session ID:
```
echo "https://nestri.io/play/$SESSION_ID"
```
`https://nestri.io/play/<your_session_id>`
Navigate to this URL in your browser, click on the page to capture your mouse pointer, and start playing!
::alert{type="info"}
You can also use other relays/frontends depending on your choosen `<relay_url>`
For testing you can use DatHorse Relay and Frontend:
| **Placeholder** | **URL** |
| ---------------------------- | ---------- |
| `<relay_url>` | `https://relay.dathorse.com/` |
| `<frontend_url>` | `https://nestritest.dathorse.com/play/<your_session_id>` |
::
<!--
Nestri Node is easy to install using the provided installation script. Follow the steps below to get started.

View File

@@ -17,7 +17,7 @@ export default component$(() => {
video = document.createElement("video");
video.id = "stream-video-player";
video.style.visibility = "hidden";
const webrtc = new WebRTCStream("https://relay.dathorse.com", id, (mediaStream) => {
const webrtc = new WebRTCStream("http://localhost:8088", id, (mediaStream) => {
if (video && mediaStream && (video as HTMLVideoElement).srcObject === null) {
console.log("Setting mediastream");
(video as HTMLVideoElement).srcObject = mediaStream;
@@ -26,6 +26,8 @@ export default component$(() => {
window.hasstream = true;
// @ts-ignore
window.roomOfflineElement?.remove();
// @ts-ignore
window.playbtnelement?.remove();
const playbtn = document.createElement("button");
playbtn.style.position = "absolute";
@@ -81,8 +83,14 @@ export default component$(() => {
});
};
document.body.append(playbtn);
// @ts-ignore
window.playbtnelement = playbtn;
} else if (mediaStream === null) {
console.log("MediaStream is null, Room is offline");
// @ts-ignore
window.playbtnelement?.remove();
// @ts-ignore
window.roomOfflineElement?.remove();
// Add a message to the screen
const offline = document.createElement("div");
offline.style.position = "absolute";
@@ -104,6 +112,30 @@ export default component$(() => {
const ctx = canvas.value.getContext("2d");
if (ctx) ctx.clearRect(0, 0, canvas.value.width, canvas.value.height);
}
} else if ((video as HTMLVideoElement).srcObject !== null) {
console.log("Setting new mediastream");
(video as HTMLVideoElement).srcObject = mediaStream;
// @ts-ignore
window.hasstream = true;
// Start video rendering
(video as HTMLVideoElement).play().then(() => {
// @ts-ignore
window.roomOfflineElement?.remove();
if (canvas.value) {
canvas.value.width = (video as HTMLVideoElement).videoWidth;
canvas.value.height = (video as HTMLVideoElement).videoHeight;
const ctx = canvas.value.getContext("2d");
const renderer = () => {
// @ts-ignore
if (ctx && window.hasstream) {
ctx.drawImage((video as HTMLVideoElement), 0, 0);
(video as HTMLVideoElement).requestVideoFrameCallback(renderer);
}
}
(video as HTMLVideoElement).requestVideoFrameCallback(renderer);
}
});
}
});
}

View File

@@ -25,51 +25,16 @@ export class WebRTCStream {
}
this._onConnected = connectedCallback;
this._setup(serverURL, roomName);
}
private _setup(serverURL: string, roomName: string) {
console.log("Setting up WebSocket");
// Replace http/https with ws/wss
const wsURL = serverURL.replace(/^http/, "ws");
this._ws = new WebSocket(`${wsURL}/api/ws/${roomName}`);
this._ws.onopen = async () => {
console.log("WebSocket opened");
console.log("Setting up PeerConnection");
this._pc = new RTCPeerConnection({
iceServers: [
{
urls: "stun:stun.l.google.com:19302"
}
],
});
this._pc.ontrack = (e) => {
console.log("Track received: ", e.track);
this._mediaStream = e.streams[e.streams.length - 1];
};
this._pc.onconnectionstatechange = () => {
console.log("Connection state: ", this._pc!.connectionState);
if (this._pc!.connectionState === "connected") {
if (this._onConnected && this._mediaStream)
this._onConnected(this._mediaStream);
}
};
this._pc.onicecandidate = (e) => {
if (e.candidate) {
const message: MessageICE = {
payload_type: "ice",
candidate: e.candidate
};
this._ws!.send(encodeMessage(message));
}
}
this._pc.ondatachannel = (e) => {
this._dataChannel = e.channel;
this._setupDataChannelEvents();
}
// Send join message
const joinMessage: MessageJoin = {
payload_type: "join",
@@ -87,6 +52,11 @@ export class WebRTCStream {
const message = await decodeMessage<MessageBase>(e.data);
switch (message.payload_type) {
case "sdp":
if (!this._pc) {
// Setup peer connection now
this._setupPeerConnection();
}
console.log("Received SDP: ", (message as MessageSDP).sdp);
await this._pc!.setRemoteDescription((message as MessageSDP).sdp);
// Create our answer
const answer = await this._pc!.createAnswer();
@@ -99,12 +69,13 @@ export class WebRTCStream {
}));
break;
case "ice":
if (!this._pc) break;
// If remote description is not set yet, hold the ICE candidates
if (this._pc!.remoteDescription) {
await this._pc!.addIceCandidate((message as MessageICE).candidate);
if (this._pc.remoteDescription) {
await this._pc.addIceCandidate((message as MessageICE).candidate);
// Add held ICE candidates
for (const ice of iceHolder) {
await this._pc!.addIceCandidate(ice);
await this._pc.addIceCandidate(ice);
}
iceHolder = [];
} else {
@@ -134,7 +105,19 @@ export class WebRTCStream {
}
this._ws.onclose = () => {
console.log("WebSocket closed");
console.log("WebSocket closed, reconnecting in 3 seconds");
if (this._onConnected)
this._onConnected(null);
// Clear PeerConnection
if (this._pc) {
this._pc.close();
this._pc = undefined;
}
setTimeout(() => {
this._setup(serverURL, roomName);
}, 3000);
}
this._ws.onerror = (e) => {
@@ -142,6 +125,45 @@ export class WebRTCStream {
}
}
private _setupPeerConnection() {
console.log("Setting up PeerConnection");
this._pc = new RTCPeerConnection({
iceServers: [
{
urls: "stun:stun.l.google.com:19302"
}
],
});
this._pc.ontrack = (e) => {
console.log("Track received: ", e.track);
this._mediaStream = e.streams[e.streams.length - 1];
};
this._pc.onconnectionstatechange = () => {
console.log("Connection state: ", this._pc!.connectionState);
if (this._pc!.connectionState === "connected") {
if (this._onConnected && this._mediaStream)
this._onConnected(this._mediaStream);
}
};
this._pc.onicecandidate = (e) => {
if (e.candidate) {
const message: MessageICE = {
payload_type: "ice",
candidate: e.candidate
};
this._ws!.send(encodeMessage(message));
}
}
this._pc.ondatachannel = (e) => {
this._dataChannel = e.channel;
this._setupDataChannelEvents();
}
}
// Forces opus to stereo in Chromium browsers, because of course
private forceOpusStereo(SDP: string): string {
// Look for "minptime=10;useinbandfec=1" and replace with "minptime=10;useinbandfec=1;stereo=1;sprop-stereo=1;"

View File

@@ -200,11 +200,15 @@ func ingestHandler(room *Room) {
})
room.WebSocket.RegisterOnClose(func() {
// If PeerConnection is not open or does not exist, delete room
if (room.PeerConnection != nil && room.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateConnected) ||
room.PeerConnection == nil {
DeleteRoomIfEmpty(room)
// If PeerConnection is still open, close it
if room.PeerConnection != nil {
if err = room.PeerConnection.Close(); err != nil {
log.Printf("Failed to close PeerConnection for room: '%s' - reason: %s\n", room.Name, err)
}
room.PeerConnection = nil
}
room.Online = false
DeleteRoomIfEmpty(room)
})
log.Printf("Room: '%s' is ready, sending an OK\n", room.Name)

View File

@@ -33,7 +33,7 @@ func (vw *Participant) addTrack(trackLocal *webrtc.TrackLocal) error {
rtcpBuffer := make([]byte, 1400)
for {
if _, _, rtcpErr := rtpSender.Read(rtcpBuffer); rtcpErr != nil {
return
break
}
}
}()

View File

@@ -10,6 +10,7 @@ import (
type SafeWebSocket struct {
*websocket.Conn
sync.Mutex
closeCallback func() // OnClose callback
binaryCallbacks map[string]OnMessageCallback // MessageBase type -> callback
}
@@ -17,6 +18,7 @@ type SafeWebSocket struct {
func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
ws := &SafeWebSocket{
Conn: conn,
closeCallback: nil,
binaryCallbacks: make(map[string]OnMessageCallback),
}
@@ -32,10 +34,6 @@ func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
}
break
} else if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
// If closing, just break
if GetFlags().Verbose {
log.Printf("WebSocket closing\n")
}
break
} else if err != nil {
log.Printf("Failed to read WebSocket message, reason: %s\n", err)
@@ -62,6 +60,11 @@ func NewSafeWebSocket(conn *websocket.Conn) *SafeWebSocket {
log.Printf("Unknown WebSocket message type: %d\n", kind)
}
}
// Call close callback
if ws.closeCallback != nil {
ws.closeCallback()
}
}()
return ws
@@ -102,13 +105,12 @@ func (ws *SafeWebSocket) UnregisterMessageCallback(msgType string) {
// RegisterOnClose sets the callback for websocket closing
func (ws *SafeWebSocket) RegisterOnClose(callback func()) {
ws.SetCloseHandler(func(code int, text string) error {
ws.closeCallback = func() {
// Clear our callbacks
ws.Lock()
ws.binaryCallbacks = nil
ws.Unlock()
// Call the callback
callback()
return nil
})
}
}

View File

@@ -0,0 +1,65 @@
#!/bin/bash
set -euo pipefail
# Wait for dbus socket to be ready
echo "Waiting for DBus system bus socket..."
DBUS_SOCKET="/run/dbus/system_bus_socket"
for _ in {1..10}; do # Wait up to 10 seconds
if [ -e "$DBUS_SOCKET" ]; then
echo "DBus system bus socket is ready."
break
fi
sleep 1
done
if [ ! -e "$DBUS_SOCKET" ]; then
echo "Error: DBus system bus socket did not appear. Exiting."
exit 1
fi
# Wait for PipeWire to be ready
echo "Waiting for PipeWire socket..."
PIPEWIRE_SOCKET="/run/user/${UID}/pipewire-0"
for _ in {1..10}; do # Wait up to 10 seconds
if [ -e "$PIPEWIRE_SOCKET" ]; then
echo "PipeWire socket is ready."
break
fi
sleep 1
done
if [ ! -e "$PIPEWIRE_SOCKET" ]; then
echo "Error: PipeWire socket did not appear. Exiting."
exit 1
fi
# Update system packages before proceeding
echo "Upgrading system packages..."
pacman -Syu --noconfirm
echo "Detecting GPU vendor and installing necessary GStreamer plugins..."
source /etc/nestri/gpu_helpers.sh
get_gpu_info
# Identify vendor
if [[ "${vendor_full_map[0],,}" =~ "intel" ]]; then
echo "Intel GPU detected, installing required packages..."
chwd -a
pacman -Syu --noconfirm gstreamer-vaapi gst-plugin-va gst-plugin-qsv
# chwd missed a thing
pacman -Syu --noconfirm vpl-gpu-rt
elif [[ "${vendor_full_map[0],,}" =~ "amd" ]]; then
echo "AMD GPU detected, installing required packages..."
chwd -a
pacman -Syu --noconfirm gstreamer-vaapi gst-plugin-va
elif [[ "${vendor_full_map[0],,}" =~ "nvidia" ]]; then
echo "NVIDIA GPU detected. Assuming drivers are linked"
else
echo "Unknown GPU vendor. No additional packages will be installed"
fi
# Clean up remainders
echo "Cleaning up old package cache..."
paccache -rk1
echo "Switching to nestri user for application startup..."
exec sudo -E -u nestri /etc/nestri/entrypoint_nestri.sh

View File

@@ -0,0 +1,156 @@
#!/bin/bash
set -euo pipefail
# Source environment variables from envs.sh
if [ -f /etc/nestri/envs.sh ]; then
echo "Sourcing environment variables from envs.sh..."
source /etc/nestri/envs.sh
else
echo "envs.sh not found! Ensure it exists at /etc/nestri/envs.sh."
exit 1
fi
# Configuration
MAX_RETRIES=3
# Helper function to restart the chain
restart_chain() {
echo "Restarting nestri-server, labwc, and wlr-randr..."
# Kill all child processes safely (if running)
if [[ -n "${NESTRI_PID:-}" ]] && kill -0 "${NESTRI_PID}" 2 >/dev/null; then
kill "${NESTRI_PID}"
fi
if [[ -n "${LABWC_PID:-}" ]] && kill -0 "${LABWC_PID}" 2 >/dev/null; then
kill "${LABWC_PID}"
fi
wait || true
# Start nestri-server
start_nestri_server
RETRY_COUNT=0
}
# Function to start nestri-server
start_nestri_server() {
echo "Starting nestri-server..."
nestri-server $(echo $NESTRI_PARAMS) &
NESTRI_PID=$!
# Wait for Wayland display (wayland-1) to be ready
echo "Waiting for Wayland display 'wayland-1' to be ready..."
WAYLAND_SOCKET="/run/user/${UID}/wayland-1"
for _ in {1..15}; do # Wait up to 15 seconds
if [ -e "$WAYLAND_SOCKET" ]; then
echo "Wayland display 'wayland-1' is ready."
start_labwc
return
fi
sleep 1
done
echo "Error: Wayland display 'wayland-1' did not appear. Incrementing retry count..."
((RETRY_COUNT++))
if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for nestri-server. Exiting."
exit 1
fi
restart_chain
}
# Function to start labwc
start_labwc() {
echo "Starting labwc..."
rm -rf /tmp/.X11-unix && mkdir -p /tmp/.X11-unix && chown nestri:nestri /tmp/.X11-unix
WAYLAND_DISPLAY=wayland-1 WLR_BACKENDS=wayland labwc &
LABWC_PID=$!
# Wait for labwc to initialize (using `wlr-randr` as an indicator)
echo "Waiting for labwc to initialize..."
for _ in {1..15}; do
if wlr-randr --json | jq -e '.[] | select(.enabled == true)' >/dev/null; then
echo "labwc is initialized and Wayland outputs are ready."
start_wlr_randr
return
fi
sleep 1
done
echo "Error: labwc did not initialize correctly. Incrementing retry count..."
((RETRY_COUNT++))
if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for labwc. Exiting."
exit 1
fi
restart_chain
}
# Function to run wlr-randr
start_wlr_randr() {
echo "Configuring resolution with wlr-randr..."
OUTPUT_NAME=$(wlr-randr --json | jq -r '.[] | select(.enabled == true) | .name' | head -n 1)
if [ -z "$OUTPUT_NAME" ]; then
echo "Error: No enabled outputs detected. Skipping wlr-randr."
return
fi
# Retry logic for wlr-randr
local WLR_RETRIES=0
while ! WAYLAND_DISPLAY=wayland-0 wlr-randr --output "$OUTPUT_NAME" --custom-mode "$RESOLUTION"; do
echo "Error: Failed to configure wlr-randr. Retrying..."
((WLR_RETRIES++))
if [ "$WLR_RETRIES" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for wlr-randr. Moving on without resolution setup."
return
fi
sleep 2
done
echo "wlr-randr configuration successful."
}
# Main loop to monitor processes
main_loop() {
trap 'echo "Terminating...";
if [[ -n "${NESTRI_PID:-}" ]] && kill -0 "${NESTRI_PID}" 2>/dev/null; then
kill "${NESTRI_PID}"
fi
if [[ -n "${LABWC_PID:-}" ]] && kill -0 "${LABWC_PID}" 2>/dev/null; then
kill "${LABWC_PID}"
fi
exit 0' SIGINT SIGTERM
while true; do
# Wait for any child process to exit
wait -n
# Check which process exited
if ! kill -0 ${NESTRI_PID:-} 2 >/dev/null; then
echo "nestri-server crashed. Restarting chain..."
((RETRY_COUNT++))
if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for nestri-server. Exiting."
exit 1
fi
restart_chain
elif ! kill -0 ${LABWC_PID:-} 2 >/dev/null; then
echo "labwc crashed. Restarting labwc and wlr-randr..."
((RETRY_COUNT++))
if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then
echo "Max retries reached for labwc. Exiting."
exit 1
fi
start_labwc
fi
done
}
# Initialize retry counter
RETRY_COUNT=0
# Start the initial chain
restart_chain
# Enter monitoring loop
main_loop

View File

@@ -1,4 +1,5 @@
#!/bin/bash -e
#!/bin/bash
set -euo pipefail
export XDG_RUNTIME_DIR=/run/user/${UID}/
export WAYLAND_DISPLAY=wayland-0
@@ -9,4 +10,4 @@ export $(dbus-launch)
export PROTON_NO_FSYNC=1
# Our preferred prefix
export WINEPREFIX=${USER_HOME}/.nestripfx/
export WINEPREFIX=/home/${USER}/.nestripfx/

View File

@@ -1,277 +1,52 @@
#!/bin/bash -e
#!/bin/bash
set -euo pipefail
# Various helper functions for handling available GPUs
declare -a vendor_full_map=()
declare -a vendor_id_map=()
declare -A vendor_index_map=()
declare -ga gpu_map
declare -gA gpu_bus_map
declare -gA gpu_card_map
declare -gA gpu_product_map
declare -gA vendor_index_map
declare -gA vendor_full_map
function get_gpu_info {
# Initialize arrays/maps to avoid unbound variable errors
vendor_full_map=()
vendor_id_map=()
vendor_index_map=()
# Map to help get shorter vendor identifiers
declare -A vendor_keywords=(
["advanced micro devices"]='amd'
["ati"]='amd'
["amd"]='amd'
["radeon"]='amd'
["nvidia"]='nvidia'
["intel"]='intel'
)
# Use lspci to detect GPU info
gpu_info=$(lspci | grep -i 'vga\|3d\|display')
get_gpu_info() {
# Clear out previous data
gpu_map=()
gpu_bus_map=()
gpu_card_map=()
gpu_product_map=()
vendor_index_map=()
vendor_full_map=()
# Parse each line of GPU info
while IFS= read -r line; do
# Extract vendor name and ID from lspci output
vendor=$(echo "$line" | awk -F: '{print $3}' | sed -E 's/^[[:space:]]+//g' | tr '[:upper:]' '[:lower:]')
id=$(echo "$line" | awk '{print $1}')
local vendor=""
local product=""
local bus_info=""
local vendor_full=""
while read -r line; do
line="${line##*( )}"
if [[ "${line,,}" =~ "vendor:" ]]; then
vendor=""
vendor_full=$(echo "$line" | awk '{$1=""; print $0}' | xargs)
# Look for short vendor keyword in line
for keyword in "${!vendor_keywords[@]}"; do
if [[ "${line,,}" == *"$keyword"* ]]; then
vendor="${vendor_keywords[$keyword]}"
break
# Normalize vendor name
if [[ $vendor =~ .*nvidia.* ]]; then
vendor="nvidia"
elif [[ $vendor =~ .*intel.* ]]; then
vendor="intel"
elif [[ $vendor =~ .*advanced[[:space:]]micro[[:space:]]devices.* ]]; then
vendor="amd"
elif [[ $vendor =~ .*ati.* ]]; then
vendor="amd"
else
vendor="unknown"
fi
done
# If no vendor keywords match, use first word
if [[ -z "$vendor" ]]; then
vendor=$(echo "$vendor_full" | awk '{print tolower($1)}')
fi
elif [[ "${line,,}" =~ "product:" ]]; then
product=$(echo "$line" | awk '{$1=""; print $0}' | xargs)
elif [[ "${line,,}" =~ "bus info:" ]]; then
bus_info=$(echo "$line" | awk '{print $3}')
fi
if [[ -n "$vendor" && -n "$product" && -n "$bus_info" && ! "${line,,}" =~ \*"-display" ]]; then
# We have gathered all GPU info necessary, store it
# Check if vendor index is being tracked
if [[ -z "${vendor_index_map[$vendor]}" ]]; then
# Start new vendor index tracking
vendor_index_map[$vendor]=0
else
# Another GPU of same vendor, increment index
vendor_index_map[$vendor]="$((vendor_index_map[$vendor] + 1))"
fi
# Resolved GPU index
local gpu_index="${vendor_index_map[$vendor]}"
local gpu_key="$vendor:$gpu_index"
# Get /dev/dri/cardN of GPU
local gpu_card=$({ ls -1d /sys/bus/pci/devices/*${bus_info#pci@}/drm/*; } 2>&1 | grep card* | grep -oP '(?<=card)\d+')
# Store info in maps
gpu_map+=("$gpu_key")
gpu_bus_map["$gpu_key"]="$bus_info"
gpu_product_map["$gpu_key"]="$product"
vendor_full_map["$gpu_key"]="$vendor_full"
if [[ -n "$gpu_card" ]]; then
gpu_card_map["$gpu_key"]="$gpu_card"
fi
# Clear values for additional GPUs
vendor=""
product=""
bus_info=""
vendor_full=""
fi
if [[ "${line,,}" =~ \*"-display" ]]; then
# New GPU found before storing, clear incomplete values to prevent mixing
vendor=""
product=""
bus_info=""
vendor_full=""
fi
done < <(sudo lshw -c video)
# Add to arrays/maps if unique
if ! [[ "${vendor_index_map[$vendor]:-}" ]]; then
vendor_index_map[$vendor]="${#vendor_full_map[@]}"
vendor_full_map+=("$vendor")
fi
vendor_id_map+=("$id")
done <<< "$gpu_info"
}
check_and_populate_gpus() {
if [[ "${#gpu_map[@]}" -eq 0 ]]; then
get_gpu_info # Gather info incase info not gathered yet
if [[ "${#gpu_map[@]}" -eq 0 ]]; then
echo "No GPUs found on this system" >&2
return 1
fi
fi
}
check_selected_gpu() {
local selected_gpu="${1,,}"
if [[ ! " ${gpu_map[*]} " =~ " $selected_gpu " ]]; then
echo "No such GPU: '$selected_gpu'" >&2
return 1
fi
echo "$selected_gpu"
}
list_available_gpus() {
if ! check_and_populate_gpus; then
return 1
fi
echo "Available GPUs:" >&2
for gpu in "${gpu_map[@]}"; do
echo " [$gpu] \"${gpu_product_map[$gpu]}\" @[${gpu_bus_map[$gpu]}]"
done
}
convert_bus_id_to_xorg() {
local bus_info="$1"
IFS=":." read -ra bus_parts <<< "${bus_info#pci@????:}" # Remove "pci@" and the following 4 characters (domain)
# Check if bus_info has the correct format (at least 3 parts after removing domain)
if [[ "${#bus_parts[@]}" -lt 3 ]]; then
echo "Invalid bus info format: $bus_info" >&2
return 1
fi
# Convert each part from hexadecimal to decimal
bus_info_xorg="PCI:"
for part in "${bus_parts[@]}"; do
bus_info_xorg+="$((16#$part)):"
done
bus_info_xorg="${bus_info_xorg%:}" # Remove the trailing colon
echo "$bus_info_xorg"
}
print_gpu_info() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "[$selected_gpu]"
echo " Vendor: ${vendor_full_map[$selected_gpu]}"
echo " Product: ${gpu_product_map[$selected_gpu]}"
echo " Bus: ${gpu_bus_map[$selected_gpu]}"
# Check if card path was found
if [[ "${gpu_card_map[$selected_gpu]}" ]]; then
echo " Card: /dev/dri/card${gpu_card_map[$selected_gpu]}"
fi
echo
}
get_gpu_vendor() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${selected_gpu%%:*}"
}
get_gpu_vendor_full() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${vendor_full_map[$selected_gpu]}"
}
get_gpu_index() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${selected_gpu#*:}"
}
get_gpu_product() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${gpu_product_map[$selected_gpu]}"
}
get_gpu_bus() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo "${gpu_bus_map[$selected_gpu]}"
}
get_gpu_bus_xorg() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
echo $(convert_bus_id_to_xorg "${gpu_bus_map[$selected_gpu]}")
}
get_gpu_card() {
if ! check_and_populate_gpus; then
return 1
fi
local selected_gpu
if ! selected_gpu=$(check_selected_gpu "$1"); then
return 1
fi
# Check if card path was found
if [[ -z "${gpu_card_map[$selected_gpu]}" ]]; then
echo "No card device found for GPU: $selected_gpu" >&2
return 1
fi
echo "/dev/dri/card${gpu_card_map[$selected_gpu]}"
function debug_gpu_info {
echo "Vendor Full Map: ${vendor_full_map[*]}"
echo "Vendor ID Map: ${vendor_id_map[*]}"
echo "Vendor Index Map:"
for key in "${!vendor_index_map[@]}"; do
echo " $key: ${vendor_index_map[$key]}"
done
}

View File

@@ -0,0 +1,53 @@
[supervisord]
user=root
nodaemon=true
loglevel=info
logfile=/tmp/supervisord.log
[program:dbus]
user=root
command=dbus-daemon --system --nofork --nopidfile
autorestart=true
autostart=true
startretries=3
priority=1
[program:seatd]
user=root
command=seatd
autorestart=true
autostart=true
startretries=3
priority=2
[program:pipewire]
user=nestri
command=dbus-launch pipewire
autorestart=true
autostart=true
startretries=3
priority=3
[program:pipewire-pulse]
user=nestri
command=dbus-launch pipewire-pulse
autorestart=true
autostart=true
startretries=3
priority=4
[program:wireplumber]
user=nestri
command=dbus-launch wireplumber
autorestart=true
autostart=true
startretries=3
priority=5
[program:entrypoint]
user=root
command=/etc/nestri/entrypoint.sh
autorestart=false
autostart=true
startretries=0
priority=10

View File

@@ -9,7 +9,8 @@ path = "src/main.rs"
[dependencies]
gst.workspace = true
gst-app.workspace = true
gst-webrtc.workspace = true
gstrswebrtc.workspace = true
serde = {version = "1.0.214", features = ["derive"] }
tokio = { version = "1.41.0", features = ["full"] }
clap = { version = "4.5.20", features = ["env"] }

View File

@@ -42,7 +42,7 @@ impl Args {
.short('u')
.long("relay-url")
.env("RELAY_URL")
.help("Nestri relay URL")
.help("Nestri relay URL"),
)
.arg(
Arg::new("resolution")
@@ -64,7 +64,7 @@ impl Args {
Arg::new("room")
.long("room")
.env("NESTRI_ROOM")
.help("Nestri room name/identifier")
.help("Nestri room name/identifier"),
)
.arg(
Arg::new("gpu-vendor")
@@ -110,7 +110,7 @@ impl Args {
Arg::new("video-encoder")
.long("video-encoder")
.env("VIDEO_ENCODER")
.help("Override video encoder (e.g. 'vah264enc')")
.help("Override video encoder (e.g. 'vah264enc')"),
)
.arg(
Arg::new("video-rate-control")
@@ -165,7 +165,7 @@ impl Args {
Arg::new("audio-encoder")
.long("audio-encoder")
.env("AUDIO_ENCODER")
.help("Override audio encoder (e.g. 'opusenc')")
.help("Override audio encoder (e.g. 'opusenc')"),
)
.arg(
Arg::new("audio-rate-control")
@@ -188,6 +188,13 @@ impl Args {
.help("Maximum bitrate in kbps")
.default_value("192"),
)
.arg(
Arg::new("dma-buf")
.long("dma-buf")
.env("DMA_BUF")
.help("Use DMA-BUF for pipeline")
.default_value("false"),
)
.get_matches();
Self {

View File

@@ -15,6 +15,9 @@ pub struct AppArgs {
pub relay_url: String,
/// Nestri room name/identifier
pub room: String,
/// Experimental DMA-BUF support
pub dma_buf: bool,
}
impl AppArgs {
pub fn from_matches(matches: &clap::ArgMatches) -> Self {
@@ -26,23 +29,33 @@ impl AppArgs {
debug_latency: matches.get_one::<String>("debug-latency").unwrap() == "true"
|| matches.get_one::<String>("debug-latency").unwrap() == "1",
resolution: {
let res = matches.get_one::<String>("resolution").unwrap().clone();
let res = matches
.get_one::<String>("resolution")
.unwrap_or(&"1280x720".to_string())
.clone();
let parts: Vec<&str> = res.split('x').collect();
(
parts[0].parse::<u32>().unwrap(),
parts[1].parse::<u32>().unwrap(),
)
if parts.len() >= 2 {
(
parts[0].parse::<u32>().unwrap_or(1280),
parts[1].parse::<u32>().unwrap_or(720),
)
} else {
(1280, 720)
}
},
framerate: matches
.get_one::<String>("framerate")
.unwrap()
.parse::<u32>()
.unwrap(),
.unwrap_or(60),
relay_url: matches.get_one::<String>("relay-url").unwrap().clone(),
// Generate random room name if not provided
room: matches.get_one::<String>("room")
room: matches
.get_one::<String>("room")
.unwrap_or(&rand::random::<u32>().to_string())
.clone(),
dma_buf: matches.get_one::<String>("dma-buf").unwrap() == "true"
|| matches.get_one::<String>("dma-buf").unwrap() == "1",
}
}
@@ -55,5 +68,6 @@ impl AppArgs {
println!("> framerate: {}", self.framerate);
println!("> relay_url: {}", self.relay_url);
println!("> room: {}", self.room);
println!("> dma_buf: {}", self.dma_buf);
}
}

View File

@@ -197,6 +197,11 @@ fn is_encoder_supported(encoder: &String) -> bool {
gst::ElementFactory::find(encoder.as_str()).is_some()
}
fn set_element_property(element: &gst::Element, property: &str, value: &dyn ToValue) {
element.set_property(property, value.to_value());
}
/// Helper to set CQP value of known encoder
/// # Arguments
/// * `encoder` - Information about the encoder.
@@ -316,7 +321,7 @@ pub fn encoder_gop_params(encoder: &VideoEncoderInfo, gop_size: u32) -> VideoEnc
let prop_name = prop.name();
// Look for known keys
if prop_name.to_lowercase().contains("gop")
if prop_name.to_lowercase().contains("gop-size")
|| prop_name.to_lowercase().contains("int-max")
|| prop_name.to_lowercase().contains("max-dist")
|| prop_name.to_lowercase().contains("intra-period-length")

View File

@@ -1,21 +1,21 @@
mod args;
mod enc_helper;
mod gpu;
mod room;
mod websocket;
mod latency;
mod messages;
mod nestrisink;
mod websocket;
use crate::args::encoding_args;
use crate::nestrisink::NestriSignaller;
use crate::websocket::NestriWebSocket;
use futures_util::StreamExt;
use gst::prelude::*;
use gst_app::AppSink;
use gstrswebrtc::signaller::Signallable;
use gstrswebrtc::webrtcsink::BaseWebRTCSink;
use std::error::Error;
use std::str::FromStr;
use std::sync::Arc;
use futures_util::StreamExt;
use gst_app::app_sink::AppSinkStream;
use tokio::sync::{Mutex};
use crate::websocket::{NestriWebSocket};
// Handles gathering GPU information and selecting the most suitable GPU
fn handle_gpus(args: &args::Args) -> Option<gpu::GPUInfo> {
@@ -161,20 +161,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Begin connection attempt to the relay WebSocket endpoint
// replace any http/https with ws/wss
let replaced_relay_url
= args.app.relay_url.replace("http://", "ws://").replace("https://", "wss://");
let ws_url = format!(
"{}/api/ws/{}",
replaced_relay_url,
args.app.room,
);
let replaced_relay_url = args
.app
.relay_url
.replace("http://", "ws://")
.replace("https://", "wss://");
let ws_url = format!("{}/api/ws/{}", replaced_relay_url, args.app.room,);
// Setup our websocket
let nestri_ws = Arc::new(NestriWebSocket::new(ws_url).await?);
log::set_max_level(log::LevelFilter::Info);
log::set_boxed_logger(Box::new(nestri_ws.clone())).unwrap();
let _ = gst::init();
gst::init()?;
gstrswebrtc::plugin_register_static()?;
// Handle GPU selection
let gpu = handle_gpus(&args);
@@ -197,12 +197,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Handle audio encoder selection
let audio_encoder = handle_encoder_audio(&args);
/*** ROOM SETUP ***/
let room = Arc::new(Mutex::new(
room::Room::new(nestri_ws.clone()).await?,
));
/*** PIPELINE CREATION ***/
// Create the pipeline
let pipeline = Arc::new(gst::Pipeline::new());
/* Audio */
// Audio Source Element
let audio_source = match args.encoding.audio.capture_method {
@@ -220,7 +218,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Audio Rate Element
let audio_rate = gst::ElementFactory::make("audiorate").build()?;
// Required to fix gstreamer opus issue, where quality sounds off (due to wrong sample rate)
let audio_capsfilter = gst::ElementFactory::make("capsfilter").build()?;
let audio_caps = gst::Caps::from_str("audio/x-raw,rate=48000,channels=2").unwrap();
@@ -237,9 +235,6 @@ async fn main() -> Result<(), Box<dyn Error>> {
},
);
// Audio RTP Payloader Element
let audio_rtp_payloader = gst::ElementFactory::make("rtpopuspay").build()?;
/* Video */
// Video Source Element
let video_source = gst::ElementFactory::make("waylanddisplaysrc").build()?;
@@ -248,13 +243,26 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Caps Filter Element (resolution, fps)
let caps_filter = gst::ElementFactory::make("capsfilter").build()?;
let caps = gst::Caps::from_str(&format!(
"video/x-raw,width={},height={},framerate={}/1,format=RGBx",
args.app.resolution.0, args.app.resolution.1, args.app.framerate
"{},width={},height={},framerate={}/1{}",
if args.app.dma_buf {
"video/x-raw(memory:DMABuf)"
} else {
"video/x-raw"
},
args.app.resolution.0,
args.app.resolution.1,
args.app.framerate,
if args.app.dma_buf { "" } else { ",format=RGBx" }
))?;
caps_filter.set_property("caps", &caps);
// Video Tee Element
let video_tee = gst::ElementFactory::make("tee").build()?;
// GL Upload Element
let glupload = gst::ElementFactory::make("glupload").build()?;
// GL upload caps filter
let gl_caps_filter = gst::ElementFactory::make("capsfilter").build()?;
let gl_caps = gst::Caps::from_str("video/x-raw(memory:VAMemory)")?;
gl_caps_filter.set_property("caps", &gl_caps);
// Video Converter Element
let video_converter = gst::ElementFactory::make("videoconvert").build()?;
@@ -263,82 +271,30 @@ async fn main() -> Result<(), Box<dyn Error>> {
let video_encoder = gst::ElementFactory::make(video_encoder_info.name.as_str()).build()?;
video_encoder_info.apply_parameters(&video_encoder, &args.app.verbose);
// Required for AV1 - av1parse
let av1_parse = gst::ElementFactory::make("av1parse").build()?;
// Video RTP Payloader Element
let video_rtp_payloader = gst::ElementFactory::make(
format!("rtp{}pay", video_encoder_info.codec.to_gst_str()).as_str(),
)
.build()?;
/* Output */
// Audio AppSink Element
let audio_appsink = gst::ElementFactory::make("appsink").build()?;
audio_appsink.set_property("emit-signals", &true);
let audio_appsink = audio_appsink.downcast_ref::<AppSink>().unwrap();
// Video AppSink Element
let video_appsink = gst::ElementFactory::make("appsink").build()?;
video_appsink.set_property("emit-signals", &true);
let video_appsink = video_appsink.downcast_ref::<AppSink>().unwrap();
/* Debug */
// Debug Feed Element
let debug_latency = gst::ElementFactory::make("timeoverlay").build()?;
debug_latency.set_property_from_str("halignment", &"right");
debug_latency.set_property_from_str("valignment", &"bottom");
// Debug Sink Element
let debug_sink = gst::ElementFactory::make("ximagesink").build()?;
// Debug video converter
let debug_video_converter = gst::ElementFactory::make("videoconvert").build()?;
// Queues with max 2ms latency
let debug_queue = gst::ElementFactory::make("queue2").build()?;
debug_queue.set_property("max-size-time", &1000000u64);
let main_video_queue = gst::ElementFactory::make("queue2").build()?;
main_video_queue.set_property("max-size-time", &1000000u64);
let main_audio_queue = gst::ElementFactory::make("queue2").build()?;
main_audio_queue.set_property("max-size-time", &1000000u64);
// Create the pipeline
let pipeline = gst::Pipeline::new();
// WebRTC sink Element
let signaller = NestriSignaller::new(nestri_ws.clone(), pipeline.clone());
let webrtcsink = BaseWebRTCSink::with_signaller(Signallable::from(signaller.clone()));
webrtcsink.set_property_from_str("stun-server", "stun://stun.l.google.com:19302");
webrtcsink.set_property_from_str("congestion-control", "disabled");
// Add elements to the pipeline
pipeline.add_many(&[
&video_appsink.upcast_ref(),
&video_rtp_payloader,
webrtcsink.upcast_ref(),
&video_encoder,
&video_converter,
&video_tee,
&caps_filter,
&video_source,
&audio_appsink.upcast_ref(),
&audio_rtp_payloader,
&audio_encoder,
&audio_capsfilter,
&audio_rate,
&audio_converter,
&audio_source,
&main_video_queue,
&main_audio_queue,
])?;
// Add debug elements if debug is enabled
if args.app.debug_feed {
pipeline.add_many(&[&debug_sink, &debug_queue, &debug_video_converter])?;
}
// Add debug latency element if debug latency is enabled
if args.app.debug_latency {
pipeline.add(&debug_latency)?;
}
// Add AV1 parse element if AV1 is selected
if video_encoder_info.codec == enc_helper::VideoCodec::AV1 {
pipeline.add(&av1_parse)?;
// If DMA-BUF is enabled, add glupload and gl caps filter
if args.app.dma_buf {
pipeline.add_many(&[&glupload, &gl_caps_filter])?;
}
// Link main audio branch
@@ -348,47 +304,29 @@ async fn main() -> Result<(), Box<dyn Error>> {
&audio_rate,
&audio_capsfilter,
&audio_encoder,
&audio_rtp_payloader,
&main_audio_queue,
&audio_appsink.upcast_ref(),
webrtcsink.upcast_ref(),
])?;
// If debug latency, add time overlay before tee
if args.app.debug_latency {
gst::Element::link_many(&[&video_source, &caps_filter, &debug_latency, &video_tee])?;
} else {
gst::Element::link_many(&[&video_source, &caps_filter, &video_tee])?;
}
// Link debug branch if debug is enabled
if args.app.debug_feed {
// With DMA-BUF, also link glupload and it's caps
if args.app.dma_buf {
// Link video source to caps_filter, glupload, gl_caps_filter, video_converter, video_encoder, webrtcsink
gst::Element::link_many(&[
&video_tee,
&debug_video_converter,
&debug_queue,
&debug_sink,
])?;
}
// Link main video branch, if AV1, add av1_parse
if video_encoder_info.codec == enc_helper::VideoCodec::AV1 {
gst::Element::link_many(&[
&video_tee,
&video_source,
&caps_filter,
&glupload,
&gl_caps_filter,
&video_converter,
&video_encoder,
&av1_parse,
&video_rtp_payloader,
&main_video_queue,
&video_appsink.upcast_ref(),
webrtcsink.upcast_ref(),
])?;
} else {
// Link video source to caps_filter, video_converter, video_encoder, webrtcsink
gst::Element::link_many(&[
&video_tee,
&video_source,
&caps_filter,
&video_converter,
&video_encoder,
&video_rtp_payloader,
&main_video_queue,
&video_appsink.upcast_ref(),
webrtcsink.upcast_ref(),
])?;
}
@@ -397,21 +335,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
audio_source.set_property("do-timestamp", &true);
pipeline.set_property("latency", &0u64);
// Wrap the pipeline in Arc<Mutex> to safely share it
let pipeline = Arc::new(Mutex::new(pipeline));
// Run both pipeline and websocket tasks concurrently
let result = tokio::try_join!(
run_room(
room.clone(),
"audio/opus",
video_encoder_info.codec.to_mime_str(),
pipeline.clone(),
Arc::new(Mutex::new(audio_appsink.stream())),
Arc::new(Mutex::new(video_appsink.stream()))
),
run_pipeline(pipeline.clone())
);
let result = run_pipeline(pipeline.clone()).await;
match result {
Ok(_) => log::info!("All tasks completed successfully"),
@@ -424,53 +349,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
Ok(())
}
async fn run_room(
room: Arc<Mutex<room::Room>>,
audio_codec: &str,
video_codec: &str,
pipeline: Arc<Mutex<gst::Pipeline>>,
audio_stream: Arc<Mutex<AppSinkStream>>,
video_stream: Arc<Mutex<AppSinkStream>>,
) -> Result<(), Box<dyn Error>> {
// Run loop, with recovery on error
loop {
let mut room = room.lock().await;
tokio::select! {
_ = tokio::signal::ctrl_c() => {
log::info!("Room interrupted via Ctrl+C");
return Ok(());
}
result = room.run(
audio_codec,
video_codec,
pipeline.clone(),
audio_stream.clone(),
video_stream.clone(),
) => {
if let Err(e) = result {
log::error!("Room error: {}", e);
// Sleep for a while before retrying
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
} else {
return Ok(());
}
}
}
}
}
async fn run_pipeline(
pipeline: Arc<Mutex<gst::Pipeline>>,
) -> Result<(), Box<dyn Error>> {
// Take ownership of the bus without holding the lock
let bus = {
let pipeline = pipeline.lock().await;
pipeline.bus().ok_or("Pipeline has no bus")?
};
async fn run_pipeline(pipeline: Arc<gst::Pipeline>) -> Result<(), Box<dyn Error>> {
let bus = { pipeline.bus().ok_or("Pipeline has no bus")? };
{
// Temporarily lock the pipeline to change state
let pipeline = pipeline.lock().await;
if let Err(e) = pipeline.set_state(gst::State::Playing) {
log::error!("Failed to start pipeline: {}", e);
return Err("Failed to start pipeline".into());
@@ -491,8 +373,6 @@ async fn run_pipeline(
}
{
// Temporarily lock the pipeline to reset state
let pipeline = pipeline.lock().await;
pipeline.set_state(gst::State::Null)?;
}

View File

@@ -10,6 +10,31 @@ use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::latency::LatencyTracker;
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
pub enum InputMessage {
#[serde(rename = "mousemove")]
MouseMove { x: i32, y: i32 },
#[serde(rename = "mousemoveabs")]
MouseMoveAbs { x: i32, y: i32 },
#[serde(rename = "wheel")]
Wheel { x: f64, y: f64 },
#[serde(rename = "mousedown")]
MouseDown { key: i32 },
// Add other variants as needed
#[serde(rename = "mouseup")]
MouseUp { key: i32 },
#[serde(rename = "keydown")]
KeyDown { key: i32 },
#[serde(rename = "keyup")]
KeyUp { key: i32 },
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageBase {
pub payload_type: String,

View File

@@ -0,0 +1,466 @@
use crate::messages::{
decode_message_as, encode_message, AnswerType, InputMessage, JoinerType, MessageAnswer,
MessageBase, MessageICE, MessageInput, MessageJoin, MessageSDP,
};
use crate::websocket::NestriWebSocket;
use glib::subclass::prelude::*;
use gst::glib;
use gst::prelude::*;
use gst_webrtc::{gst_sdp, WebRTCSDPType, WebRTCSessionDescription};
use gstrswebrtc::signaller::{Signallable, SignallableImpl};
use std::collections::HashSet;
use std::sync::{Arc, LazyLock};
use std::sync::{Mutex, RwLock};
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
pub struct Signaller {
nestri_ws: RwLock<Option<Arc<NestriWebSocket>>>,
pipeline: RwLock<Option<Arc<gst::Pipeline>>>,
data_channel: RwLock<Option<gst_webrtc::WebRTCDataChannel>>,
}
impl Default for Signaller {
fn default() -> Self {
Self {
nestri_ws: RwLock::new(None),
pipeline: RwLock::new(None),
data_channel: RwLock::new(None),
}
}
}
impl Signaller {
pub fn set_nestri_ws(&self, nestri_ws: Arc<NestriWebSocket>) {
*self.nestri_ws.write().unwrap() = Some(nestri_ws);
}
pub fn set_pipeline(&self, pipeline: Arc<gst::Pipeline>) {
*self.pipeline.write().unwrap() = Some(pipeline);
}
pub fn get_pipeline(&self) -> Option<Arc<gst::Pipeline>> {
self.pipeline.read().unwrap().clone()
}
pub fn set_data_channel(&self, data_channel: gst_webrtc::WebRTCDataChannel) {
*self.data_channel.write().unwrap() = Some(data_channel);
}
/// Helper method to clean things up
fn register_callbacks(&self) {
let nestri_ws = {
self.nestri_ws
.read()
.unwrap()
.clone()
.expect("NestriWebSocket not set")
};
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("sdp", move |data| {
if let Ok(message) = decode_message_as::<MessageSDP>(data) {
let sdp =
gst_sdp::SDPMessage::parse_buffer(message.sdp.sdp.as_bytes()).unwrap();
let answer = WebRTCSessionDescription::new(WebRTCSDPType::Answer, sdp);
self_obj.emit_by_name::<()>(
"session-description",
&[&"unique-session-id", &answer],
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode SDP message");
}
});
}
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("ice", move |data| {
if let Ok(message) = decode_message_as::<MessageICE>(data) {
let candidate = message.candidate;
let sdp_m_line_index = candidate.sdp_mline_index.unwrap_or(0) as u32;
let sdp_mid = candidate.sdp_mid;
self_obj.emit_by_name::<()>(
"handle-ice",
&[
&"unique-session-id",
&sdp_m_line_index,
&sdp_mid,
&candidate.candidate,
],
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode ICE message");
}
});
}
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("answer", move |data| {
if let Ok(answer) = decode_message_as::<MessageAnswer>(data) {
gst::info!(gst::CAT_DEFAULT, "Received answer: {:?}", answer);
match answer.answer_type {
AnswerType::AnswerOK => {
gst::info!(gst::CAT_DEFAULT, "Received OK answer");
// Send our SDP offer
self_obj.emit_by_name::<()>(
"session-requested",
&[
&"unique-session-id",
&"consumer-identifier",
&None::<WebRTCSessionDescription>,
],
);
}
AnswerType::AnswerInUse => {
gst::error!(gst::CAT_DEFAULT, "Room is in use by another node");
}
AnswerType::AnswerOffline => {
gst::warning!(gst::CAT_DEFAULT, "Room is offline");
}
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode answer");
}
});
}
{
let self_obj = self.obj().clone();
// After creating webrtcsink
self_obj.connect_closure(
"webrtcbin-ready",
false,
glib::closure!(move |signaller: &super::NestriSignaller,
_consumer_identifier: &str,
webrtcbin: &gst::Element| {
gst::info!(gst::CAT_DEFAULT, "Adding data channels");
// Create data channels on webrtcbin
let data_channel = Some(
webrtcbin.emit_by_name::<gst_webrtc::WebRTCDataChannel>(
"create-data-channel",
&[
&"nestri-data-channel",
&gst::Structure::builder("config")
.field("ordered", &false)
.field("max-retransmits", &0u32)
.build(),
],
),
);
if let Some(data_channel) = data_channel {
gst::info!(gst::CAT_DEFAULT, "Data channel created");
if let Some(pipeline) = signaller.imp().get_pipeline() {
setup_data_channel(&data_channel, &pipeline);
signaller.imp().set_data_channel(data_channel);
} else {
gst::error!(gst::CAT_DEFAULT, "Wayland display source not set");
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to create data channel");
}
}),
);
}
}
}
impl SignallableImpl for Signaller {
fn start(&self) {
gst::info!(gst::CAT_DEFAULT, "Signaller started");
// Get WebSocket connection
let nestri_ws = {
self.nestri_ws
.read()
.unwrap()
.clone()
.expect("NestriWebSocket not set")
};
// Register message callbacks
self.register_callbacks();
// Subscribe to reconnection notifications
let reconnected_notify = nestri_ws.subscribe_reconnected();
// Clone necessary references
let self_clone = self.obj().clone();
let nestri_ws_clone = nestri_ws.clone();
// Spawn a task to handle actions upon reconnection
tokio::spawn(async move {
loop {
// Wait for a reconnection notification
reconnected_notify.notified().await;
println!("Reconnected to relay, re-negotiating...");
gst::warning!(gst::CAT_DEFAULT, "Reconnected to relay, re-negotiating...");
// Emit "session-ended" first to make sure the element is cleaned up
self_clone.emit_by_name::<bool>("session-ended", &[&"unique-session-id"]);
// Send a new join message
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
},
joiner_type: JoinerType::JoinerNode,
};
if let Ok(encoded) = encode_message(&join_msg) {
if let Err(e) = nestri_ws_clone.send_message(encoded) {
gst::error!(
gst::CAT_DEFAULT,
"Failed to send join message after reconnection: {:?}",
e
);
}
} else {
gst::error!(
gst::CAT_DEFAULT,
"Failed to encode join message after reconnection"
);
}
// If we need to interact with GStreamer or GLib, schedule it on the main thread
let self_clone_for_main = self_clone.clone();
glib::MainContext::default().invoke(move || {
// Emit the "session-requested" signal
self_clone_for_main.emit_by_name::<()>(
"session-requested",
&[
&"unique-session-id",
&"consumer-identifier",
&None::<WebRTCSessionDescription>,
],
);
});
}
});
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
},
joiner_type: JoinerType::JoinerNode,
};
if let Ok(encoded) = encode_message(&join_msg) {
if let Err(e) = nestri_ws.send_message(encoded) {
eprintln!("Failed to send join message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send join message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode join message");
}
}
fn stop(&self) {
gst::info!(gst::CAT_DEFAULT, "Signaller stopped");
}
fn send_sdp(&self, _session_id: &str, sdp: &WebRTCSessionDescription) {
let nestri_ws = {
self.nestri_ws
.read()
.unwrap()
.clone()
.expect("NestriWebSocket not set")
};
let sdp_message = MessageSDP {
base: MessageBase {
payload_type: "sdp".to_string(),
},
sdp: RTCSessionDescription::offer(sdp.sdp().as_text().unwrap()).unwrap(),
};
if let Ok(encoded) = encode_message(&sdp_message) {
if let Err(e) = nestri_ws.send_message(encoded) {
eprintln!("Failed to send SDP message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send SDP message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode SDP message");
}
}
fn add_ice(
&self,
_session_id: &str,
candidate: &str,
sdp_m_line_index: u32,
sdp_mid: Option<String>,
) {
let nestri_ws = {
self.nestri_ws
.read()
.unwrap()
.clone()
.expect("NestriWebSocket not set")
};
let candidate_init = RTCIceCandidateInit {
candidate: candidate.to_string(),
sdp_mid,
sdp_mline_index: Some(sdp_m_line_index as u16),
..Default::default()
};
let ice_message = MessageICE {
base: MessageBase {
payload_type: "ice".to_string(),
},
candidate: candidate_init,
};
if let Ok(encoded) = encode_message(&ice_message) {
if let Err(e) = nestri_ws.send_message(encoded) {
eprintln!("Failed to send ICE message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send ICE message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode ICE message");
}
}
fn end_session(&self, session_id: &str) {
gst::info!(gst::CAT_DEFAULT, "Ending session: {}", session_id);
}
}
#[glib::object_subclass]
impl ObjectSubclass for Signaller {
const NAME: &'static str = "NestriSignaller";
type Type = super::NestriSignaller;
type ParentType = glib::Object;
type Interfaces = (Signallable,);
}
impl ObjectImpl for Signaller {
fn properties() -> &'static [glib::ParamSpec] {
static PROPS: LazyLock<Vec<glib::ParamSpec>> = LazyLock::new(|| {
vec![glib::ParamSpecBoolean::builder("manual-sdp-munging")
.nick("Manual SDP munging")
.blurb("Whether the signaller manages SDP munging itself")
.default_value(false)
.read_only()
.build()]
});
PROPS.as_ref()
}
fn property(&self, _id: usize, pspec: &glib::ParamSpec) -> glib::Value {
match pspec.name() {
"manual-sdp-munging" => false.to_value(),
_ => unimplemented!(),
}
}
}
fn setup_data_channel(data_channel: &gst_webrtc::WebRTCDataChannel, pipeline: &gst::Pipeline) {
let pipeline = pipeline.clone();
// A shared state to track currently pressed keys
let pressed_keys = Arc::new(Mutex::new(HashSet::new()));
let pressed_buttons = Arc::new(Mutex::new(HashSet::new()));
data_channel.connect_on_message_data(move |_data_channel, data| {
if let Some(data) = data {
match decode_message_as::<MessageInput>(data.to_vec()) {
Ok(message_input) => {
// Deserialize the input message data
if let Ok(input_msg) = serde_json::from_str::<InputMessage>(&message_input.data)
{
// Process the input message and create an event
if let Some(event) =
handle_input_message(input_msg, &pressed_keys, &pressed_buttons)
{
// Send the event to pipeline, result bool is ignored
let _ = pipeline.send_event(event);
}
} else {
eprintln!("Failed to parse InputMessage");
}
}
Err(e) => {
eprintln!("Failed to decode MessageInput: {:?}", e);
}
}
}
});
}
fn handle_input_message(
input_msg: InputMessage,
pressed_keys: &Arc<Mutex<HashSet<i32>>>,
pressed_buttons: &Arc<Mutex<HashSet<i32>>>,
) -> Option<gst::Event> {
match input_msg {
InputMessage::MouseMove { x, y } => {
let structure = gst::Structure::builder("MouseMoveRelative")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseMoveAbs { x, y } => {
let structure = gst::Structure::builder("MouseMoveAbsolute")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyDown { key } => {
let mut keys = pressed_keys.lock().unwrap();
// If the key is already pressed, return to prevent key lockup
if keys.contains(&key) {
return None;
}
keys.insert(key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyUp { key } => {
let mut keys = pressed_keys.lock().unwrap();
// Remove the key from the pressed state when released
keys.remove(&key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::Wheel { x, y } => {
let structure = gst::Structure::builder("MouseAxis")
.field("x", x as f64)
.field("y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseDown { key } => {
let mut buttons = pressed_buttons.lock().unwrap();
// If the button is already pressed, return to prevent button lockup
if buttons.contains(&key) {
return None;
}
buttons.insert(key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseUp { key } => {
let mut buttons = pressed_buttons.lock().unwrap();
// Remove the button from the pressed state when released
buttons.remove(&key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
}
}

View File

@@ -0,0 +1,25 @@
use std::sync::Arc;
use gst::glib;
use gst::subclass::prelude::*;
use gstrswebrtc::signaller::Signallable;
use crate::websocket::NestriWebSocket;
mod imp;
glib::wrapper! {
pub struct NestriSignaller(ObjectSubclass<imp::Signaller>) @implements Signallable;
}
impl NestriSignaller {
pub fn new(nestri_ws: Arc<NestriWebSocket>, pipeline: Arc<gst::Pipeline>) -> Self {
let obj: Self = glib::Object::new();
obj.imp().set_nestri_ws(nestri_ws);
obj.imp().set_pipeline(pipeline);
obj
}
}
impl Default for NestriSignaller {
fn default() -> Self {
panic!("Cannot create NestriSignaller without NestriWebSocket");
}
}

View File

@@ -1,540 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use std::collections::{HashSet};
use std::error::Error;
use std::sync::Arc;
use futures_util::StreamExt;
use gst::prelude::ElementExtManual;
use gst_app::app_sink::AppSinkStream;
use tokio::sync::{oneshot, Mutex};
use tokio::sync::{mpsc};
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::ice_transport::ice_gathering_state::RTCIceGatheringState;
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
use webrtc::track::track_local::TrackLocalWriter;
use crate::messages::*;
use crate::websocket::NestriWebSocket;
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
enum InputMessage {
#[serde(rename = "mousemove")]
MouseMove { x: i32, y: i32 },
#[serde(rename = "mousemoveabs")]
MouseMoveAbs { x: i32, y: i32 },
#[serde(rename = "wheel")]
Wheel { x: f64, y: f64 },
#[serde(rename = "mousedown")]
MouseDown { key: i32 },
// Add other variants as needed
#[serde(rename = "mouseup")]
MouseUp { key: i32 },
#[serde(rename = "keydown")]
KeyDown { key: i32 },
#[serde(rename = "keyup")]
KeyUp { key: i32 },
}
pub struct Room {
nestri_ws: Arc<NestriWebSocket>,
webrtc_api: webrtc::api::API,
webrtc_config: RTCConfiguration,
}
impl Room {
pub async fn new(
nestri_ws: Arc<NestriWebSocket>,
) -> Result<Room, Box<dyn Error>> {
// Create media engine and register default codecs
let mut media_engine = MediaEngine::default();
media_engine.register_default_codecs()?;
// Registry
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media_engine)?;
// Create the API object with the MediaEngine
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
// Prepare the configuration
let config = RTCConfiguration {
ice_servers: vec![RTCIceServer {
urls: vec!["stun:stun.l.google.com:19302".to_owned()],
..Default::default()
}],
..Default::default()
};
Ok(Self {
nestri_ws,
webrtc_api: api,
webrtc_config: config,
})
}
pub async fn run(
&mut self,
audio_codec: &str,
video_codec: &str,
pipeline: Arc<Mutex<gst::Pipeline>>,
audio_sink: Arc<Mutex<AppSinkStream>>,
video_sink: Arc<Mutex<AppSinkStream>>,
) -> Result<(), Box<dyn Error>> {
let (tx, rx) = oneshot::channel();
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
self.nestri_ws
.register_callback("answer", {
let tx = tx.clone();
move |data| {
if let Ok(answer) = decode_message_as::<MessageAnswer>(data) {
log::info!("Received answer: {:?}", answer);
match answer.answer_type {
AnswerType::AnswerOffline => {
log::warn!("Room is offline, we shouldn't be receiving this");
}
AnswerType::AnswerInUse => {
log::error!("Room is in use by another node!");
}
AnswerType::AnswerOK => {
// Notify that we got an OK answer
if let Some(tx) = tx.lock().unwrap().take() {
if let Err(_) = tx.send(()) {
log::error!("Failed to send OK answer signal");
}
}
}
}
} else {
log::error!("Failed to decode answer");
}
}
})
.await;
// Send a request to join the room
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
},
joiner_type: JoinerType::JoinerNode,
};
if let Ok(encoded) = encode_message(&join_msg) {
self.nestri_ws.send_message(encoded).await?;
} else {
log::error!("Failed to encode join message");
return Err("Failed to encode join message".into());
}
// Wait for the signal indicating that we have received an OK answer
match rx.await {
Ok(()) => {
log::info!("Received OK answer, proceeding...");
}
Err(_) => {
log::error!("Oneshot channel closed unexpectedly");
return Err("Unexpected error while waiting for OK answer".into());
}
}
// Create a new RTCPeerConnection
let config = self.webrtc_config.clone();
let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?);
// Create audio track
let audio_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: audio_codec.to_owned(),
..Default::default()
},
"audio".to_owned(),
"audio-nestri-server".to_owned(),
));
// Create video track
let video_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: video_codec.to_owned(),
..Default::default()
},
"video".to_owned(),
"video-nestri-server".to_owned(),
));
// Cancellation token to stop spawned tasks after peer connection is closed
let cancel_token = tokio_util::sync::CancellationToken::new();
// Add audio track to peer connection
let audio_sender = peer_connection.add_track(audio_track.clone()).await?;
let audio_sender_token = cancel_token.child_token();
tokio::spawn(async move {
loop {
let mut rtcp_buf = vec![0u8; 1500];
tokio::select! {
_ = audio_sender_token.cancelled() => {
break;
}
_ = audio_sender.read(&mut rtcp_buf) => {}
}
}
});
// Add video track to peer connection
let video_sender = peer_connection.add_track(video_track.clone()).await?;
let video_sender_token = cancel_token.child_token();
tokio::spawn(async move {
loop {
let mut rtcp_buf = vec![0u8; 1500];
tokio::select! {
_ = video_sender_token.cancelled() => {
break;
}
_ = video_sender.read(&mut rtcp_buf) => {}
}
}
});
// Create a datachannel with label 'input'
let data_channel_opts = Some(RTCDataChannelInit {
ordered: Some(false),
max_retransmits: Some(0),
..Default::default()
});
let data_channel
= peer_connection.create_data_channel("input", data_channel_opts).await?;
// PeerConnection state change tracker
let (pc_sndr, mut pc_recv) = mpsc::channel(1);
// Peer connection state change handler
peer_connection.on_peer_connection_state_change(Box::new(
move |s: RTCPeerConnectionState| {
let pc_sndr = pc_sndr.clone();
Box::pin(async move {
log::info!("PeerConnection State has changed: {s}");
if s == RTCPeerConnectionState::Failed
|| s == RTCPeerConnectionState::Disconnected
|| s == RTCPeerConnectionState::Closed
{
// Notify pc_state that the peer connection has closed
if let Err(e) = pc_sndr.send(s).await {
log::error!("Failed to send PeerConnection state: {}", e);
}
}
})
},
));
peer_connection.on_ice_gathering_state_change(Box::new(move |s| {
Box::pin(async move {
log::info!("ICE Gathering State has changed: {s}");
})
}));
peer_connection.on_ice_connection_state_change(Box::new(move |s| {
Box::pin(async move {
log::info!("ICE Connection State has changed: {s}");
})
}));
// Trickle ICE over WebSocket
let ws = self.nestri_ws.clone();
peer_connection.on_ice_candidate(Box::new(move |c| {
let nestri_ws = ws.clone();
Box::pin(async move {
if let Some(candidate) = c {
let candidate_json = candidate.to_json().unwrap();
let ice_msg = MessageICE {
base: MessageBase {
payload_type: "ice".to_string(),
},
candidate: candidate_json,
};
if let Ok(encoded) = encode_message(&ice_msg) {
let _ = nestri_ws.send_message(encoded);
}
}
})
}));
// Temporary ICE candidate buffer until remote description is set
let ice_holder: Arc<Mutex<Vec<RTCIceCandidateInit>>> = Arc::new(Mutex::new(Vec::new()));
// Register set_response_callback for ICE candidate
let pc = peer_connection.clone();
let ice_clone = ice_holder.clone();
self.nestri_ws.register_callback("ice", move |data| {
match decode_message_as::<MessageICE>(data) {
Ok(message) => {
log::info!("Received ICE message");
let candidate = RTCIceCandidateInit::from(message.candidate);
let pc = pc.clone();
let ice_clone = ice_clone.clone();
tokio::spawn(async move {
// If remote description is not set, buffer ICE candidates
if pc.remote_description().await.is_none() {
let mut ice_holder = ice_clone.lock().await;
ice_holder.push(candidate);
} else {
if let Err(e) = pc.add_ice_candidate(candidate).await {
log::error!("Failed to add ICE candidate: {}", e);
} else {
// Add any held ICE candidates
let mut ice_holder = ice_clone.lock().await;
for candidate in ice_holder.drain(..) {
if let Err(e) = pc.add_ice_candidate(candidate).await {
log::error!("Failed to add ICE candidate: {}", e);
}
}
}
}
});
}
Err(e) => eprintln!("Failed to decode callback message: {:?}", e),
}
}).await;
// A shared state to track currently pressed keys
let pressed_keys = Arc::new(Mutex::new(HashSet::new()));
let pressed_buttons = Arc::new(Mutex::new(HashSet::new()));
// Data channel message handler
data_channel.on_message(Box::new(move |msg: DataChannelMessage| {
let pipeline = pipeline.clone();
let pressed_keys = pressed_keys.clone();
let pressed_buttons = pressed_buttons.clone();
Box::pin({
async move {
// We don't care about string messages for now
if !msg.is_string {
// Decode the message as an MessageInput (binary encoded gzip)
match decode_message_as::<MessageInput>(msg.data.to_vec()) {
Ok(message_input) => {
// Handle the input message
if let Ok(input_msg) = from_str::<InputMessage>(&message_input.data) {
if let Some(event) =
handle_input_message(input_msg, &pressed_keys, &pressed_buttons).await
{
let _ = pipeline.lock().await.send_event(event);
}
}
}
Err(e) => {
log::error!("Failed to decode input message: {:?}", e);
}
}
}
}
})
}));
log::info!("Creating offer...");
// Create an offer to send to the browser
let offer = peer_connection.create_offer(None).await?;
log::info!("Setting local description...");
// Sets the LocalDescription, and starts our UDP listeners
peer_connection.set_local_description(offer).await?;
log::info!("Local description set...");
if let Some(local_description) = peer_connection.local_description().await {
// Wait until we have gathered all ICE candidates
while peer_connection.ice_gathering_state() != RTCIceGatheringState::Complete {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
// Register set_response_callback for SDP answer
let pc = peer_connection.clone();
self.nestri_ws.register_callback("sdp", move |data| {
match decode_message_as::<MessageSDP>(data) {
Ok(message) => {
log::info!("Received SDP message");
let sdp = message.sdp;
let pc = pc.clone();
tokio::spawn(async move {
if let Err(e) = pc.set_remote_description(sdp).await {
log::error!("Failed to set remote description: {}", e);
}
});
}
Err(e) => eprintln!("Failed to decode callback message: {:?}", e),
}
}).await;
log::info!("Sending local description to remote...");
// Encode and send the local description via WebSocket
let sdp_msg = MessageSDP {
base: MessageBase {
payload_type: "sdp".to_string(),
},
sdp: local_description,
};
let encoded = encode_message(&sdp_msg)?;
self.nestri_ws.send_message(encoded).await?;
} else {
log::error!("generate local_description failed!");
cancel_token.cancel();
return Err("generate local_description failed!".into());
};
// Send video and audio data
let audio_track = audio_track.clone();
tokio::spawn(async move {
let mut audio_sink = audio_sink.lock().await;
while let Some(sample) = audio_sink.next().await {
if let Some(buffer) = sample.buffer() {
if let Ok(map) = buffer.map_readable() {
if let Err(e) = audio_track.write(map.as_slice()).await {
if webrtc::Error::ErrClosedPipe == e {
break;
} else {
log::error!("Failed to write audio track: {}", e);
}
}
}
}
}
});
let video_track = video_track.clone();
tokio::spawn(async move {
let mut video_sink = video_sink.lock().await;
while let Some(sample) = video_sink.next().await {
if let Some(buffer) = sample.buffer() {
if let Ok(map) = buffer.map_readable() {
if let Err(e) = video_track.write(map.as_slice()).await {
if webrtc::Error::ErrClosedPipe == e {
break;
} else {
log::error!("Failed to write video track: {}", e);
}
}
}
}
}
});
// Block until closed or error
tokio::select! {
_ = pc_recv.recv() => {
log::info!("Peer connection closed with state: {:?}", peer_connection.connection_state());
}
}
cancel_token.cancel();
// Make double-sure to close the peer connection
if let Err(e) = peer_connection.close().await {
log::error!("Failed to close peer connection: {}", e);
}
Ok(())
}
}
async fn handle_input_message(
input_msg: InputMessage,
pressed_keys: &Arc<Mutex<HashSet<i32>>>,
pressed_buttons: &Arc<Mutex<HashSet<i32>>>,
) -> Option<gst::Event> {
match input_msg {
InputMessage::MouseMove { x, y } => {
let structure = gst::Structure::builder("MouseMoveRelative")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseMoveAbs { x, y } => {
let structure = gst::Structure::builder("MouseMoveAbsolute")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyDown { key } => {
let mut keys = pressed_keys.lock().await;
// If the key is already pressed, return to prevent key lockup
if keys.contains(&key) {
return None;
}
keys.insert(key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyUp { key } => {
let mut keys = pressed_keys.lock().await;
// Remove the key from the pressed state when released
keys.remove(&key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::Wheel { x, y } => {
let structure = gst::Structure::builder("MouseAxis")
.field("x", x as f64)
.field("y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseDown { key } => {
let mut buttons = pressed_buttons.lock().await;
// If the button is already pressed, return to prevent button lockup
if buttons.contains(&key) {
return None;
}
buttons.insert(key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseUp { key } => {
let mut buttons = pressed_buttons.lock().await;
// Remove the button from the pressed state when released
buttons.remove(&key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
}
}

View File

@@ -1,18 +1,17 @@
use crate::messages::{decode_message, encode_message, MessageBase, MessageLog};
use futures_util::sink::SinkExt;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::StreamExt;
use log::{Level, Log, Metadata, Record};
use std::collections::HashMap;
use std::error::Error;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::{mpsc, Mutex, Notify};
use tokio::time::sleep;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use crate::messages::{decode_message, encode_message, MessageBase, MessageLog};
type Callback = Box<dyn Fn(Vec<u8>) + Send + Sync>;
type WSRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
@@ -21,28 +20,36 @@ type WSWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
#[derive(Clone)]
pub struct NestriWebSocket {
ws_url: String,
reader: Arc<Mutex<WSRead>>,
writer: Arc<Mutex<WSWrite>>,
reader: Arc<Mutex<Option<WSRead>>>,
writer: Arc<Mutex<Option<WSWrite>>>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
message_tx: mpsc::UnboundedSender<Vec<u8>>,
reconnected_notify: Arc<Notify>,
}
impl NestriWebSocket {
pub async fn new(
ws_url: String,
) -> Result<NestriWebSocket, Box<dyn Error>> {
pub async fn new(ws_url: String) -> Result<NestriWebSocket, Box<dyn Error>> {
// Attempt to connect to the WebSocket
let ws_stream = NestriWebSocket::do_connect(&ws_url).await.unwrap();
// If the connection is successful, split the stream
// Split the stream into read and write halves
let (write, read) = ws_stream.split();
let mut ws = NestriWebSocket {
// Create the message channel
let (message_tx, message_rx) = mpsc::unbounded_channel();
let ws = NestriWebSocket {
ws_url,
reader: Arc::new(Mutex::new(read)),
writer: Arc::new(Mutex::new(write)),
reader: Arc::new(Mutex::new(Some(read))),
writer: Arc::new(Mutex::new(Some(write))),
callbacks: Arc::new(RwLock::new(HashMap::new())),
message_tx: message_tx.clone(),
reconnected_notify: Arc::new(Notify::new()),
};
// Spawn the read loop
ws.spawn_read_loop();
// Spawn the write loop
ws.spawn_write_loop(message_rx);
Ok(ws)
}
@@ -57,89 +64,160 @@ impl NestriWebSocket {
}
Err(e) => {
eprintln!("Failed to connect to WebSocket, retrying: {:?}", e);
sleep(Duration::from_secs(1)).await; // Wait before retrying
sleep(Duration::from_secs(3)).await; // Wait before retrying
}
}
}
}
// Handles message -> callback calls and reconnects on error/disconnect
fn spawn_read_loop(&mut self) {
fn spawn_read_loop(&self) {
let reader = self.reader.clone();
let callbacks = self.callbacks.clone();
let mut self_clone = self.clone();
let self_clone = self.clone();
tokio::spawn(async move {
loop {
let message = reader.lock().await.next().await;
match message {
Some(Ok(message)) => {
let data = message.into_data();
let base_message = match decode_message(&data) {
Ok(base_message) => base_message,
Err(e) => {
eprintln!("Failed to decode message: {:?}", e);
continue;
}
};
// Lock the reader to get the WSRead, then drop the lock
let ws_read_option = {
let mut reader_lock = reader.lock().await;
reader_lock.take()
};
let callbacks_lock = callbacks.read().await;
if let Some(callback) = callbacks_lock.get(&base_message.payload_type) {
let data = data.clone();
callback(data);
let mut ws_read = match ws_read_option {
Some(ws_read) => ws_read,
None => {
eprintln!("Reader is None, cannot proceed");
return;
}
};
while let Some(message_result) = ws_read.next().await {
match message_result {
Ok(message) => {
let data = message.into_data();
let base_message = match decode_message(&data) {
Ok(base_message) => base_message,
Err(e) => {
eprintln!("Failed to decode message: {:?}", e);
continue;
}
};
let callbacks_lock = callbacks.read().unwrap();
if let Some(callback) = callbacks_lock.get(&base_message.payload_type) {
let data = data.clone();
callback(data);
}
}
Err(e) => {
eprintln!("Error receiving message: {:?}, reconnecting in 3 seconds...", e);
sleep(Duration::from_secs(3)).await;
self_clone.reconnect().await.unwrap();
break; // Break the inner loop to get a new ws_read
}
}
Some(Err(e)) => {
eprintln!("Error receiving message: {:?}", e);
self_clone.reconnect().await.unwrap();
}
None => {
eprintln!("WebSocket connection closed, reconnecting...");
self_clone.reconnect().await.unwrap();
}
// After reconnection, the loop continues, and we acquire a new ws_read
}
});
}
fn spawn_write_loop(&self, mut message_rx: mpsc::UnboundedReceiver<Vec<u8>>) {
let writer = self.writer.clone();
let self_clone = self.clone();
tokio::spawn(async move {
loop {
// Wait for a message from the channel
if let Some(message) = message_rx.recv().await {
loop {
// Acquire the writer lock
let mut writer_lock = writer.lock().await;
if let Some(writer) = writer_lock.as_mut() {
// Try to send the message over the WebSocket
match writer.send(Message::Binary(message.clone())).await {
Ok(_) => {
// Message sent successfully
break;
}
Err(e) => {
eprintln!("Error sending message: {:?}", e);
// Attempt to reconnect
if let Err(e) = self_clone.reconnect().await {
eprintln!("Error during reconnection: {:?}", e);
// Wait before retrying
sleep(Duration::from_secs(3)).await;
continue;
}
}
}
} else {
eprintln!("Writer is None, cannot send message");
// Attempt to reconnect
if let Err(e) = self_clone.reconnect().await {
eprintln!("Error during reconnection: {:?}", e);
// Wait before retrying
sleep(Duration::from_secs(3)).await;
continue;
}
}
}
} else {
break;
}
}
});
}
async fn reconnect(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
// Keep trying to reconnect until successful
async fn reconnect(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
loop {
match NestriWebSocket::do_connect(&self.ws_url).await {
Ok(ws_stream) => {
let (write, read) = ws_stream.split();
*self.reader.lock().await = read;
*self.writer.lock().await = write;
{
let mut writer_lock = self.writer.lock().await;
*writer_lock = Some(write);
}
{
let mut reader_lock = self.reader.lock().await;
*reader_lock = Some(read);
}
// Notify subscribers of successful reconnection
self.reconnected_notify.notify_waiters();
return Ok(());
}
Err(e) => {
eprintln!("Failed to reconnect to WebSocket: {:?}", e);
sleep(Duration::from_secs(2)).await; // Wait before retrying
sleep(Duration::from_secs(3)).await; // Wait before retrying
}
}
}
}
pub async fn send_message(&self, message: Vec<u8>) -> Result<(), Box<dyn Error>> {
let mut writer_lock = self.writer.lock().await;
writer_lock
.send(Message::Binary(message))
.await
.map_err(|e| format!("Error sending message: {:?}", e).into())
/// Send a message through the WebSocket
pub fn send_message(&self, message: Vec<u8>) -> Result<(), Box<dyn Error>> {
self.message_tx
.send(message)
.map_err(|e| format!("Failed to send message: {:?}", e).into())
}
pub async fn register_callback<F>(&self, response_type: &str, callback: F)
/// Register a callback for a specific response type
pub fn register_callback<F>(&self, response_type: &str, callback: F)
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
let mut callbacks_lock = self.callbacks.write().await;
let mut callbacks_lock = self.callbacks.write().unwrap();
callbacks_lock.insert(response_type.to_string(), Box::new(callback));
}
/// Subscribe to event for reconnection
pub fn subscribe_reconnected(&self) -> Arc<Notify> {
self.reconnected_notify.clone()
}
}
impl Log for NestriWebSocket {
fn enabled(&self, metadata: &Metadata) -> bool {
// Filter logs by level
metadata.level() <= Level::Info
}
@@ -162,7 +240,9 @@ impl Log for NestriWebSocket {
time,
};
if let Ok(encoded_message) = encode_message(&log_message) {
let _ = self.send_message(encoded_message);
if let Err(e) = self.send_message(encoded_message) {
eprintln!("Failed to send log message: {:?}", e);
}
}
}
}
@@ -171,4 +251,3 @@ impl Log for NestriWebSocket {
// No-op for this logger
}
}