feat: Add protobuf (#171)

This is a second attempt to add protobuf to Nestri, after the first one
failed

---------

Co-authored-by: Philipp Neumann <3daquawolf@gmail.com>
Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com>
This commit is contained in:
Wanjohi
2025-01-29 04:16:27 +03:00
committed by GitHub
parent be6ea11052
commit c2363b0bce
42 changed files with 3114 additions and 854 deletions

View File

@@ -48,7 +48,10 @@ impl AppArgs {
.unwrap()
.parse::<u32>()
.unwrap_or(60),
relay_url: matches.get_one::<String>("relay-url").unwrap().clone(),
relay_url: matches
.get_one::<String>("relay-url")
.expect("relay url cannot be empty")
.clone(),
// Generate random room name if not provided
room: matches
.get_one::<String>("room")

View File

@@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Serialize, Deserialize, Debug)]
pub struct TimestampEntry {

View File

@@ -5,6 +5,7 @@ mod latency;
mod messages;
mod nestrisink;
mod websocket;
mod proto;
use crate::args::encoding_args;
use crate::nestrisink::NestriSignaller;

View File

@@ -1,50 +1,14 @@
use std::error::Error;
use std::io::{Read, Write};
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use crate::latency::LatencyTracker;
use num_derive::{FromPrimitive, ToPrimitive};
use num_traits::{FromPrimitive, ToPrimitive};
use serde::{Deserialize, Serialize};
use std::error::Error;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::latency::LatencyTracker;
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
pub enum InputMessage {
#[serde(rename = "mousemove")]
MouseMove { x: i32, y: i32 },
#[serde(rename = "mousemoveabs")]
MouseMoveAbs { x: i32, y: i32 },
#[serde(rename = "wheel")]
Wheel { x: f64, y: f64 },
#[serde(rename = "mousedown")]
MouseDown { key: i32 },
// Add other variants as needed
#[serde(rename = "mouseup")]
MouseUp { key: i32 },
#[serde(rename = "keydown")]
KeyDown { key: i32 },
#[serde(rename = "keyup")]
KeyUp { key: i32 },
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageBase {
pub payload_type: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageInput {
#[serde(flatten)]
pub base: MessageBase,
pub data: String,
pub latency: Option<LatencyTracker>,
}
@@ -136,34 +100,21 @@ pub struct MessageAnswer {
pub answer_type: AnswerType,
}
pub fn encode_message<T: Serialize>(message: &T) -> Result<Vec<u8>, Box<dyn Error>> {
pub fn encode_message<T: Serialize>(message: &T) -> Result<String, Box<dyn Error>> {
// Serialize the message to JSON
let json = serde_json::to_string(message)?;
// Compress the JSON using gzip
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(json.as_bytes())?;
let compressed_data = encoder.finish()?;
Ok(compressed_data)
Ok(json)
}
pub fn decode_message(data: &[u8]) -> Result<MessageBase, Box<dyn Error + Send + Sync>> {
let mut decoder = GzDecoder::new(data);
let mut decompressed_data = String::new();
decoder.read_to_string(&mut decompressed_data)?;
let base_message: MessageBase = serde_json::from_str(&decompressed_data)?;
pub fn decode_message(data: String) -> Result<MessageBase, Box<dyn Error + Send + Sync>> {
println!("Data: {}", data);
let base_message: MessageBase = serde_json::from_str(&data)?;
Ok(base_message)
}
pub fn decode_message_as<T: for<'de> Deserialize<'de>>(
data: Vec<u8>,
data: String,
) -> Result<T, Box<dyn Error + Send + Sync>> {
let mut decoder = GzDecoder::new(data.as_slice());
let mut decompressed_data = String::new();
decoder.read_to_string(&mut decompressed_data)?;
let message: T = serde_json::from_str(&decompressed_data)?;
let message: T = serde_json::from_str(&data)?;
Ok(message)
}
}

View File

@@ -1,13 +1,18 @@
use crate::messages::{
decode_message_as, encode_message, AnswerType, InputMessage, JoinerType, MessageAnswer,
MessageBase, MessageICE, MessageInput, MessageJoin, MessageSDP,
decode_message_as, encode_message, AnswerType, JoinerType, MessageAnswer, MessageBase,
MessageICE, MessageJoin, MessageSDP,
};
use crate::proto::proto::proto_input::InputType::{
KeyDown, KeyUp, MouseKeyDown, MouseKeyUp, MouseMove, MouseMoveAbs, MouseWheel,
};
use crate::proto::proto::{ProtoInput, ProtoMessageInput};
use crate::websocket::NestriWebSocket;
use glib::subclass::prelude::*;
use gst::glib;
use gst::prelude::*;
use gst_webrtc::{gst_sdp, WebRTCSDPType, WebRTCSessionDescription};
use gstrswebrtc::signaller::{Signallable, SignallableImpl};
use prost::Message;
use std::collections::HashSet;
use std::sync::{Arc, LazyLock};
use std::sync::{Mutex, RwLock};
@@ -200,6 +205,7 @@ impl SignallableImpl for Signaller {
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
latency: None,
},
joiner_type: JoinerType::JoinerNode,
};
@@ -237,6 +243,7 @@ impl SignallableImpl for Signaller {
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
latency: None,
},
joiner_type: JoinerType::JoinerNode,
};
@@ -265,6 +272,7 @@ impl SignallableImpl for Signaller {
let sdp_message = MessageSDP {
base: MessageBase {
payload_type: "sdp".to_string(),
latency: None,
},
sdp: RTCSessionDescription::offer(sdp.sdp().as_text().unwrap()).unwrap(),
};
@@ -301,6 +309,7 @@ impl SignallableImpl for Signaller {
let ice_message = MessageICE {
base: MessageBase {
payload_type: "ice".to_string(),
latency: None,
},
candidate: candidate_init,
};
@@ -354,11 +363,9 @@ fn setup_data_channel(data_channel: &gst_webrtc::WebRTCDataChannel, pipeline: &g
data_channel.connect_on_message_data(move |_data_channel, data| {
if let Some(data) = data {
match decode_message_as::<MessageInput>(data.to_vec()) {
match ProtoMessageInput::decode(data.to_vec().as_slice()) {
Ok(message_input) => {
// Deserialize the input message data
if let Ok(input_msg) = serde_json::from_str::<InputMessage>(&message_input.data)
{
if let Some(input_msg) = message_input.data {
// Process the input message and create an event
if let Some(event) =
handle_input_message(input_msg, &pressed_keys, &pressed_buttons)
@@ -379,88 +386,92 @@ fn setup_data_channel(data_channel: &gst_webrtc::WebRTCDataChannel, pipeline: &g
}
fn handle_input_message(
input_msg: InputMessage,
input_msg: ProtoInput,
pressed_keys: &Arc<Mutex<HashSet<i32>>>,
pressed_buttons: &Arc<Mutex<HashSet<i32>>>,
) -> Option<gst::Event> {
match input_msg {
InputMessage::MouseMove { x, y } => {
let structure = gst::Structure::builder("MouseMoveRelative")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
if let Some(input_type) = input_msg.input_type {
match input_type {
MouseMove(data) => {
let structure = gst::Structure::builder("MouseMoveRelative")
.field("pointer_x", data.x as f64)
.field("pointer_y", data.y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseMoveAbs { x, y } => {
let structure = gst::Structure::builder("MouseMoveAbsolute")
.field("pointer_x", x as f64)
.field("pointer_y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyDown { key } => {
let mut keys = pressed_keys.lock().unwrap();
// If the key is already pressed, return to prevent key lockup
if keys.contains(&key) {
return None;
Some(gst::event::CustomUpstream::new(structure))
}
keys.insert(key);
MouseMoveAbs(data) => {
let structure = gst::Structure::builder("MouseMoveAbsolute")
.field("pointer_x", data.x as f64)
.field("pointer_y", data.y as f64)
.build();
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::KeyUp { key } => {
let mut keys = pressed_keys.lock().unwrap();
// Remove the key from the pressed state when released
keys.remove(&key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::Wheel { x, y } => {
let structure = gst::Structure::builder("MouseAxis")
.field("x", x as f64)
.field("y", y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseDown { key } => {
let mut buttons = pressed_buttons.lock().unwrap();
// If the button is already pressed, return to prevent button lockup
if buttons.contains(&key) {
return None;
Some(gst::event::CustomUpstream::new(structure))
}
buttons.insert(key);
KeyDown(data) => {
let mut keys = pressed_keys.lock().unwrap();
// If the key is already pressed, return to prevent key lockup
if keys.contains(&data.key) {
return None;
}
keys.insert(data.key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", true)
.build();
let structure = gst::Structure::builder("KeyboardKey")
.field("key", data.key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
InputMessage::MouseUp { key } => {
let mut buttons = pressed_buttons.lock().unwrap();
// Remove the button from the pressed state when released
buttons.remove(&key);
let structure = gst::Structure::builder("MouseButton")
.field("button", key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
Some(gst::event::CustomUpstream::new(structure))
}
KeyUp(data) => {
let mut keys = pressed_keys.lock().unwrap();
// Remove the key from the pressed state when released
keys.remove(&data.key);
let structure = gst::Structure::builder("KeyboardKey")
.field("key", data.key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
MouseWheel(data) => {
let structure = gst::Structure::builder("MouseAxis")
.field("x", data.x as f64)
.field("y", data.y as f64)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
MouseKeyDown(data) => {
let mut buttons = pressed_buttons.lock().unwrap();
// If the button is already pressed, return to prevent button lockup
if buttons.contains(&data.key) {
return None;
}
buttons.insert(data.key);
let structure = gst::Structure::builder("MouseButton")
.field("button", data.key as u32)
.field("pressed", true)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
MouseKeyUp(data) => {
let mut buttons = pressed_buttons.lock().unwrap();
// Remove the button from the pressed state when released
buttons.remove(&data.key);
let structure = gst::Structure::builder("MouseButton")
.field("button", data.key as u32)
.field("pressed", false)
.build();
Some(gst::event::CustomUpstream::new(structure))
}
}
} else {
None
}
}

View File

@@ -0,0 +1 @@
pub mod proto;

View File

@@ -0,0 +1,139 @@
// @generated
// This file is @generated by prost-build.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoTimestampEntry {
#[prost(string, tag="1")]
pub stage: ::prost::alloc::string::String,
#[prost(message, optional, tag="2")]
pub time: ::core::option::Option<::prost_types::Timestamp>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoLatencyTracker {
#[prost(string, tag="1")]
pub sequence_id: ::prost::alloc::string::String,
#[prost(message, repeated, tag="2")]
pub timestamps: ::prost::alloc::vec::Vec<ProtoTimestampEntry>,
}
/// MouseMove message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseMove {
/// Fixed value "MouseMove"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag="2")]
pub x: i32,
#[prost(int32, tag="3")]
pub y: i32,
}
/// MouseMoveAbs message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseMoveAbs {
/// Fixed value "MouseMoveAbs"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag="2")]
pub x: i32,
#[prost(int32, tag="3")]
pub y: i32,
}
/// MouseWheel message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseWheel {
/// Fixed value "MouseWheel"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag="2")]
pub x: i32,
#[prost(int32, tag="3")]
pub y: i32,
}
/// MouseKeyDown message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseKeyDown {
/// Fixed value "MouseKeyDown"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag="2")]
pub key: i32,
}
/// MouseKeyUp message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMouseKeyUp {
/// Fixed value "MouseKeyUp"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag="2")]
pub key: i32,
}
/// KeyDown message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoKeyDown {
/// Fixed value "KeyDown"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag="2")]
pub key: i32,
}
/// KeyUp message
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoKeyUp {
/// Fixed value "KeyUp"
#[prost(string, tag="1")]
pub r#type: ::prost::alloc::string::String,
#[prost(int32, tag="2")]
pub key: i32,
}
/// Union of all Input types
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoInput {
#[prost(oneof="proto_input::InputType", tags="1, 2, 3, 4, 5, 6, 7")]
pub input_type: ::core::option::Option<proto_input::InputType>,
}
/// Nested message and enum types in `ProtoInput`.
pub mod proto_input {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum InputType {
#[prost(message, tag="1")]
MouseMove(super::ProtoMouseMove),
#[prost(message, tag="2")]
MouseMoveAbs(super::ProtoMouseMoveAbs),
#[prost(message, tag="3")]
MouseWheel(super::ProtoMouseWheel),
#[prost(message, tag="4")]
MouseKeyDown(super::ProtoMouseKeyDown),
#[prost(message, tag="5")]
MouseKeyUp(super::ProtoMouseKeyUp),
#[prost(message, tag="6")]
KeyDown(super::ProtoKeyDown),
#[prost(message, tag="7")]
KeyUp(super::ProtoKeyUp),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMessageBase {
#[prost(string, tag="1")]
pub payload_type: ::prost::alloc::string::String,
#[prost(message, optional, tag="2")]
pub latency: ::core::option::Option<ProtoLatencyTracker>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ProtoMessageInput {
#[prost(message, optional, tag="1")]
pub message_base: ::core::option::Option<ProtoMessageBase>,
#[prost(message, optional, tag="2")]
pub data: ::core::option::Option<ProtoInput>,
}
// @@protoc_insertion_point(module)

View File

@@ -10,10 +10,10 @@ use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, Mutex, Notify};
use tokio::time::sleep;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
type Callback = Box<dyn Fn(Vec<u8>) + Send + Sync>;
type Callback = Box<dyn Fn(String) + Send + Sync>;
type WSRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
type WSWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
@@ -23,7 +23,7 @@ pub struct NestriWebSocket {
reader: Arc<Mutex<Option<WSRead>>>,
writer: Arc<Mutex<Option<WSWrite>>>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
message_tx: mpsc::UnboundedSender<Vec<u8>>,
message_tx: mpsc::UnboundedSender<String>,
reconnected_notify: Arc<Notify>,
}
impl NestriWebSocket {
@@ -95,8 +95,8 @@ impl NestriWebSocket {
while let Some(message_result) = ws_read.next().await {
match message_result {
Ok(message) => {
let data = message.into_data();
let base_message = match decode_message(&data) {
let data = message.into_text().expect("failed to turn message into text");
let base_message = match decode_message(data.to_string()) {
Ok(base_message) => base_message,
Err(e) => {
eprintln!("Failed to decode message: {:?}", e);
@@ -107,11 +107,14 @@ impl NestriWebSocket {
let callbacks_lock = callbacks.read().unwrap();
if let Some(callback) = callbacks_lock.get(&base_message.payload_type) {
let data = data.clone();
callback(data);
callback(data.to_string());
}
}
Err(e) => {
eprintln!("Error receiving message: {:?}, reconnecting in 3 seconds...", e);
eprintln!(
"Error receiving message: {:?}, reconnecting in 3 seconds...",
e
);
sleep(Duration::from_secs(3)).await;
self_clone.reconnect().await.unwrap();
break; // Break the inner loop to get a new ws_read
@@ -123,7 +126,7 @@ impl NestriWebSocket {
});
}
fn spawn_write_loop(&self, mut message_rx: mpsc::UnboundedReceiver<Vec<u8>>) {
fn spawn_write_loop(&self, mut message_rx: mpsc::UnboundedReceiver<String>) {
let writer = self.writer.clone();
let self_clone = self.clone();
@@ -136,7 +139,10 @@ impl NestriWebSocket {
let mut writer_lock = writer.lock().await;
if let Some(writer) = writer_lock.as_mut() {
// Try to send the message over the WebSocket
match writer.send(Message::Binary(message.clone())).await {
match writer
.send(Message::Text(Utf8Bytes::from(message.clone())))
.await
{
Ok(_) => {
// Message sent successfully
break;
@@ -196,7 +202,7 @@ impl NestriWebSocket {
}
/// Send a message through the WebSocket
pub fn send_message(&self, message: Vec<u8>) -> Result<(), Box<dyn Error>> {
pub fn send_message(&self, message: String) -> Result<(), Box<dyn Error>> {
self.message_tx
.send(message)
.map_err(|e| format!("Failed to send message: {:?}", e).into())
@@ -205,7 +211,7 @@ impl NestriWebSocket {
/// Register a callback for a specific response type
pub fn register_callback<F>(&self, response_type: &str, callback: F)
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
F: Fn(String) + Send + Sync + 'static,
{
let mut callbacks_lock = self.callbacks.write().unwrap();
callbacks_lock.insert(response_type.to_string(), Box::new(callback));
@@ -234,6 +240,7 @@ impl Log for NestriWebSocket {
let log_message = MessageLog {
base: MessageBase {
payload_type: "log".to_string(),
latency: None,
},
level,
message,