From 4f207d55d9df7c0353d35dd533863458b5196d51 Mon Sep 17 00:00:00 2001 From: Vincent Junge Date: Sun, 19 Nov 2023 17:08:53 +0100 Subject: [PATCH 1/5] prepare for platform-dependant websockets backend --- src/gateway/backend_tungstenite.rs | 69 ++++++++++++++++++++++++ src/gateway/gateway.rs | 86 ++++++++---------------------- src/gateway/handle.rs | 14 ++--- src/gateway/heartbeat.rs | 35 +++--------- src/gateway/message.rs | 36 ++----------- src/gateway/mod.rs | 12 ++--- 6 files changed, 110 insertions(+), 142 deletions(-) create mode 100644 src/gateway/backend_tungstenite.rs diff --git a/src/gateway/backend_tungstenite.rs b/src/gateway/backend_tungstenite.rs new file mode 100644 index 0000000..0522847 --- /dev/null +++ b/src/gateway/backend_tungstenite.rs @@ -0,0 +1,69 @@ +use futures_util::{ + stream::{SplitSink, SplitStream}, + StreamExt, +}; +use tokio::net::TcpStream; +use tokio_tungstenite::{ + connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream, +}; + +use super::GatewayMessage; +use crate::errors::GatewayError; + +#[derive(Debug, Clone)] +pub struct WebSocketBackend; + +// These could be made into inherent associated types when that's stabilized +pub type WsSink = SplitSink>, tungstenite::Message>; +pub type WsStream = SplitStream>>; +pub type WsMessage = tungstenite::Message; +pub type WsError = tungstenite::Error; + +impl WebSocketBackend { + // When impl_trait_in_assoc_type gets stabilized, this would just be = impl Sink + + pub async fn new( + websocket_url: &str, + ) -> Result<(WsSink, WsStream), crate::errors::GatewayError> { + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + roots.add(&rustls::Certificate(cert.0)).unwrap(); + } + let (websocket_stream, _) = match connect_async_tls_with_config( + websocket_url, + None, + false, + Some(Connector::Rustls( + rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth() + .into(), + )), + ) + .await + { + Ok(websocket_stream) => websocket_stream, + Err(e) => { + return Err(GatewayError::CannotConnect { + error: e.to_string(), + }) + } + }; + + Ok(websocket_stream.split()) + } +} + +impl From for tungstenite::Message { + fn from(message: GatewayMessage) -> Self { + Self::Text(message.0) + } +} + +impl From for GatewayMessage { + fn from(value: tungstenite::Message) -> Self { + Self(value.to_string()) + } +} diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index 30d0610..b2ae673 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -1,4 +1,6 @@ use self::event::Events; +use super::handle::GatewayHandle; +use super::heartbeat::HeartbeatHandler; use super::*; use crate::types::{ self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete, @@ -10,15 +12,8 @@ use crate::types::{ pub struct Gateway { events: Arc>, heartbeat_handler: HeartbeatHandler, - websocket_send: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - websocket_receive: SplitStream>>, + websocket_send: Arc>, + websocket_receive: WsStream, kill_send: tokio::sync::broadcast::Sender<()>, store: Arc>>>>, url: String, @@ -27,34 +22,7 @@ pub struct Gateway { impl Gateway { #[allow(clippy::new_ret_no_self)] pub async fn new(websocket_url: String) -> Result { - let mut roots = rustls::RootCertStore::empty(); - for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") - { - roots.add(&rustls::Certificate(cert.0)).unwrap(); - } - let (websocket_stream, _) = match connect_async_tls_with_config( - &websocket_url, - None, - false, - Some(Connector::Rustls( - rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(roots) - .with_no_client_auth() - .into(), - )), - ) - .await - { - Ok(websocket_stream) => websocket_stream, - Err(e) => { - return Err(GatewayError::CannotConnect { - error: e.to_string(), - }) - } - }; - - let (websocket_send, mut websocket_receive) = websocket_stream.split(); + let (websocket_send, mut websocket_receive) = WebSocketBackend::new(&websocket_url).await?; let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); @@ -63,9 +31,8 @@ impl Gateway { // Wait for the first hello and then spawn both tasks so we avoid nested tasks // This automatically spawns the heartbeat task, but from the main thread - let msg = websocket_receive.next().await.unwrap().unwrap(); - let gateway_payload: types::GatewayReceivePayload = - serde_json::from_str(msg.to_text().unwrap()).unwrap(); + let msg: GatewayMessage = websocket_receive.next().await.unwrap().unwrap().into(); + let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(&msg.0).unwrap(); if gateway_payload.op_code != GATEWAY_HELLO { return Err(GatewayError::NonHelloOnInitiate { @@ -120,8 +87,7 @@ impl Gateway { // This if chain can be much better but if let is unstable on stable rust if let Some(Ok(message)) = msg { - self.handle_message(GatewayMessage::from_tungstenite_message(message)) - .await; + self.handle_message(message.into()).await; continue; } @@ -134,7 +100,7 @@ impl Gateway { /// Closes the websocket connection and stops all tasks async fn close(&mut self) { self.kill_send.send(()).unwrap(); - self.websocket_send.lock().await.close().await.unwrap(); + let _ = self.websocket_send.lock().await.close().await; } /// Deserializes and updates a dispatched event, when we already know its type; @@ -156,31 +122,23 @@ impl Gateway { /// This handles a message as a websocket event and updates its events along with the events' observers pub async fn handle_message(&mut self, msg: GatewayMessage) { - if msg.is_empty() { + if msg.0.is_empty() { return; } - if !msg.is_error() && !msg.is_payload() { - warn!( - "Message unrecognised: {:?}, please open an issue on the chorus github", - msg.message.to_string() - ); + let Ok(gateway_payload) = msg.payload() else { + if let Some(error) = msg.error() { + warn!("GW: Received error {:?}, connection will close..", error); + self.close().await; + self.events.lock().await.error.notify(error).await; + } else { + warn!( + "Message unrecognised: {:?}, please open an issue on the chorus github", + msg.0 + ); + } return; - } - - if msg.is_error() { - let error = msg.error().unwrap(); - - warn!("GW: Received error {:?}, connection will close..", error); - - self.close().await; - - self.events.lock().await.error.notify(error).await; - - return; - } - - let gateway_payload = msg.payload().unwrap(); + }; // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes match gateway_payload.op_code { diff --git a/src/gateway/handle.rs b/src/gateway/handle.rs index 1200b30..c38d987 100644 --- a/src/gateway/handle.rs +++ b/src/gateway/handle.rs @@ -9,14 +9,7 @@ use crate::types::{self, Composite}; pub struct GatewayHandle { pub url: String, pub events: Arc>, - pub websocket_send: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, + pub websocket_send: Arc>, /// Tells gateway tasks to close pub(super) kill_send: tokio::sync::broadcast::Sender<()>, pub(crate) store: Arc>>>>, @@ -32,13 +25,12 @@ impl GatewayHandle { }; let payload_json = serde_json::to_string(&gateway_payload).unwrap(); - - let message = tokio_tungstenite::tungstenite::Message::text(payload_json); + let message = GatewayMessage(payload_json); self.websocket_send .lock() .await - .send(message) + .send(message.into()) .await .unwrap(); } diff --git a/src/gateway/heartbeat.rs b/src/gateway/heartbeat.rs index dd162b7..aeb3dae 100644 --- a/src/gateway/heartbeat.rs +++ b/src/gateway/heartbeat.rs @@ -1,6 +1,5 @@ -use crate::types; - use super::*; +use crate::types; /// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; @@ -20,27 +19,14 @@ pub(super) struct HeartbeatHandler { impl HeartbeatHandler { pub fn new( heartbeat_interval: Duration, - websocket_tx: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, + websocket_tx: Arc>, kill_rc: tokio::sync::broadcast::Receiver<()>, - ) -> HeartbeatHandler { + ) -> Self { let (send, receive) = tokio::sync::mpsc::channel(32); let kill_receive = kill_rc.resubscribe(); let handle: JoinHandle<()> = task::spawn(async move { - HeartbeatHandler::heartbeat_task( - websocket_tx, - heartbeat_interval, - receive, - kill_receive, - ) - .await; + Self::heartbeat_task(websocket_tx, heartbeat_interval, receive, kill_receive).await; }); Self { @@ -55,14 +41,7 @@ impl HeartbeatHandler { /// Can be killed by the kill broadcast; /// If the websocket is closed, will die out next time it tries to send a heartbeat; pub async fn heartbeat_task( - websocket_tx: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, + websocket_tx: Arc>, heartbeat_interval: Duration, mut receive: tokio::sync::mpsc::Receiver, mut kill_receive: tokio::sync::broadcast::Receiver<()>, @@ -122,9 +101,9 @@ impl HeartbeatHandler { let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); - let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); + let msg = GatewayMessage(heartbeat_json); - let send_result = websocket_tx.lock().await.send(msg).await; + let send_result = websocket_tx.lock().await.send(msg.into()).await; if send_result.is_err() { // We couldn't send, the websocket is broken warn!("GW: Couldnt send heartbeat, websocket seems broken"); diff --git a/src/gateway/message.rs b/src/gateway/message.rs index edee9dd..8a94f73 100644 --- a/src/gateway/message.rs +++ b/src/gateway/message.rs @@ -5,24 +5,14 @@ use super::*; /// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError]. /// This struct is used internally when handling messages. #[derive(Clone, Debug)] -pub struct GatewayMessage { - /// The message we received from the server - pub(super) message: tokio_tungstenite::tungstenite::Message, -} +pub struct GatewayMessage(pub String); impl GatewayMessage { - /// Creates self from a tungstenite message - pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self { - Self { message } - } - /// Parses the message as an error; /// Returns the error if succesfully parsed, None if the message isn't an error pub fn error(&self) -> Option { - let content = self.message.to_string(); - // Some error strings have dots on the end, which we don't care about - let processed_content = content.to_lowercase().replace('.', ""); + let processed_content = self.0.to_lowercase().replace('.', ""); match processed_content.as_str() { "unknown error" | "4000" => Some(GatewayError::Unknown), @@ -45,29 +35,9 @@ impl GatewayMessage { } } - /// Returns whether or not the message is an error - pub fn is_error(&self) -> bool { - self.error().is_some() - } - /// Parses the message as a payload; /// Returns a result of deserializing pub fn payload(&self) -> Result { - return serde_json::from_str(self.message.to_text().unwrap()); - } - - /// Returns whether or not the message is a payload - pub fn is_payload(&self) -> bool { - // close messages are never payloads, payloads are only text messages - if self.message.is_close() | !self.message.is_text() { - return false; - } - - return self.payload().is_ok(); - } - - /// Returns whether or not the message is empty - pub fn is_empty(&self) -> bool { - self.message.is_empty() + return serde_json::from_str(&self.0); } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index ebd06cc..6e9d1c7 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -3,8 +3,13 @@ pub mod handle; pub mod heartbeat; pub mod message; +#[cfg(not(wasm))] +pub mod backend_tungstenite; +#[cfg(not(wasm))] +use backend_tungstenite::*; + pub use gateway::*; -pub use handle::*; +pub use handle::GatewayHandle; use heartbeat::*; pub use message::*; @@ -19,20 +24,15 @@ use std::sync::{Arc, RwLock}; use std::time::Duration; use tokio::time::sleep_until; -use futures_util::stream::SplitSink; -use futures_util::stream::SplitStream; use futures_util::SinkExt; use futures_util::StreamExt; use log::{info, trace, warn}; -use tokio::net::TcpStream; use tokio::sync::mpsc::Sender; use tokio::sync::Mutex; use tokio::task; use tokio::task::JoinHandle; use tokio::time; use tokio::time::Instant; -use tokio_tungstenite::MaybeTlsStream; -use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; // Gateway opcodes /// Opcode received when the server dispatches a [crate::types::WebSocketEvent] From 0f446f43b417036ba24419dcc114d7049e6b5f7b Mon Sep 17 00:00:00 2001 From: Vincent Junge Date: Sun, 19 Nov 2023 17:13:52 +0100 Subject: [PATCH 2/5] removed outdated comment --- src/gateway/backend_tungstenite.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gateway/backend_tungstenite.rs b/src/gateway/backend_tungstenite.rs index 0522847..e114e2c 100644 --- a/src/gateway/backend_tungstenite.rs +++ b/src/gateway/backend_tungstenite.rs @@ -20,8 +20,6 @@ pub type WsMessage = tungstenite::Message; pub type WsError = tungstenite::Error; impl WebSocketBackend { - // When impl_trait_in_assoc_type gets stabilized, this would just be = impl Sink - pub async fn new( websocket_url: &str, ) -> Result<(WsSink, WsStream), crate::errors::GatewayError> { From dd9945068f2ff1be12e23f6d27077617bc69c13a Mon Sep 17 00:00:00 2001 From: Vincent Junge Date: Sun, 19 Nov 2023 17:15:00 +0100 Subject: [PATCH 3/5] removed leftover type aliases --- src/gateway/backend_tungstenite.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gateway/backend_tungstenite.rs b/src/gateway/backend_tungstenite.rs index e114e2c..a81bab3 100644 --- a/src/gateway/backend_tungstenite.rs +++ b/src/gateway/backend_tungstenite.rs @@ -16,8 +16,6 @@ pub struct WebSocketBackend; // These could be made into inherent associated types when that's stabilized pub type WsSink = SplitSink>, tungstenite::Message>; pub type WsStream = SplitStream>>; -pub type WsMessage = tungstenite::Message; -pub type WsError = tungstenite::Error; impl WebSocketBackend { pub async fn new( From c0ce540da68422f65286b8680723ce748c7ddaed Mon Sep 17 00:00:00 2001 From: Vincent Junge Date: Sun, 19 Nov 2023 17:18:08 +0100 Subject: [PATCH 4/5] for got unwrap :3 --- src/gateway/gateway.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index b2ae673..f0d1882 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -100,7 +100,7 @@ impl Gateway { /// Closes the websocket connection and stops all tasks async fn close(&mut self) { self.kill_send.send(()).unwrap(); - let _ = self.websocket_send.lock().await.close().await; + self.websocket_send.lock().await.close().await.unwrap(); } /// Deserializes and updates a dispatched event, when we already know its type; From 5bd8f32a6abf3f0fc29ee27e69a3ca866b064b0c Mon Sep 17 00:00:00 2001 From: Vincent Junge Date: Sun, 19 Nov 2023 17:46:49 +0100 Subject: [PATCH 5/5] remove superfluous return --- src/gateway/backend_tungstenite.rs | 2 +- src/gateway/gateway.rs | 3 ++- src/gateway/message.rs | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/gateway/backend_tungstenite.rs b/src/gateway/backend_tungstenite.rs index a81bab3..53b6982 100644 --- a/src/gateway/backend_tungstenite.rs +++ b/src/gateway/backend_tungstenite.rs @@ -18,7 +18,7 @@ pub type WsSink = SplitSink>, tungsten pub type WsStream = SplitStream>>; impl WebSocketBackend { - pub async fn new( + pub async fn connect( websocket_url: &str, ) -> Result<(WsSink, WsStream), crate::errors::GatewayError> { let mut roots = rustls::RootCertStore::empty(); diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index f0d1882..38efe7c 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -22,7 +22,8 @@ pub struct Gateway { impl Gateway { #[allow(clippy::new_ret_no_self)] pub async fn new(websocket_url: String) -> Result { - let (websocket_send, mut websocket_receive) = WebSocketBackend::new(&websocket_url).await?; + let (websocket_send, mut websocket_receive) = + WebSocketBackend::connect(&websocket_url).await?; let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); diff --git a/src/gateway/message.rs b/src/gateway/message.rs index 8a94f73..2c12e48 100644 --- a/src/gateway/message.rs +++ b/src/gateway/message.rs @@ -38,6 +38,6 @@ impl GatewayMessage { /// Parses the message as a payload; /// Returns a result of deserializing pub fn payload(&self) -> Result { - return serde_json::from_str(&self.0); + serde_json::from_str(&self.0) } }