feat: Migrate from WebSocket to libp2p for peer-to-peer connectivity (#286)

## Description
Whew, some stuff is still not re-implemented, but it's working!

Rabbit's gonna explode with the amount of changes I reckon 😅



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a peer-to-peer relay system using libp2p with enhanced
stream forwarding, room state synchronization, and mDNS peer discovery.
- Added decentralized room and participant management, metrics
publishing, and safe, size-limited, concurrent message streaming with
robust framing and callback dispatching.
- Implemented asynchronous, callback-driven message handling over custom
libp2p streams replacing WebSocket signaling.
- **Improvements**
- Migrated signaling and stream protocols from WebSocket to libp2p,
improving reliability and scalability.
- Simplified configuration and environment variables, removing
deprecated flags and adding persistent data support.
- Enhanced logging, error handling, and connection management for better
observability and robustness.
- Refined RTP header extension registration and NAT IP handling for
improved WebRTC performance.
- **Bug Fixes**
- Improved ICE candidate buffering and SDP negotiation in WebRTC
connections.
  - Fixed NAT IP and UDP port range configuration issues.
- **Refactor**
- Modularized codebase, reorganized relay and server logic, and removed
deprecated WebSocket-based components.
- Streamlined message structures, removed obsolete enums and message
types, and simplified SafeMap concurrency.
- Replaced WebSocket signaling with libp2p stream protocols in server
and relay components.
- **Chores**
- Updated and cleaned dependencies across Go, Rust, and JavaScript
packages.
  - Added `.gitignore` for persistent data directory in relay package.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com>
Co-authored-by: Philipp Neumann <3daquawolf@gmail.com>
This commit is contained in:
Kristian Ollikainen
2025-06-06 16:48:49 +03:00
committed by GitHub
parent e67a8d2b32
commit 6e82eff9e2
48 changed files with 4741 additions and 2787 deletions

View File

@@ -4,14 +4,14 @@ mod gpu;
mod latency;
mod messages;
mod nestrisink;
mod p2p;
mod proto;
mod websocket;
use crate::args::encoding_args;
use crate::enc_helper::EncoderType;
use crate::gpu::GPUVendor;
use crate::nestrisink::NestriSignaller;
use crate::websocket::NestriWebSocket;
use crate::p2p::p2p::NestriP2P;
use futures_util::StreamExt;
use gst::prelude::*;
use gstrswebrtc::signaller::Signallable;
@@ -19,6 +19,8 @@ use gstrswebrtc::webrtcsink::BaseWebRTCSink;
use std::error::Error;
use std::str::FromStr;
use std::sync::Arc;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
// Handles gathering GPU information and selecting the most suitable GPU
fn handle_gpus(args: &args::Args) -> Result<gpu::GPUInfo, Box<dyn Error>> {
@@ -165,32 +167,29 @@ fn handle_encoder_audio(args: &args::Args) -> String {
async fn main() -> Result<(), Box<dyn Error>> {
// Parse command line arguments
let mut args = args::Args::new();
if args.app.verbose {
// Make sure tracing has INFO level
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.init();
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env()?,
)
.init();
if args.app.verbose {
args.debug_print();
} else {
tracing_subscriber::fmt::init();
}
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install ring crypto provider");
// Begin connection attempt to the relay WebSocket endpoint
// replace any http/https with ws/wss
let replaced_relay_url = args
.app
.relay_url
.replace("http://", "ws://")
.replace("https://", "wss://");
let ws_url = format!("{}/api/ws/{}", replaced_relay_url, args.app.room,);
// Get relay URL from arguments
let relay_url = args.app.relay_url.trim();
// Setup our websocket
let nestri_ws = Arc::new(NestriWebSocket::new(ws_url).await?);
// Initialize libp2p (logically the sink should handle the connection to be independent)
let nestri_p2p = Arc::new(NestriP2P::new().await?);
let p2p_conn = nestri_p2p.connect(relay_url).await?;
gst::init()?;
gstrswebrtc::plugin_register_static()?;
@@ -328,7 +327,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
/* Output */
// WebRTC sink Element
let signaller = NestriSignaller::new(nestri_ws.clone(), video_source.clone());
let signaller =
NestriSignaller::new(args.app.room, p2p_conn.clone(), video_source.clone()).await?;
let webrtcsink = BaseWebRTCSink::with_signaller(Signallable::from(signaller.clone()));
webrtcsink.set_property_from_str("stun-server", "stun://stun.l.google.com:19302");
webrtcsink.set_property_from_str("congestion-control", "disabled");

View File

@@ -1,8 +1,5 @@
use crate::latency::LatencyTracker;
use num_derive::{FromPrimitive, ToPrimitive};
use num_traits::{FromPrimitive, ToPrimitive};
use serde::{Deserialize, Serialize};
use std::error::Error;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
@@ -12,6 +9,13 @@ pub struct MessageBase {
pub latency: Option<LatencyTracker>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageRaw {
#[serde(flatten)]
pub base: MessageBase,
pub data: serde_json::Value,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageLog {
#[serde(flatten)]
@@ -44,76 +48,3 @@ pub struct MessageSDP {
pub base: MessageBase,
pub sdp: RTCSessionDescription,
}
#[repr(i32)]
#[derive(Debug, FromPrimitive, ToPrimitive, Copy, Clone, Serialize, Deserialize)]
#[serde(try_from = "i32", into = "i32")]
pub enum JoinerType {
JoinerNode = 0,
JoinerClient = 1,
}
impl TryFrom<i32> for JoinerType {
type Error = &'static str;
fn try_from(value: i32) -> Result<Self, Self::Error> {
JoinerType::from_i32(value).ok_or("Invalid value for JoinerType")
}
}
impl From<JoinerType> for i32 {
fn from(joiner_type: JoinerType) -> Self {
joiner_type.to_i32().unwrap()
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageJoin {
#[serde(flatten)]
pub base: MessageBase,
pub joiner_type: JoinerType,
}
#[repr(i32)]
#[derive(Debug, FromPrimitive, ToPrimitive, Copy, Clone, Serialize, Deserialize)]
#[serde(try_from = "i32", into = "i32")]
pub enum AnswerType {
AnswerOffline = 0,
AnswerInUse = 1,
AnswerOK = 2,
}
impl TryFrom<i32> for AnswerType {
type Error = &'static str;
fn try_from(value: i32) -> Result<Self, Self::Error> {
AnswerType::from_i32(value).ok_or("Invalid value for AnswerType")
}
}
impl From<AnswerType> for i32 {
fn from(answer_type: AnswerType) -> Self {
answer_type.to_i32().unwrap()
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MessageAnswer {
#[serde(flatten)]
pub base: MessageBase,
pub answer_type: AnswerType,
}
pub fn encode_message<T: Serialize>(message: &T) -> Result<String, Box<dyn Error>> {
// Serialize the message to JSON
let json = serde_json::to_string(message)?;
Ok(json)
}
pub fn decode_message(data: String) -> Result<MessageBase, Box<dyn Error + Send + Sync>> {
let base_message: MessageBase = serde_json::from_str(&data)?;
Ok(base_message)
}
pub fn decode_message_as<T: for<'de> Deserialize<'de>>(
data: String,
) -> Result<T, Box<dyn Error + Send + Sync>> {
let message: T = serde_json::from_str(&data)?;
Ok(message)
}

View File

@@ -1,12 +1,10 @@
use crate::messages::{
AnswerType, JoinerType, MessageAnswer, MessageBase, MessageICE, MessageJoin, MessageSDP,
decode_message_as, encode_message,
};
use crate::messages::{MessageBase, MessageICE, MessageRaw, MessageSDP};
use crate::p2p::p2p::NestriConnection;
use crate::p2p::p2p_protocol_stream::NestriStreamProtocol;
use crate::proto::proto::proto_input::InputType::{
KeyDown, KeyUp, MouseKeyDown, MouseKeyUp, MouseMove, MouseMoveAbs, MouseWheel,
};
use crate::proto::proto::{ProtoInput, ProtoMessageInput};
use crate::websocket::NestriWebSocket;
use atomic_refcell::AtomicRefCell;
use glib::subclass::prelude::*;
use gst::glib;
@@ -20,22 +18,37 @@ use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
pub struct Signaller {
nestri_ws: PLRwLock<Option<Arc<NestriWebSocket>>>,
stream_room: PLRwLock<Option<String>>,
stream_protocol: PLRwLock<Option<Arc<NestriStreamProtocol>>>,
wayland_src: PLRwLock<Option<Arc<gst::Element>>>,
data_channel: AtomicRefCell<Option<gst_webrtc::WebRTCDataChannel>>,
}
impl Default for Signaller {
fn default() -> Self {
Self {
nestri_ws: PLRwLock::new(None),
stream_room: PLRwLock::new(None),
stream_protocol: PLRwLock::new(None),
wayland_src: PLRwLock::new(None),
data_channel: AtomicRefCell::new(None),
}
}
}
impl Signaller {
pub fn set_nestri_ws(&self, nestri_ws: Arc<NestriWebSocket>) {
*self.nestri_ws.write() = Some(nestri_ws);
pub async fn set_nestri_connection(
&self,
nestri_conn: NestriConnection,
) -> Result<(), Box<dyn std::error::Error>> {
let stream_protocol = NestriStreamProtocol::new(nestri_conn).await?;
*self.stream_protocol.write() = Some(Arc::new(stream_protocol));
Ok(())
}
pub fn set_stream_room(&self, room: String) {
*self.stream_room.write() = Some(room);
}
fn get_stream_protocol(&self) -> Option<Arc<NestriStreamProtocol>> {
self.stream_protocol.read().clone()
}
pub fn set_wayland_src(&self, wayland_src: Arc<gst::Element>) {
@@ -58,16 +71,14 @@ impl Signaller {
/// Helper method to clean things up
fn register_callbacks(&self) {
let nestri_ws = {
self.nestri_ws
.read()
.clone()
.expect("NestriWebSocket not set")
let Some(stream_protocol) = self.get_stream_protocol() else {
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
return;
};
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("sdp", move |data| {
if let Ok(message) = decode_message_as::<MessageSDP>(data) {
stream_protocol.register_callback("answer", move |data| {
if let Ok(message) = serde_json::from_slice::<MessageSDP>(&data) {
let sdp =
gst_sdp::SDPMessage::parse_buffer(message.sdp.sdp.as_bytes()).unwrap();
let answer = WebRTCSessionDescription::new(WebRTCSDPType::Answer, sdp);
@@ -82,12 +93,11 @@ impl Signaller {
}
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("ice", move |data| {
if let Ok(message) = decode_message_as::<MessageICE>(data) {
stream_protocol.register_callback("ice-candidate", move |data| {
if let Ok(message) = serde_json::from_slice::<MessageICE>(&data) {
let candidate = message.candidate;
let sdp_m_line_index = candidate.sdp_mline_index.unwrap_or(0) as u32;
let sdp_mid = candidate.sdp_mid;
self_obj.emit_by_name::<()>(
"handle-ice",
&[
@@ -104,29 +114,28 @@ impl Signaller {
}
{
let self_obj = self.obj().clone();
let _ = nestri_ws.register_callback("answer", move |data| {
if let Ok(answer) = decode_message_as::<MessageAnswer>(data) {
gst::info!(gst::CAT_DEFAULT, "Received answer: {:?}", answer);
match answer.answer_type {
AnswerType::AnswerOK => {
gst::info!(gst::CAT_DEFAULT, "Received OK answer");
// Send our SDP offer
self_obj.emit_by_name::<()>(
"session-requested",
&[
&"unique-session-id",
&"consumer-identifier",
&None::<WebRTCSessionDescription>,
],
);
}
AnswerType::AnswerInUse => {
gst::error!(gst::CAT_DEFAULT, "Room is in use by another node");
}
AnswerType::AnswerOffline => {
gst::warning!(gst::CAT_DEFAULT, "Room is offline");
}
stream_protocol.register_callback("push-stream-ok", move |data| {
if let Ok(answer) = serde_json::from_slice::<MessageRaw>(&data) {
// Decode room name string
if let Some(room_name) = answer.data.as_str() {
gst::info!(
gst::CAT_DEFAULT,
"Received OK answer for room: {}",
room_name
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode room name from answer");
}
// Send our SDP offer
self_obj.emit_by_name::<()>(
"session-requested",
&[
&"unique-session-id",
&"consumer-identifier",
&None::<WebRTCSessionDescription>,
],
);
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to decode answer");
}
@@ -177,89 +186,32 @@ impl SignallableImpl for Signaller {
fn start(&self) {
gst::info!(gst::CAT_DEFAULT, "Signaller started");
// Get WebSocket connection
let nestri_ws = {
self.nestri_ws
.read()
.clone()
.expect("NestriWebSocket not set")
};
// Register message callbacks
self.register_callbacks();
// Subscribe to reconnection notifications
let reconnected_notify = nestri_ws.subscribe_reconnected();
// TODO: Re-implement reconnection handling
// Clone necessary references
let self_clone = self.obj().clone();
let nestri_ws_clone = nestri_ws.clone();
let Some(stream_room) = self.stream_room.read().clone() else {
gst::error!(gst::CAT_DEFAULT, "Stream room not set");
return;
};
// Spawn a task to handle actions upon reconnection
tokio::spawn(async move {
loop {
// Wait for a reconnection notification
reconnected_notify.notified().await;
tracing::warn!("Reconnected to relay, re-negotiating...");
gst::warning!(gst::CAT_DEFAULT, "Reconnected to relay, re-negotiating...");
// Emit "session-ended" first to make sure the element is cleaned up
self_clone.emit_by_name::<bool>("session-ended", &[&"unique-session-id"]);
// Send a new join message
let join_msg = MessageJoin {
base: MessageBase {
payload_type: "join".to_string(),
latency: None,
},
joiner_type: JoinerType::JoinerNode,
};
if let Ok(encoded) = encode_message(&join_msg) {
if let Err(e) = nestri_ws_clone.send_message(encoded) {
gst::error!(
gst::CAT_DEFAULT,
"Failed to send join message after reconnection: {:?}",
e
);
}
} else {
gst::error!(
gst::CAT_DEFAULT,
"Failed to encode join message after reconnection"
);
}
// If we need to interact with GStreamer or GLib, schedule it on the main thread
let self_clone_for_main = self_clone.clone();
glib::MainContext::default().invoke(move || {
// Emit the "session-requested" signal
self_clone_for_main.emit_by_name::<()>(
"session-requested",
&[
&"unique-session-id",
&"consumer-identifier",
&None::<WebRTCSessionDescription>,
],
);
});
}
});
let join_msg = MessageJoin {
let push_msg = MessageRaw {
base: MessageBase {
payload_type: "join".to_string(),
payload_type: "push-stream-room".to_string(),
latency: None,
},
joiner_type: JoinerType::JoinerNode,
data: serde_json::Value::from(stream_room),
};
if let Ok(encoded) = encode_message(&join_msg) {
if let Err(e) = nestri_ws.send_message(encoded) {
tracing::error!("Failed to send join message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send join message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode join message");
let Some(stream_protocol) = self.get_stream_protocol() else {
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
return;
};
if let Err(e) = stream_protocol.send_message(&push_msg) {
tracing::error!("Failed to send push stream room message: {:?}", e);
}
}
@@ -268,26 +220,21 @@ impl SignallableImpl for Signaller {
}
fn send_sdp(&self, _session_id: &str, sdp: &WebRTCSessionDescription) {
let nestri_ws = {
self.nestri_ws
.read()
.clone()
.expect("NestriWebSocket not set")
};
let sdp_message = MessageSDP {
base: MessageBase {
payload_type: "sdp".to_string(),
payload_type: "offer".to_string(),
latency: None,
},
sdp: RTCSessionDescription::offer(sdp.sdp().as_text().unwrap()).unwrap(),
};
if let Ok(encoded) = encode_message(&sdp_message) {
if let Err(e) = nestri_ws.send_message(encoded) {
tracing::error!("Failed to send SDP message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send SDP message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode SDP message");
let Some(stream_protocol) = self.get_stream_protocol() else {
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
return;
};
if let Err(e) = stream_protocol.send_message(&sdp_message) {
tracing::error!("Failed to send SDP message: {:?}", e);
}
}
@@ -298,12 +245,6 @@ impl SignallableImpl for Signaller {
sdp_m_line_index: u32,
sdp_mid: Option<String>,
) {
let nestri_ws = {
self.nestri_ws
.read()
.clone()
.expect("NestriWebSocket not set")
};
let candidate_init = RTCIceCandidateInit {
candidate: candidate.to_string(),
sdp_mid,
@@ -312,18 +253,19 @@ impl SignallableImpl for Signaller {
};
let ice_message = MessageICE {
base: MessageBase {
payload_type: "ice".to_string(),
payload_type: "ice-candidate".to_string(),
latency: None,
},
candidate: candidate_init,
};
if let Ok(encoded) = encode_message(&ice_message) {
if let Err(e) = nestri_ws.send_message(encoded) {
tracing::error!("Failed to send ICE message: {:?}", e);
gst::error!(gst::CAT_DEFAULT, "Failed to send ICE message: {:?}", e);
}
} else {
gst::error!(gst::CAT_DEFAULT, "Failed to encode ICE message");
let Some(stream_protocol) = self.get_stream_protocol() else {
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
return;
};
if let Err(e) = stream_protocol.send_message(&ice_message) {
tracing::error!("Failed to send ICE candidate message: {:?}", e);
}
}

View File

@@ -1,4 +1,4 @@
use crate::websocket::NestriWebSocket;
use crate::p2p::p2p::NestriConnection;
use gst::glib;
use gst::subclass::prelude::*;
use gstrswebrtc::signaller::Signallable;
@@ -11,15 +11,20 @@ glib::wrapper! {
}
impl NestriSignaller {
pub fn new(nestri_ws: Arc<NestriWebSocket>, wayland_src: Arc<gst::Element>) -> Self {
pub async fn new(
room: String,
nestri_conn: NestriConnection,
wayland_src: Arc<gst::Element>,
) -> Result<Self, Box<dyn std::error::Error>> {
let obj: Self = glib::Object::new();
obj.imp().set_nestri_ws(nestri_ws);
obj.imp().set_stream_room(room);
obj.imp().set_nestri_connection(nestri_conn).await?;
obj.imp().set_wayland_src(wayland_src);
obj
Ok(obj)
}
}
impl Default for NestriSignaller {
fn default() -> Self {
panic!("Cannot create NestriSignaller without NestriWebSocket");
panic!("Cannot create NestriSignaller without NestriConnection and WaylandSrc");
}
}

View File

@@ -0,0 +1,3 @@
pub mod p2p;
pub mod p2p_safestream;
pub mod p2p_protocol_stream;

View File

@@ -0,0 +1,131 @@
use futures_util::StreamExt;
use libp2p::multiaddr::Protocol;
use libp2p::{
Multiaddr, PeerId, Swarm, identify, noise, ping,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::error::Error;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
pub struct NestriConnection {
pub peer_id: PeerId,
pub control: libp2p_stream::Control,
}
#[derive(NetworkBehaviour)]
struct NestriBehaviour {
identify: identify::Behaviour,
ping: ping::Behaviour,
stream: libp2p_stream::Behaviour,
}
pub struct NestriP2P {
swarm: Arc<Mutex<Swarm<NestriBehaviour>>>,
}
impl NestriP2P {
pub async fn new() -> Result<Self, Box<dyn Error>> {
let swarm = Arc::new(Mutex::new(
libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_dns()?
.with_behaviour(|key| {
let identify_behaviour = identify::Behaviour::new(identify::Config::new(
"/ipfs/id/1.0.0".to_string(),
key.public(),
));
let ping_behaviour = ping::Behaviour::default();
let stream_behaviour = libp2p_stream::Behaviour::default();
Ok(NestriBehaviour {
identify: identify_behaviour,
ping: ping_behaviour,
stream: stream_behaviour,
})
})?
.build(),
));
// Spawn the swarm event loop
let swarm_clone = swarm.clone();
tokio::spawn(swarm_loop(swarm_clone));
{
let mut swarm_lock = swarm.lock().await;
swarm_lock.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?; // IPv4 - TCP Raw
swarm_lock.listen_on("/ip6/::/tcp/0".parse()?)?; // IPv6 - TCP Raw
}
Ok(NestriP2P { swarm })
}
pub async fn connect(&self, conn_url: &str) -> Result<NestriConnection, Box<dyn Error>> {
let conn_addr: Multiaddr = conn_url.parse()?;
let mut swarm_lock = self.swarm.lock().await;
swarm_lock.dial(conn_addr.clone())?;
let Some(Protocol::P2p(peer_id)) = conn_addr.clone().iter().last() else {
return Err("Invalid connection URL: missing peer ID".into());
};
Ok(NestriConnection {
peer_id,
control: swarm_lock.behaviour().stream.new_control(),
})
}
}
async fn swarm_loop(swarm: Arc<Mutex<Swarm<NestriBehaviour>>>) {
loop {
let event = {
let mut swarm_lock = swarm.lock().await;
swarm_lock.select_next_some().await
};
match event {
SwarmEvent::NewListenAddr { address, .. } => {
tracing::info!("Listening on: '{}'", address);
}
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
tracing::info!("Connection established with peer: {}", peer_id);
}
SwarmEvent::ConnectionClosed { peer_id, cause, .. } => {
if let Some(err) = cause {
tracing::error!(
"Connection with peer {} closed due to error: {}",
peer_id,
err
);
} else {
tracing::info!("Connection with peer {} closed", peer_id);
}
}
SwarmEvent::IncomingConnection {
local_addr,
send_back_addr,
..
} => {
tracing::info!(
"Incoming connection from: {} (send back to: {})",
local_addr,
send_back_addr
);
}
SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => {
if let Some(peer_id) = peer_id {
tracing::error!("Failed to connect to peer {}: {}", peer_id, error);
} else {
tracing::error!("Failed to connect: {}", error);
}
}
_ => {}
}
}
}

View File

@@ -0,0 +1,149 @@
use crate::p2p::p2p::NestriConnection;
use crate::p2p::p2p_safestream::SafeStream;
use libp2p::StreamProtocol;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::mpsc;
// Cloneable callback type
pub type CallbackInner = dyn Fn(Vec<u8>) + Send + Sync + 'static;
pub struct Callback(Arc<CallbackInner>);
impl Callback {
pub fn new<F>(f: F) -> Self
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
Callback(Arc::new(f))
}
pub fn call(&self, data: Vec<u8>) {
self.0(data)
}
}
impl Clone for Callback {
fn clone(&self) -> Self {
Callback(Arc::clone(&self.0))
}
}
impl From<Box<CallbackInner>> for Callback {
fn from(boxed: Box<CallbackInner>) -> Self {
Callback(Arc::from(boxed))
}
}
/// NestriStreamProtocol manages the stream protocol for Nestri connections.
pub struct NestriStreamProtocol {
tx: mpsc::Sender<Vec<u8>>,
safe_stream: Arc<SafeStream>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
}
impl NestriStreamProtocol {
const NESTRI_PROTOCOL_STREAM_PUSH: StreamProtocol =
StreamProtocol::new("/nestri-relay/stream-push/1.0.0");
pub async fn new(
nestri_connection: NestriConnection,
) -> Result<Self, Box<dyn std::error::Error>> {
let mut nestri_connection = nestri_connection.clone();
let push_stream = match nestri_connection
.control
.open_stream(nestri_connection.peer_id, Self::NESTRI_PROTOCOL_STREAM_PUSH)
.await
{
Ok(stream) => stream,
Err(e) => {
return Err(Box::new(e));
}
};
let (tx, rx) = mpsc::channel(1000);
let sp = NestriStreamProtocol {
tx,
safe_stream: Arc::new(SafeStream::new(push_stream)),
callbacks: Arc::new(RwLock::new(HashMap::new())),
};
// Spawn the loops
sp.spawn_read_loop();
sp.spawn_write_loop(rx);
Ok(sp)
}
fn spawn_read_loop(&self) -> tokio::task::JoinHandle<()> {
let safe_stream = self.safe_stream.clone();
let callbacks = self.callbacks.clone();
tokio::spawn(async move {
loop {
let data = {
match safe_stream.receive_raw().await {
Ok(data) => data,
Err(e) => {
tracing::error!("Error receiving data: {}", e);
break; // Exit the loop on error
}
}
};
match serde_json::from_slice::<crate::messages::MessageBase>(&data) {
Ok(base_message) => {
let response_type = base_message.payload_type;
let callback = {
let callbacks_lock = callbacks.read().unwrap();
callbacks_lock.get(&response_type).cloned()
};
if let Some(callback) = callback {
// Call the registered callback with the raw data
callback.call(data);
} else {
tracing::warn!(
"No callback registered for response type: {}",
response_type
);
}
}
Err(e) => {
tracing::error!("Failed to decode message: {}", e);
}
}
}
})
}
fn spawn_write_loop(&self, mut rx: mpsc::Receiver<Vec<u8>>) -> tokio::task::JoinHandle<()> {
let safe_stream = self.safe_stream.clone();
tokio::spawn(async move {
loop {
// Wait for a message from the channel
if let Some(tx_data) = rx.recv().await {
if let Err(e) = safe_stream.send_raw(&tx_data).await {
tracing::error!("Error sending data: {:?}", e);
}
} else {
tracing::info!("Receiver closed, exiting write loop");
break;
}
}
})
}
pub fn send_message<M: serde::Serialize>(
&self,
message: &M,
) -> Result<(), Box<dyn std::error::Error>> {
let json_data = serde_json::to_vec(message)?;
self.tx.try_send(json_data)?;
Ok(())
}
/// Register a callback for a specific response type
pub fn register_callback<F>(&self, response_type: &str, callback: F)
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
let mut callbacks_lock = self.callbacks.write().unwrap();
callbacks_lock.insert(response_type.to_string(), Callback::new(callback));
}
}

View File

@@ -0,0 +1,105 @@
use byteorder::{BigEndian, ByteOrder};
use futures_util::io::{ReadHalf, WriteHalf};
use futures_util::{AsyncReadExt, AsyncWriteExt};
use prost::Message;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::sync::Arc;
use tokio::sync::Mutex;
const MAX_SIZE: usize = 1024 * 1024; // 1MB
pub struct SafeStream {
stream_read: Arc<Mutex<ReadHalf<libp2p::Stream>>>,
stream_write: Arc<Mutex<WriteHalf<libp2p::Stream>>>,
}
impl SafeStream {
pub fn new(stream: libp2p::Stream) -> Self {
let (read, write) = stream.split();
SafeStream {
stream_read: Arc::new(Mutex::new(read)),
stream_write: Arc::new(Mutex::new(write)),
}
}
pub async fn send_json<T: Serialize>(
&self,
data: &T,
) -> Result<(), Box<dyn std::error::Error>> {
let json_data = serde_json::to_vec(data)?;
tracing::info!("Sending JSON");
let e = self.send_with_length_prefix(&json_data).await;
tracing::info!("Sent JSON");
e
}
pub async fn receive_json<T: DeserializeOwned>(&self) -> Result<T, Box<dyn std::error::Error>> {
let data = self.receive_with_length_prefix().await?;
let msg = serde_json::from_slice(&data)?;
Ok(msg)
}
pub async fn send_proto<M: Message>(&self, msg: &M) -> Result<(), Box<dyn std::error::Error>> {
let mut proto_data = Vec::new();
msg.encode(&mut proto_data)?;
self.send_with_length_prefix(&proto_data).await
}
pub async fn receive_proto<M: Message + Default>(
&self,
) -> Result<M, Box<dyn std::error::Error>> {
let data = self.receive_with_length_prefix().await?;
let msg = M::decode(&*data)?;
Ok(msg)
}
pub async fn send_raw(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
self.send_with_length_prefix(data).await
}
pub async fn receive_raw(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
self.receive_with_length_prefix().await
}
async fn send_with_length_prefix(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
if data.len() > MAX_SIZE {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Data exceeds maximum size",
)));
}
let mut stream_write = self.stream_write.lock().await;
// Write the 4-byte length prefix
let mut length_prefix = [0u8; 4];
BigEndian::write_u32(&mut length_prefix, data.len() as u32);
stream_write.write_all(&length_prefix).await?;
// Write the actual data
stream_write.write_all(data).await?;
stream_write.flush().await?;
Ok(())
}
async fn receive_with_length_prefix(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let mut stream_read = self.stream_read.lock().await;
// Read the 4-byte length prefix
let mut length_prefix = [0u8; 4];
stream_read.read_exact(&mut length_prefix).await?;
let length = BigEndian::read_u32(&length_prefix) as usize;
if length > MAX_SIZE {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Data exceeds maximum size",
)));
}
// Read the actual data
let mut buffer = vec![0; length];
stream_read.read_exact(&mut buffer).await?;
Ok(buffer)
}
}

View File

@@ -1,225 +0,0 @@
use crate::messages::decode_message;
use futures_util::StreamExt;
use futures_util::sink::SinkExt;
use futures_util::stream::{SplitSink, SplitStream};
use std::collections::HashMap;
use std::error::Error;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, Notify, mpsc};
use tokio::time::sleep;
use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
type Callback = Box<dyn Fn(String) + Send + Sync>;
type WSRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
type WSWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
#[derive(Clone)]
pub struct NestriWebSocket {
ws_url: String,
reader: Arc<Mutex<Option<WSRead>>>,
writer: Arc<Mutex<Option<WSWrite>>>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
message_tx: mpsc::UnboundedSender<String>,
reconnected_notify: Arc<Notify>,
}
impl NestriWebSocket {
pub async fn new(ws_url: String) -> Result<NestriWebSocket, Box<dyn Error>> {
// Attempt to connect to the WebSocket
let ws_stream = NestriWebSocket::do_connect(&ws_url).await.unwrap();
// Split the stream into read and write halves
let (write, read) = ws_stream.split();
// Create the message channel
let (message_tx, message_rx) = mpsc::unbounded_channel();
let ws = NestriWebSocket {
ws_url,
reader: Arc::new(Mutex::new(Some(read))),
writer: Arc::new(Mutex::new(Some(write))),
callbacks: Arc::new(RwLock::new(HashMap::new())),
message_tx: message_tx.clone(),
reconnected_notify: Arc::new(Notify::new()),
};
// Spawn the read loop
ws.spawn_read_loop();
// Spawn the write loop
ws.spawn_write_loop(message_rx);
Ok(ws)
}
async fn do_connect(
ws_url: &str,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Box<dyn Error + Send + Sync>> {
loop {
match connect_async(ws_url).await {
Ok((ws_stream, _)) => {
return Ok(ws_stream);
}
Err(e) => {
tracing::error!("Failed to connect to WebSocket, retrying: {:?}", e);
sleep(Duration::from_secs(3)).await; // Wait before retrying
}
}
}
}
// Handles message -> callback calls and reconnects on error/disconnect
fn spawn_read_loop(&self) {
let reader = self.reader.clone();
let callbacks = self.callbacks.clone();
let self_clone = self.clone();
tokio::spawn(async move {
loop {
// Lock the reader to get the WSRead, then drop the lock
let ws_read_option = {
let mut reader_lock = reader.lock().await;
reader_lock.take()
};
let mut ws_read = match ws_read_option {
Some(ws_read) => ws_read,
None => {
tracing::error!("Reader is None, cannot proceed");
return;
}
};
while let Some(message_result) = ws_read.next().await {
match message_result {
Ok(message) => {
let data = message
.into_text()
.expect("failed to turn message into text");
let base_message = match decode_message(data.to_string()) {
Ok(base_message) => base_message,
Err(e) => {
tracing::error!("Failed to decode message: {:?}", e);
continue;
}
};
let callbacks_lock = callbacks.read().unwrap();
if let Some(callback) = callbacks_lock.get(&base_message.payload_type) {
let data = data.clone();
callback(data.to_string());
}
}
Err(e) => {
tracing::error!(
"Error receiving message: {:?}, reconnecting in 3 seconds...",
e
);
sleep(Duration::from_secs(3)).await;
self_clone.reconnect().await.unwrap();
break; // Break the inner loop to get a new ws_read
}
}
}
// After reconnection, the loop continues, and we acquire a new ws_read
}
});
}
fn spawn_write_loop(&self, mut message_rx: mpsc::UnboundedReceiver<String>) {
let writer = self.writer.clone();
let self_clone = self.clone();
tokio::spawn(async move {
loop {
// Wait for a message from the channel
if let Some(message) = message_rx.recv().await {
loop {
// Acquire the writer lock
let mut writer_lock = writer.lock().await;
if let Some(writer) = writer_lock.as_mut() {
// Try to send the message over the WebSocket
match writer
.send(Message::Text(Utf8Bytes::from(message.clone())))
.await
{
Ok(_) => {
// Message sent successfully
break;
}
Err(e) => {
tracing::error!("Error sending message: {:?}", e);
// Attempt to reconnect
if let Err(e) = self_clone.reconnect().await {
tracing::error!("Error during reconnection: {:?}", e);
// Wait before retrying
sleep(Duration::from_secs(3)).await;
continue;
}
}
}
} else {
tracing::error!("Writer is None, cannot send message");
// Attempt to reconnect
if let Err(e) = self_clone.reconnect().await {
tracing::error!("Error during reconnection: {:?}", e);
// Wait before retrying
sleep(Duration::from_secs(3)).await;
continue;
}
}
}
} else {
break;
}
}
});
}
async fn reconnect(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
loop {
match NestriWebSocket::do_connect(&self.ws_url).await {
Ok(ws_stream) => {
let (write, read) = ws_stream.split();
{
let mut writer_lock = self.writer.lock().await;
*writer_lock = Some(write);
}
{
let mut reader_lock = self.reader.lock().await;
*reader_lock = Some(read);
}
// Notify subscribers of successful reconnection
self.reconnected_notify.notify_waiters();
return Ok(());
}
Err(e) => {
tracing::error!("Failed to reconnect to WebSocket: {:?}", e);
sleep(Duration::from_secs(3)).await; // Wait before retrying
}
}
}
}
/// Send a message through the WebSocket
pub fn send_message(&self, message: String) -> Result<(), Box<dyn Error>> {
self.message_tx
.send(message)
.map_err(|e| format!("Failed to send message: {:?}", e).into())
}
/// Register a callback for a specific response type
pub fn register_callback<F>(&self, response_type: &str, callback: F)
where
F: Fn(String) + Send + Sync + 'static,
{
let mut callbacks_lock = self.callbacks.write().unwrap();
callbacks_lock.insert(response_type.to_string(), Box::new(callback));
}
/// Subscribe to event for reconnection
pub fn subscribe_reconnected(&self) -> Arc<Notify> {
self.reconnected_notify.clone()
}
}