mirror of
https://github.com/nestriness/nestri.git
synced 2025-12-12 08:45:38 +02:00
⭐ feat: Migrate from WebSocket to libp2p for peer-to-peer connectivity (#286)
## Description Whew, some stuff is still not re-implemented, but it's working! Rabbit's gonna explode with the amount of changes I reckon 😅 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a peer-to-peer relay system using libp2p with enhanced stream forwarding, room state synchronization, and mDNS peer discovery. - Added decentralized room and participant management, metrics publishing, and safe, size-limited, concurrent message streaming with robust framing and callback dispatching. - Implemented asynchronous, callback-driven message handling over custom libp2p streams replacing WebSocket signaling. - **Improvements** - Migrated signaling and stream protocols from WebSocket to libp2p, improving reliability and scalability. - Simplified configuration and environment variables, removing deprecated flags and adding persistent data support. - Enhanced logging, error handling, and connection management for better observability and robustness. - Refined RTP header extension registration and NAT IP handling for improved WebRTC performance. - **Bug Fixes** - Improved ICE candidate buffering and SDP negotiation in WebRTC connections. - Fixed NAT IP and UDP port range configuration issues. - **Refactor** - Modularized codebase, reorganized relay and server logic, and removed deprecated WebSocket-based components. - Streamlined message structures, removed obsolete enums and message types, and simplified SafeMap concurrency. - Replaced WebSocket signaling with libp2p stream protocols in server and relay components. - **Chores** - Updated and cleaned dependencies across Go, Rust, and JavaScript packages. - Added `.gitignore` for persistent data directory in relay package. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: DatCaptainHorse <DatCaptainHorse@users.noreply.github.com> Co-authored-by: Philipp Neumann <3daquawolf@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e67a8d2b32
commit
6e82eff9e2
1763
packages/server/Cargo.lock
generated
1763
packages/server/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -19,14 +19,14 @@ webrtc = "0.13"
|
||||
regex = "1.11"
|
||||
rand = "0.9"
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
tokio-tungstenite = { version = "0.26", features = ["native-tls"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
chrono = "0.4"
|
||||
futures-util = "0.3"
|
||||
num-derive = "0.4"
|
||||
num-traits = "0.2"
|
||||
prost = "0.13"
|
||||
prost-types = "0.13"
|
||||
parking_lot = "0.12"
|
||||
atomic_refcell = "0.1"
|
||||
atomic_refcell = "0.1"
|
||||
byteorder = "1.5"
|
||||
libp2p = { version = "0.55", features = ["identify", "dns", "tcp", "noise", "ping", "tokio", "serde", "yamux", "macros"] }
|
||||
libp2p-stream = "0.3.0-alpha"
|
||||
@@ -4,14 +4,14 @@ mod gpu;
|
||||
mod latency;
|
||||
mod messages;
|
||||
mod nestrisink;
|
||||
mod p2p;
|
||||
mod proto;
|
||||
mod websocket;
|
||||
|
||||
use crate::args::encoding_args;
|
||||
use crate::enc_helper::EncoderType;
|
||||
use crate::gpu::GPUVendor;
|
||||
use crate::nestrisink::NestriSignaller;
|
||||
use crate::websocket::NestriWebSocket;
|
||||
use crate::p2p::p2p::NestriP2P;
|
||||
use futures_util::StreamExt;
|
||||
use gst::prelude::*;
|
||||
use gstrswebrtc::signaller::Signallable;
|
||||
@@ -19,6 +19,8 @@ use gstrswebrtc::webrtcsink::BaseWebRTCSink;
|
||||
use std::error::Error;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
// Handles gathering GPU information and selecting the most suitable GPU
|
||||
fn handle_gpus(args: &args::Args) -> Result<gpu::GPUInfo, Box<dyn Error>> {
|
||||
@@ -165,32 +167,29 @@ fn handle_encoder_audio(args: &args::Args) -> String {
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
// Parse command line arguments
|
||||
let mut args = args::Args::new();
|
||||
if args.app.verbose {
|
||||
// Make sure tracing has INFO level
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.init();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.from_env()?,
|
||||
)
|
||||
.init();
|
||||
|
||||
if args.app.verbose {
|
||||
args.debug_print();
|
||||
} else {
|
||||
tracing_subscriber::fmt::init();
|
||||
}
|
||||
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install ring crypto provider");
|
||||
|
||||
// Begin connection attempt to the relay WebSocket endpoint
|
||||
// replace any http/https with ws/wss
|
||||
let replaced_relay_url = args
|
||||
.app
|
||||
.relay_url
|
||||
.replace("http://", "ws://")
|
||||
.replace("https://", "wss://");
|
||||
let ws_url = format!("{}/api/ws/{}", replaced_relay_url, args.app.room,);
|
||||
// Get relay URL from arguments
|
||||
let relay_url = args.app.relay_url.trim();
|
||||
|
||||
// Setup our websocket
|
||||
let nestri_ws = Arc::new(NestriWebSocket::new(ws_url).await?);
|
||||
// Initialize libp2p (logically the sink should handle the connection to be independent)
|
||||
let nestri_p2p = Arc::new(NestriP2P::new().await?);
|
||||
let p2p_conn = nestri_p2p.connect(relay_url).await?;
|
||||
|
||||
gst::init()?;
|
||||
gstrswebrtc::plugin_register_static()?;
|
||||
@@ -328,7 +327,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
/* Output */
|
||||
// WebRTC sink Element
|
||||
let signaller = NestriSignaller::new(nestri_ws.clone(), video_source.clone());
|
||||
let signaller =
|
||||
NestriSignaller::new(args.app.room, p2p_conn.clone(), video_source.clone()).await?;
|
||||
let webrtcsink = BaseWebRTCSink::with_signaller(Signallable::from(signaller.clone()));
|
||||
webrtcsink.set_property_from_str("stun-server", "stun://stun.l.google.com:19302");
|
||||
webrtcsink.set_property_from_str("congestion-control", "disabled");
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
use crate::latency::LatencyTracker;
|
||||
use num_derive::{FromPrimitive, ToPrimitive};
|
||||
use num_traits::{FromPrimitive, ToPrimitive};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||
|
||||
@@ -12,6 +9,13 @@ pub struct MessageBase {
|
||||
pub latency: Option<LatencyTracker>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageRaw {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageLog {
|
||||
#[serde(flatten)]
|
||||
@@ -44,76 +48,3 @@ pub struct MessageSDP {
|
||||
pub base: MessageBase,
|
||||
pub sdp: RTCSessionDescription,
|
||||
}
|
||||
|
||||
#[repr(i32)]
|
||||
#[derive(Debug, FromPrimitive, ToPrimitive, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(try_from = "i32", into = "i32")]
|
||||
pub enum JoinerType {
|
||||
JoinerNode = 0,
|
||||
JoinerClient = 1,
|
||||
}
|
||||
impl TryFrom<i32> for JoinerType {
|
||||
type Error = &'static str;
|
||||
|
||||
fn try_from(value: i32) -> Result<Self, Self::Error> {
|
||||
JoinerType::from_i32(value).ok_or("Invalid value for JoinerType")
|
||||
}
|
||||
}
|
||||
impl From<JoinerType> for i32 {
|
||||
fn from(joiner_type: JoinerType) -> Self {
|
||||
joiner_type.to_i32().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageJoin {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub joiner_type: JoinerType,
|
||||
}
|
||||
|
||||
#[repr(i32)]
|
||||
#[derive(Debug, FromPrimitive, ToPrimitive, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(try_from = "i32", into = "i32")]
|
||||
pub enum AnswerType {
|
||||
AnswerOffline = 0,
|
||||
AnswerInUse = 1,
|
||||
AnswerOK = 2,
|
||||
}
|
||||
impl TryFrom<i32> for AnswerType {
|
||||
type Error = &'static str;
|
||||
|
||||
fn try_from(value: i32) -> Result<Self, Self::Error> {
|
||||
AnswerType::from_i32(value).ok_or("Invalid value for AnswerType")
|
||||
}
|
||||
}
|
||||
impl From<AnswerType> for i32 {
|
||||
fn from(answer_type: AnswerType) -> Self {
|
||||
answer_type.to_i32().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct MessageAnswer {
|
||||
#[serde(flatten)]
|
||||
pub base: MessageBase,
|
||||
pub answer_type: AnswerType,
|
||||
}
|
||||
|
||||
pub fn encode_message<T: Serialize>(message: &T) -> Result<String, Box<dyn Error>> {
|
||||
// Serialize the message to JSON
|
||||
let json = serde_json::to_string(message)?;
|
||||
Ok(json)
|
||||
}
|
||||
|
||||
pub fn decode_message(data: String) -> Result<MessageBase, Box<dyn Error + Send + Sync>> {
|
||||
let base_message: MessageBase = serde_json::from_str(&data)?;
|
||||
Ok(base_message)
|
||||
}
|
||||
|
||||
pub fn decode_message_as<T: for<'de> Deserialize<'de>>(
|
||||
data: String,
|
||||
) -> Result<T, Box<dyn Error + Send + Sync>> {
|
||||
let message: T = serde_json::from_str(&data)?;
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
use crate::messages::{
|
||||
AnswerType, JoinerType, MessageAnswer, MessageBase, MessageICE, MessageJoin, MessageSDP,
|
||||
decode_message_as, encode_message,
|
||||
};
|
||||
use crate::messages::{MessageBase, MessageICE, MessageRaw, MessageSDP};
|
||||
use crate::p2p::p2p::NestriConnection;
|
||||
use crate::p2p::p2p_protocol_stream::NestriStreamProtocol;
|
||||
use crate::proto::proto::proto_input::InputType::{
|
||||
KeyDown, KeyUp, MouseKeyDown, MouseKeyUp, MouseMove, MouseMoveAbs, MouseWheel,
|
||||
};
|
||||
use crate::proto::proto::{ProtoInput, ProtoMessageInput};
|
||||
use crate::websocket::NestriWebSocket;
|
||||
use atomic_refcell::AtomicRefCell;
|
||||
use glib::subclass::prelude::*;
|
||||
use gst::glib;
|
||||
@@ -20,22 +18,37 @@ use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||
|
||||
pub struct Signaller {
|
||||
nestri_ws: PLRwLock<Option<Arc<NestriWebSocket>>>,
|
||||
stream_room: PLRwLock<Option<String>>,
|
||||
stream_protocol: PLRwLock<Option<Arc<NestriStreamProtocol>>>,
|
||||
wayland_src: PLRwLock<Option<Arc<gst::Element>>>,
|
||||
data_channel: AtomicRefCell<Option<gst_webrtc::WebRTCDataChannel>>,
|
||||
}
|
||||
impl Default for Signaller {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
nestri_ws: PLRwLock::new(None),
|
||||
stream_room: PLRwLock::new(None),
|
||||
stream_protocol: PLRwLock::new(None),
|
||||
wayland_src: PLRwLock::new(None),
|
||||
data_channel: AtomicRefCell::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Signaller {
|
||||
pub fn set_nestri_ws(&self, nestri_ws: Arc<NestriWebSocket>) {
|
||||
*self.nestri_ws.write() = Some(nestri_ws);
|
||||
pub async fn set_nestri_connection(
|
||||
&self,
|
||||
nestri_conn: NestriConnection,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let stream_protocol = NestriStreamProtocol::new(nestri_conn).await?;
|
||||
*self.stream_protocol.write() = Some(Arc::new(stream_protocol));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_stream_room(&self, room: String) {
|
||||
*self.stream_room.write() = Some(room);
|
||||
}
|
||||
|
||||
fn get_stream_protocol(&self) -> Option<Arc<NestriStreamProtocol>> {
|
||||
self.stream_protocol.read().clone()
|
||||
}
|
||||
|
||||
pub fn set_wayland_src(&self, wayland_src: Arc<gst::Element>) {
|
||||
@@ -58,16 +71,14 @@ impl Signaller {
|
||||
|
||||
/// Helper method to clean things up
|
||||
fn register_callbacks(&self) {
|
||||
let nestri_ws = {
|
||||
self.nestri_ws
|
||||
.read()
|
||||
.clone()
|
||||
.expect("NestriWebSocket not set")
|
||||
let Some(stream_protocol) = self.get_stream_protocol() else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
|
||||
return;
|
||||
};
|
||||
{
|
||||
let self_obj = self.obj().clone();
|
||||
let _ = nestri_ws.register_callback("sdp", move |data| {
|
||||
if let Ok(message) = decode_message_as::<MessageSDP>(data) {
|
||||
stream_protocol.register_callback("answer", move |data| {
|
||||
if let Ok(message) = serde_json::from_slice::<MessageSDP>(&data) {
|
||||
let sdp =
|
||||
gst_sdp::SDPMessage::parse_buffer(message.sdp.sdp.as_bytes()).unwrap();
|
||||
let answer = WebRTCSessionDescription::new(WebRTCSDPType::Answer, sdp);
|
||||
@@ -82,12 +93,11 @@ impl Signaller {
|
||||
}
|
||||
{
|
||||
let self_obj = self.obj().clone();
|
||||
let _ = nestri_ws.register_callback("ice", move |data| {
|
||||
if let Ok(message) = decode_message_as::<MessageICE>(data) {
|
||||
stream_protocol.register_callback("ice-candidate", move |data| {
|
||||
if let Ok(message) = serde_json::from_slice::<MessageICE>(&data) {
|
||||
let candidate = message.candidate;
|
||||
let sdp_m_line_index = candidate.sdp_mline_index.unwrap_or(0) as u32;
|
||||
let sdp_mid = candidate.sdp_mid;
|
||||
|
||||
self_obj.emit_by_name::<()>(
|
||||
"handle-ice",
|
||||
&[
|
||||
@@ -104,29 +114,28 @@ impl Signaller {
|
||||
}
|
||||
{
|
||||
let self_obj = self.obj().clone();
|
||||
let _ = nestri_ws.register_callback("answer", move |data| {
|
||||
if let Ok(answer) = decode_message_as::<MessageAnswer>(data) {
|
||||
gst::info!(gst::CAT_DEFAULT, "Received answer: {:?}", answer);
|
||||
match answer.answer_type {
|
||||
AnswerType::AnswerOK => {
|
||||
gst::info!(gst::CAT_DEFAULT, "Received OK answer");
|
||||
// Send our SDP offer
|
||||
self_obj.emit_by_name::<()>(
|
||||
"session-requested",
|
||||
&[
|
||||
&"unique-session-id",
|
||||
&"consumer-identifier",
|
||||
&None::<WebRTCSessionDescription>,
|
||||
],
|
||||
);
|
||||
}
|
||||
AnswerType::AnswerInUse => {
|
||||
gst::error!(gst::CAT_DEFAULT, "Room is in use by another node");
|
||||
}
|
||||
AnswerType::AnswerOffline => {
|
||||
gst::warning!(gst::CAT_DEFAULT, "Room is offline");
|
||||
}
|
||||
stream_protocol.register_callback("push-stream-ok", move |data| {
|
||||
if let Ok(answer) = serde_json::from_slice::<MessageRaw>(&data) {
|
||||
// Decode room name string
|
||||
if let Some(room_name) = answer.data.as_str() {
|
||||
gst::info!(
|
||||
gst::CAT_DEFAULT,
|
||||
"Received OK answer for room: {}",
|
||||
room_name
|
||||
);
|
||||
} else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Failed to decode room name from answer");
|
||||
}
|
||||
|
||||
// Send our SDP offer
|
||||
self_obj.emit_by_name::<()>(
|
||||
"session-requested",
|
||||
&[
|
||||
&"unique-session-id",
|
||||
&"consumer-identifier",
|
||||
&None::<WebRTCSessionDescription>,
|
||||
],
|
||||
);
|
||||
} else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Failed to decode answer");
|
||||
}
|
||||
@@ -177,89 +186,32 @@ impl SignallableImpl for Signaller {
|
||||
fn start(&self) {
|
||||
gst::info!(gst::CAT_DEFAULT, "Signaller started");
|
||||
|
||||
// Get WebSocket connection
|
||||
let nestri_ws = {
|
||||
self.nestri_ws
|
||||
.read()
|
||||
.clone()
|
||||
.expect("NestriWebSocket not set")
|
||||
};
|
||||
|
||||
// Register message callbacks
|
||||
self.register_callbacks();
|
||||
|
||||
// Subscribe to reconnection notifications
|
||||
let reconnected_notify = nestri_ws.subscribe_reconnected();
|
||||
// TODO: Re-implement reconnection handling
|
||||
|
||||
// Clone necessary references
|
||||
let self_clone = self.obj().clone();
|
||||
let nestri_ws_clone = nestri_ws.clone();
|
||||
let Some(stream_room) = self.stream_room.read().clone() else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Stream room not set");
|
||||
return;
|
||||
};
|
||||
|
||||
// Spawn a task to handle actions upon reconnection
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
// Wait for a reconnection notification
|
||||
reconnected_notify.notified().await;
|
||||
|
||||
tracing::warn!("Reconnected to relay, re-negotiating...");
|
||||
gst::warning!(gst::CAT_DEFAULT, "Reconnected to relay, re-negotiating...");
|
||||
|
||||
// Emit "session-ended" first to make sure the element is cleaned up
|
||||
self_clone.emit_by_name::<bool>("session-ended", &[&"unique-session-id"]);
|
||||
|
||||
// Send a new join message
|
||||
let join_msg = MessageJoin {
|
||||
base: MessageBase {
|
||||
payload_type: "join".to_string(),
|
||||
latency: None,
|
||||
},
|
||||
joiner_type: JoinerType::JoinerNode,
|
||||
};
|
||||
if let Ok(encoded) = encode_message(&join_msg) {
|
||||
if let Err(e) = nestri_ws_clone.send_message(encoded) {
|
||||
gst::error!(
|
||||
gst::CAT_DEFAULT,
|
||||
"Failed to send join message after reconnection: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
} else {
|
||||
gst::error!(
|
||||
gst::CAT_DEFAULT,
|
||||
"Failed to encode join message after reconnection"
|
||||
);
|
||||
}
|
||||
|
||||
// If we need to interact with GStreamer or GLib, schedule it on the main thread
|
||||
let self_clone_for_main = self_clone.clone();
|
||||
glib::MainContext::default().invoke(move || {
|
||||
// Emit the "session-requested" signal
|
||||
self_clone_for_main.emit_by_name::<()>(
|
||||
"session-requested",
|
||||
&[
|
||||
&"unique-session-id",
|
||||
&"consumer-identifier",
|
||||
&None::<WebRTCSessionDescription>,
|
||||
],
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let join_msg = MessageJoin {
|
||||
let push_msg = MessageRaw {
|
||||
base: MessageBase {
|
||||
payload_type: "join".to_string(),
|
||||
payload_type: "push-stream-room".to_string(),
|
||||
latency: None,
|
||||
},
|
||||
joiner_type: JoinerType::JoinerNode,
|
||||
data: serde_json::Value::from(stream_room),
|
||||
};
|
||||
if let Ok(encoded) = encode_message(&join_msg) {
|
||||
if let Err(e) = nestri_ws.send_message(encoded) {
|
||||
tracing::error!("Failed to send join message: {:?}", e);
|
||||
gst::error!(gst::CAT_DEFAULT, "Failed to send join message: {:?}", e);
|
||||
}
|
||||
} else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Failed to encode join message");
|
||||
|
||||
let Some(stream_protocol) = self.get_stream_protocol() else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
|
||||
return;
|
||||
};
|
||||
|
||||
if let Err(e) = stream_protocol.send_message(&push_msg) {
|
||||
tracing::error!("Failed to send push stream room message: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,26 +220,21 @@ impl SignallableImpl for Signaller {
|
||||
}
|
||||
|
||||
fn send_sdp(&self, _session_id: &str, sdp: &WebRTCSessionDescription) {
|
||||
let nestri_ws = {
|
||||
self.nestri_ws
|
||||
.read()
|
||||
.clone()
|
||||
.expect("NestriWebSocket not set")
|
||||
};
|
||||
let sdp_message = MessageSDP {
|
||||
base: MessageBase {
|
||||
payload_type: "sdp".to_string(),
|
||||
payload_type: "offer".to_string(),
|
||||
latency: None,
|
||||
},
|
||||
sdp: RTCSessionDescription::offer(sdp.sdp().as_text().unwrap()).unwrap(),
|
||||
};
|
||||
if let Ok(encoded) = encode_message(&sdp_message) {
|
||||
if let Err(e) = nestri_ws.send_message(encoded) {
|
||||
tracing::error!("Failed to send SDP message: {:?}", e);
|
||||
gst::error!(gst::CAT_DEFAULT, "Failed to send SDP message: {:?}", e);
|
||||
}
|
||||
} else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Failed to encode SDP message");
|
||||
|
||||
let Some(stream_protocol) = self.get_stream_protocol() else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
|
||||
return;
|
||||
};
|
||||
|
||||
if let Err(e) = stream_protocol.send_message(&sdp_message) {
|
||||
tracing::error!("Failed to send SDP message: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,12 +245,6 @@ impl SignallableImpl for Signaller {
|
||||
sdp_m_line_index: u32,
|
||||
sdp_mid: Option<String>,
|
||||
) {
|
||||
let nestri_ws = {
|
||||
self.nestri_ws
|
||||
.read()
|
||||
.clone()
|
||||
.expect("NestriWebSocket not set")
|
||||
};
|
||||
let candidate_init = RTCIceCandidateInit {
|
||||
candidate: candidate.to_string(),
|
||||
sdp_mid,
|
||||
@@ -312,18 +253,19 @@ impl SignallableImpl for Signaller {
|
||||
};
|
||||
let ice_message = MessageICE {
|
||||
base: MessageBase {
|
||||
payload_type: "ice".to_string(),
|
||||
payload_type: "ice-candidate".to_string(),
|
||||
latency: None,
|
||||
},
|
||||
candidate: candidate_init,
|
||||
};
|
||||
if let Ok(encoded) = encode_message(&ice_message) {
|
||||
if let Err(e) = nestri_ws.send_message(encoded) {
|
||||
tracing::error!("Failed to send ICE message: {:?}", e);
|
||||
gst::error!(gst::CAT_DEFAULT, "Failed to send ICE message: {:?}", e);
|
||||
}
|
||||
} else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Failed to encode ICE message");
|
||||
|
||||
let Some(stream_protocol) = self.get_stream_protocol() else {
|
||||
gst::error!(gst::CAT_DEFAULT, "Stream protocol not set");
|
||||
return;
|
||||
};
|
||||
|
||||
if let Err(e) = stream_protocol.send_message(&ice_message) {
|
||||
tracing::error!("Failed to send ICE candidate message: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::websocket::NestriWebSocket;
|
||||
use crate::p2p::p2p::NestriConnection;
|
||||
use gst::glib;
|
||||
use gst::subclass::prelude::*;
|
||||
use gstrswebrtc::signaller::Signallable;
|
||||
@@ -11,15 +11,20 @@ glib::wrapper! {
|
||||
}
|
||||
|
||||
impl NestriSignaller {
|
||||
pub fn new(nestri_ws: Arc<NestriWebSocket>, wayland_src: Arc<gst::Element>) -> Self {
|
||||
pub async fn new(
|
||||
room: String,
|
||||
nestri_conn: NestriConnection,
|
||||
wayland_src: Arc<gst::Element>,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let obj: Self = glib::Object::new();
|
||||
obj.imp().set_nestri_ws(nestri_ws);
|
||||
obj.imp().set_stream_room(room);
|
||||
obj.imp().set_nestri_connection(nestri_conn).await?;
|
||||
obj.imp().set_wayland_src(wayland_src);
|
||||
obj
|
||||
Ok(obj)
|
||||
}
|
||||
}
|
||||
impl Default for NestriSignaller {
|
||||
fn default() -> Self {
|
||||
panic!("Cannot create NestriSignaller without NestriWebSocket");
|
||||
panic!("Cannot create NestriSignaller without NestriConnection and WaylandSrc");
|
||||
}
|
||||
}
|
||||
|
||||
3
packages/server/src/p2p.rs
Normal file
3
packages/server/src/p2p.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod p2p;
|
||||
pub mod p2p_safestream;
|
||||
pub mod p2p_protocol_stream;
|
||||
131
packages/server/src/p2p/p2p.rs
Normal file
131
packages/server/src/p2p/p2p.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use futures_util::StreamExt;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
use libp2p::{
|
||||
Multiaddr, PeerId, Swarm, identify, noise, ping,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NestriConnection {
|
||||
pub peer_id: PeerId,
|
||||
pub control: libp2p_stream::Control,
|
||||
}
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct NestriBehaviour {
|
||||
identify: identify::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
stream: libp2p_stream::Behaviour,
|
||||
}
|
||||
|
||||
pub struct NestriP2P {
|
||||
swarm: Arc<Mutex<Swarm<NestriBehaviour>>>,
|
||||
}
|
||||
impl NestriP2P {
|
||||
pub async fn new() -> Result<Self, Box<dyn Error>> {
|
||||
let swarm = Arc::new(Mutex::new(
|
||||
libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_dns()?
|
||||
.with_behaviour(|key| {
|
||||
let identify_behaviour = identify::Behaviour::new(identify::Config::new(
|
||||
"/ipfs/id/1.0.0".to_string(),
|
||||
key.public(),
|
||||
));
|
||||
let ping_behaviour = ping::Behaviour::default();
|
||||
let stream_behaviour = libp2p_stream::Behaviour::default();
|
||||
|
||||
Ok(NestriBehaviour {
|
||||
identify: identify_behaviour,
|
||||
ping: ping_behaviour,
|
||||
stream: stream_behaviour,
|
||||
})
|
||||
})?
|
||||
.build(),
|
||||
));
|
||||
|
||||
// Spawn the swarm event loop
|
||||
let swarm_clone = swarm.clone();
|
||||
tokio::spawn(swarm_loop(swarm_clone));
|
||||
|
||||
{
|
||||
let mut swarm_lock = swarm.lock().await;
|
||||
swarm_lock.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?; // IPv4 - TCP Raw
|
||||
swarm_lock.listen_on("/ip6/::/tcp/0".parse()?)?; // IPv6 - TCP Raw
|
||||
}
|
||||
|
||||
Ok(NestriP2P { swarm })
|
||||
}
|
||||
|
||||
pub async fn connect(&self, conn_url: &str) -> Result<NestriConnection, Box<dyn Error>> {
|
||||
let conn_addr: Multiaddr = conn_url.parse()?;
|
||||
|
||||
let mut swarm_lock = self.swarm.lock().await;
|
||||
swarm_lock.dial(conn_addr.clone())?;
|
||||
|
||||
let Some(Protocol::P2p(peer_id)) = conn_addr.clone().iter().last() else {
|
||||
return Err("Invalid connection URL: missing peer ID".into());
|
||||
};
|
||||
|
||||
Ok(NestriConnection {
|
||||
peer_id,
|
||||
control: swarm_lock.behaviour().stream.new_control(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn swarm_loop(swarm: Arc<Mutex<Swarm<NestriBehaviour>>>) {
|
||||
loop {
|
||||
let event = {
|
||||
let mut swarm_lock = swarm.lock().await;
|
||||
swarm_lock.select_next_some().await
|
||||
};
|
||||
match event {
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
tracing::info!("Listening on: '{}'", address);
|
||||
}
|
||||
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
|
||||
tracing::info!("Connection established with peer: {}", peer_id);
|
||||
}
|
||||
SwarmEvent::ConnectionClosed { peer_id, cause, .. } => {
|
||||
if let Some(err) = cause {
|
||||
tracing::error!(
|
||||
"Connection with peer {} closed due to error: {}",
|
||||
peer_id,
|
||||
err
|
||||
);
|
||||
} else {
|
||||
tracing::info!("Connection with peer {} closed", peer_id);
|
||||
}
|
||||
}
|
||||
SwarmEvent::IncomingConnection {
|
||||
local_addr,
|
||||
send_back_addr,
|
||||
..
|
||||
} => {
|
||||
tracing::info!(
|
||||
"Incoming connection from: {} (send back to: {})",
|
||||
local_addr,
|
||||
send_back_addr
|
||||
);
|
||||
}
|
||||
SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => {
|
||||
if let Some(peer_id) = peer_id {
|
||||
tracing::error!("Failed to connect to peer {}: {}", peer_id, error);
|
||||
} else {
|
||||
tracing::error!("Failed to connect: {}", error);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
149
packages/server/src/p2p/p2p_protocol_stream.rs
Normal file
149
packages/server/src/p2p/p2p_protocol_stream.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
use crate::p2p::p2p::NestriConnection;
|
||||
use crate::p2p::p2p_safestream::SafeStream;
|
||||
use libp2p::StreamProtocol;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
// Cloneable callback type
|
||||
pub type CallbackInner = dyn Fn(Vec<u8>) + Send + Sync + 'static;
|
||||
pub struct Callback(Arc<CallbackInner>);
|
||||
impl Callback {
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: Fn(Vec<u8>) + Send + Sync + 'static,
|
||||
{
|
||||
Callback(Arc::new(f))
|
||||
}
|
||||
|
||||
pub fn call(&self, data: Vec<u8>) {
|
||||
self.0(data)
|
||||
}
|
||||
}
|
||||
impl Clone for Callback {
|
||||
fn clone(&self) -> Self {
|
||||
Callback(Arc::clone(&self.0))
|
||||
}
|
||||
}
|
||||
impl From<Box<CallbackInner>> for Callback {
|
||||
fn from(boxed: Box<CallbackInner>) -> Self {
|
||||
Callback(Arc::from(boxed))
|
||||
}
|
||||
}
|
||||
|
||||
/// NestriStreamProtocol manages the stream protocol for Nestri connections.
|
||||
pub struct NestriStreamProtocol {
|
||||
tx: mpsc::Sender<Vec<u8>>,
|
||||
safe_stream: Arc<SafeStream>,
|
||||
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
|
||||
}
|
||||
impl NestriStreamProtocol {
|
||||
const NESTRI_PROTOCOL_STREAM_PUSH: StreamProtocol =
|
||||
StreamProtocol::new("/nestri-relay/stream-push/1.0.0");
|
||||
|
||||
pub async fn new(
|
||||
nestri_connection: NestriConnection,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut nestri_connection = nestri_connection.clone();
|
||||
let push_stream = match nestri_connection
|
||||
.control
|
||||
.open_stream(nestri_connection.peer_id, Self::NESTRI_PROTOCOL_STREAM_PUSH)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = mpsc::channel(1000);
|
||||
|
||||
let sp = NestriStreamProtocol {
|
||||
tx,
|
||||
safe_stream: Arc::new(SafeStream::new(push_stream)),
|
||||
callbacks: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
|
||||
// Spawn the loops
|
||||
sp.spawn_read_loop();
|
||||
sp.spawn_write_loop(rx);
|
||||
|
||||
Ok(sp)
|
||||
}
|
||||
|
||||
fn spawn_read_loop(&self) -> tokio::task::JoinHandle<()> {
|
||||
let safe_stream = self.safe_stream.clone();
|
||||
let callbacks = self.callbacks.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let data = {
|
||||
match safe_stream.receive_raw().await {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Error receiving data: {}", e);
|
||||
break; // Exit the loop on error
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match serde_json::from_slice::<crate::messages::MessageBase>(&data) {
|
||||
Ok(base_message) => {
|
||||
let response_type = base_message.payload_type;
|
||||
let callback = {
|
||||
let callbacks_lock = callbacks.read().unwrap();
|
||||
callbacks_lock.get(&response_type).cloned()
|
||||
};
|
||||
|
||||
if let Some(callback) = callback {
|
||||
// Call the registered callback with the raw data
|
||||
callback.call(data);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"No callback registered for response type: {}",
|
||||
response_type
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to decode message: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn spawn_write_loop(&self, mut rx: mpsc::Receiver<Vec<u8>>) -> tokio::task::JoinHandle<()> {
|
||||
let safe_stream = self.safe_stream.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
// Wait for a message from the channel
|
||||
if let Some(tx_data) = rx.recv().await {
|
||||
if let Err(e) = safe_stream.send_raw(&tx_data).await {
|
||||
tracing::error!("Error sending data: {:?}", e);
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Receiver closed, exiting write loop");
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn send_message<M: serde::Serialize>(
|
||||
&self,
|
||||
message: &M,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let json_data = serde_json::to_vec(message)?;
|
||||
self.tx.try_send(json_data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register a callback for a specific response type
|
||||
pub fn register_callback<F>(&self, response_type: &str, callback: F)
|
||||
where
|
||||
F: Fn(Vec<u8>) + Send + Sync + 'static,
|
||||
{
|
||||
let mut callbacks_lock = self.callbacks.write().unwrap();
|
||||
callbacks_lock.insert(response_type.to_string(), Callback::new(callback));
|
||||
}
|
||||
}
|
||||
105
packages/server/src/p2p/p2p_safestream.rs
Normal file
105
packages/server/src/p2p/p2p_safestream.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use futures_util::io::{ReadHalf, WriteHalf};
|
||||
use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||
use prost::Message;
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
const MAX_SIZE: usize = 1024 * 1024; // 1MB
|
||||
|
||||
pub struct SafeStream {
|
||||
stream_read: Arc<Mutex<ReadHalf<libp2p::Stream>>>,
|
||||
stream_write: Arc<Mutex<WriteHalf<libp2p::Stream>>>,
|
||||
}
|
||||
impl SafeStream {
|
||||
pub fn new(stream: libp2p::Stream) -> Self {
|
||||
let (read, write) = stream.split();
|
||||
SafeStream {
|
||||
stream_read: Arc::new(Mutex::new(read)),
|
||||
stream_write: Arc::new(Mutex::new(write)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_json<T: Serialize>(
|
||||
&self,
|
||||
data: &T,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let json_data = serde_json::to_vec(data)?;
|
||||
tracing::info!("Sending JSON");
|
||||
let e = self.send_with_length_prefix(&json_data).await;
|
||||
tracing::info!("Sent JSON");
|
||||
e
|
||||
}
|
||||
|
||||
pub async fn receive_json<T: DeserializeOwned>(&self) -> Result<T, Box<dyn std::error::Error>> {
|
||||
let data = self.receive_with_length_prefix().await?;
|
||||
let msg = serde_json::from_slice(&data)?;
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
pub async fn send_proto<M: Message>(&self, msg: &M) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut proto_data = Vec::new();
|
||||
msg.encode(&mut proto_data)?;
|
||||
self.send_with_length_prefix(&proto_data).await
|
||||
}
|
||||
|
||||
pub async fn receive_proto<M: Message + Default>(
|
||||
&self,
|
||||
) -> Result<M, Box<dyn std::error::Error>> {
|
||||
let data = self.receive_with_length_prefix().await?;
|
||||
let msg = M::decode(&*data)?;
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
pub async fn send_raw(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.send_with_length_prefix(data).await
|
||||
}
|
||||
|
||||
pub async fn receive_raw(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
self.receive_with_length_prefix().await
|
||||
}
|
||||
|
||||
async fn send_with_length_prefix(&self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if data.len() > MAX_SIZE {
|
||||
return Err(Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Data exceeds maximum size",
|
||||
)));
|
||||
}
|
||||
|
||||
let mut stream_write = self.stream_write.lock().await;
|
||||
|
||||
// Write the 4-byte length prefix
|
||||
let mut length_prefix = [0u8; 4];
|
||||
BigEndian::write_u32(&mut length_prefix, data.len() as u32);
|
||||
stream_write.write_all(&length_prefix).await?;
|
||||
|
||||
// Write the actual data
|
||||
stream_write.write_all(data).await?;
|
||||
stream_write.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_with_length_prefix(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
let mut stream_read = self.stream_read.lock().await;
|
||||
|
||||
// Read the 4-byte length prefix
|
||||
let mut length_prefix = [0u8; 4];
|
||||
stream_read.read_exact(&mut length_prefix).await?;
|
||||
let length = BigEndian::read_u32(&length_prefix) as usize;
|
||||
|
||||
if length > MAX_SIZE {
|
||||
return Err(Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Data exceeds maximum size",
|
||||
)));
|
||||
}
|
||||
|
||||
// Read the actual data
|
||||
let mut buffer = vec![0; length];
|
||||
stream_read.read_exact(&mut buffer).await?;
|
||||
Ok(buffer)
|
||||
}
|
||||
}
|
||||
@@ -1,225 +0,0 @@
|
||||
use crate::messages::decode_message;
|
||||
use futures_util::StreamExt;
|
||||
use futures_util::sink::SinkExt;
|
||||
use futures_util::stream::{SplitSink, SplitStream};
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{Mutex, Notify, mpsc};
|
||||
use tokio::time::sleep;
|
||||
use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
|
||||
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
|
||||
|
||||
type Callback = Box<dyn Fn(String) + Send + Sync>;
|
||||
type WSRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
|
||||
type WSWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NestriWebSocket {
|
||||
ws_url: String,
|
||||
reader: Arc<Mutex<Option<WSRead>>>,
|
||||
writer: Arc<Mutex<Option<WSWrite>>>,
|
||||
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
|
||||
message_tx: mpsc::UnboundedSender<String>,
|
||||
reconnected_notify: Arc<Notify>,
|
||||
}
|
||||
impl NestriWebSocket {
|
||||
pub async fn new(ws_url: String) -> Result<NestriWebSocket, Box<dyn Error>> {
|
||||
// Attempt to connect to the WebSocket
|
||||
let ws_stream = NestriWebSocket::do_connect(&ws_url).await.unwrap();
|
||||
|
||||
// Split the stream into read and write halves
|
||||
let (write, read) = ws_stream.split();
|
||||
|
||||
// Create the message channel
|
||||
let (message_tx, message_rx) = mpsc::unbounded_channel();
|
||||
|
||||
let ws = NestriWebSocket {
|
||||
ws_url,
|
||||
reader: Arc::new(Mutex::new(Some(read))),
|
||||
writer: Arc::new(Mutex::new(Some(write))),
|
||||
callbacks: Arc::new(RwLock::new(HashMap::new())),
|
||||
message_tx: message_tx.clone(),
|
||||
reconnected_notify: Arc::new(Notify::new()),
|
||||
};
|
||||
|
||||
// Spawn the read loop
|
||||
ws.spawn_read_loop();
|
||||
// Spawn the write loop
|
||||
ws.spawn_write_loop(message_rx);
|
||||
|
||||
Ok(ws)
|
||||
}
|
||||
|
||||
async fn do_connect(
|
||||
ws_url: &str,
|
||||
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Box<dyn Error + Send + Sync>> {
|
||||
loop {
|
||||
match connect_async(ws_url).await {
|
||||
Ok((ws_stream, _)) => {
|
||||
return Ok(ws_stream);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to connect to WebSocket, retrying: {:?}", e);
|
||||
sleep(Duration::from_secs(3)).await; // Wait before retrying
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handles message -> callback calls and reconnects on error/disconnect
|
||||
fn spawn_read_loop(&self) {
|
||||
let reader = self.reader.clone();
|
||||
let callbacks = self.callbacks.clone();
|
||||
let self_clone = self.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
// Lock the reader to get the WSRead, then drop the lock
|
||||
let ws_read_option = {
|
||||
let mut reader_lock = reader.lock().await;
|
||||
reader_lock.take()
|
||||
};
|
||||
|
||||
let mut ws_read = match ws_read_option {
|
||||
Some(ws_read) => ws_read,
|
||||
None => {
|
||||
tracing::error!("Reader is None, cannot proceed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(message_result) = ws_read.next().await {
|
||||
match message_result {
|
||||
Ok(message) => {
|
||||
let data = message
|
||||
.into_text()
|
||||
.expect("failed to turn message into text");
|
||||
let base_message = match decode_message(data.to_string()) {
|
||||
Ok(base_message) => base_message,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to decode message: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let callbacks_lock = callbacks.read().unwrap();
|
||||
if let Some(callback) = callbacks_lock.get(&base_message.payload_type) {
|
||||
let data = data.clone();
|
||||
callback(data.to_string());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Error receiving message: {:?}, reconnecting in 3 seconds...",
|
||||
e
|
||||
);
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
self_clone.reconnect().await.unwrap();
|
||||
break; // Break the inner loop to get a new ws_read
|
||||
}
|
||||
}
|
||||
}
|
||||
// After reconnection, the loop continues, and we acquire a new ws_read
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn spawn_write_loop(&self, mut message_rx: mpsc::UnboundedReceiver<String>) {
|
||||
let writer = self.writer.clone();
|
||||
let self_clone = self.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
// Wait for a message from the channel
|
||||
if let Some(message) = message_rx.recv().await {
|
||||
loop {
|
||||
// Acquire the writer lock
|
||||
let mut writer_lock = writer.lock().await;
|
||||
if let Some(writer) = writer_lock.as_mut() {
|
||||
// Try to send the message over the WebSocket
|
||||
match writer
|
||||
.send(Message::Text(Utf8Bytes::from(message.clone())))
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
// Message sent successfully
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error sending message: {:?}", e);
|
||||
// Attempt to reconnect
|
||||
if let Err(e) = self_clone.reconnect().await {
|
||||
tracing::error!("Error during reconnection: {:?}", e);
|
||||
// Wait before retrying
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::error!("Writer is None, cannot send message");
|
||||
// Attempt to reconnect
|
||||
if let Err(e) = self_clone.reconnect().await {
|
||||
tracing::error!("Error during reconnection: {:?}", e);
|
||||
// Wait before retrying
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn reconnect(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
loop {
|
||||
match NestriWebSocket::do_connect(&self.ws_url).await {
|
||||
Ok(ws_stream) => {
|
||||
let (write, read) = ws_stream.split();
|
||||
{
|
||||
let mut writer_lock = self.writer.lock().await;
|
||||
*writer_lock = Some(write);
|
||||
}
|
||||
{
|
||||
let mut reader_lock = self.reader.lock().await;
|
||||
*reader_lock = Some(read);
|
||||
}
|
||||
// Notify subscribers of successful reconnection
|
||||
self.reconnected_notify.notify_waiters();
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to reconnect to WebSocket: {:?}", e);
|
||||
sleep(Duration::from_secs(3)).await; // Wait before retrying
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a message through the WebSocket
|
||||
pub fn send_message(&self, message: String) -> Result<(), Box<dyn Error>> {
|
||||
self.message_tx
|
||||
.send(message)
|
||||
.map_err(|e| format!("Failed to send message: {:?}", e).into())
|
||||
}
|
||||
|
||||
/// Register a callback for a specific response type
|
||||
pub fn register_callback<F>(&self, response_type: &str, callback: F)
|
||||
where
|
||||
F: Fn(String) + Send + Sync + 'static,
|
||||
{
|
||||
let mut callbacks_lock = self.callbacks.write().unwrap();
|
||||
callbacks_lock.insert(response_type.to_string(), Box::new(callback));
|
||||
}
|
||||
|
||||
/// Subscribe to event for reconnection
|
||||
pub fn subscribe_reconnected(&self) -> Arc<Notify> {
|
||||
self.reconnected_notify.clone()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user