From ce7ff49ee4a22c89a0731feb1250125cf0605a21 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Sat, 18 Nov 2023 18:08:12 +0100 Subject: [PATCH] checkpoint --- src/api/auth/login.rs | 5 ++- src/api/auth/mod.rs | 4 +- src/gateway/default/gateway.rs | 24 ++++++------ src/gateway/default/handle.rs | 25 ++++++++++++- src/gateway/default/heartbeat.rs | 39 +++----------------- src/gateway/default/mod.rs | 21 +++++++++++ src/gateway/mod.rs | 63 ++++++++++++++++++-------------- 7 files changed, 105 insertions(+), 76 deletions(-) diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 71d594f..b1138c0 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -4,7 +4,7 @@ use reqwest::Client; use serde_json::to_string; use crate::errors::ChorusResult; -use crate::gateway::{GatewayCapable, GatewayHandleCapable}; +use crate::gateway::{DefaultGatewayHandle, GatewayCapable, GatewayHandleCapable}; use crate::instance::{ChorusUser, Instance}; use crate::ratelimiter::ChorusRequest; use crate::types::{GatewayIdentifyPayload, LimitType, LoginResult, LoginSchema}; @@ -37,7 +37,8 @@ impl Instance { self.limits_information.as_mut().unwrap().ratelimits = shell.limits.clone().unwrap(); } let mut identify = GatewayIdentifyPayload::common(); - let gateway = Gateway::get_handle(self.urls.wss.clone()).await.unwrap(); + let gateway: DefaultGatewayHandle = + Gateway::get_handle(self.urls.wss.clone()).await.unwrap(); identify.token = login_result.token.clone(); gateway.send_identify(identify).await; let user = ChorusUser::new( diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index 402af5c..7d03b95 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; pub use login::*; pub use register::*; -use crate::gateway::{GatewayCapable, GatewayHandleCapable}; +use crate::gateway::{DefaultGatewayHandle, GatewayCapable, GatewayHandleCapable}; use crate::{ errors::ChorusResult, gateway::DefaultGateway, @@ -26,7 +26,7 @@ impl Instance { .await .unwrap(); let mut identify = GatewayIdentifyPayload::common(); - let gateway = DefaultGateway::get_handle(self.urls.wss.clone()) + let gateway: DefaultGatewayHandle = DefaultGateway::get_handle(self.urls.wss.clone()) .await .unwrap(); identify.token = token.clone(); diff --git a/src/gateway/default/gateway.rs b/src/gateway/default/gateway.rs index 8ec825d..918d9d9 100644 --- a/src/gateway/default/gateway.rs +++ b/src/gateway/default/gateway.rs @@ -1,4 +1,5 @@ use futures_util::StreamExt; +use tokio_tungstenite::tungstenite::Message; use super::events::Events; use super::*; @@ -25,18 +26,19 @@ pub struct DefaultGateway { #[async_trait] impl GatewayCapable< + tokio_tungstenite::tungstenite::Message, WebSocketStream>, - WebSocketStream>, - DefaultGatewayHandle, - HeartbeatHandler, > for DefaultGateway { fn get_heartbeat_handler(&self) -> &HeartbeatHandler { &self.heartbeat_handler } - #[allow(clippy::new_ret_no_self)] - async fn get_handle(websocket_url: String) -> Result { + async fn get_handle< + G: GatewayHandleCapable>>, + >( + websocket_url: String, + ) -> Result { let mut roots = rustls::RootCertStore::empty(); for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { @@ -112,13 +114,13 @@ impl gateway.gateway_listen_task().await; }); - Ok(DefaultGatewayHandle { - url: websocket_url.clone(), - events: shared_events, - websocket_send: shared_websocket_send.clone(), - kill_send: kill_send.clone(), + Ok(G::new( + websocket_url.clone(), + shared_events, + shared_websocket_send.clone(), + kill_send.clone(), store, - }) + )) } /// Closes the websocket connection and stops all tasks diff --git a/src/gateway/default/handle.rs b/src/gateway/default/handle.rs index e76c0f1..dc3fd4c 100644 --- a/src/gateway/default/handle.rs +++ b/src/gateway/default/handle.rs @@ -4,7 +4,7 @@ use crate::types::{self, Composite}; #[async_trait(?Send)] impl GatewayHandleCapable< - WebSocketStream>, + tokio_tungstenite::tungstenite::Message, WebSocketStream>, > for DefaultGatewayHandle { @@ -23,6 +23,29 @@ impl self.kill_send.send(()).unwrap(); self.websocket_send.lock().await.close().await.unwrap(); } + + fn new( + url: String, + events: Arc>, + websocket_send: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + kill_send: tokio::sync::broadcast::Sender<()>, + store: GatewayStore, + ) -> Self { + Self { + url, + events, + websocket_send, + kill_send, + store, + } + } } /// Represents a handle to a Gateway connection. A Gateway connection will create observable diff --git a/src/gateway/default/heartbeat.rs b/src/gateway/default/heartbeat.rs index ed39d8a..e7b7129 100644 --- a/src/gateway/default/heartbeat.rs +++ b/src/gateway/default/heartbeat.rs @@ -1,39 +1,12 @@ use super::*; #[async_trait] -impl HeartbeatHandlerCapable>> for HeartbeatHandler { - fn new( - heartbeat_interval: Duration, - websocket_tx: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - kill_rc: tokio::sync::broadcast::Receiver<()>, - ) -> HeartbeatHandler { - let (send, receive) = tokio::sync::mpsc::channel(32); - let kill_receive = kill_rc.resubscribe(); - - let handle: JoinHandle<()> = task::spawn(async move { - HeartbeatHandler::heartbeat_task( - websocket_tx, - heartbeat_interval, - receive, - kill_receive, - ) - .await; - }); - - Self { - heartbeat_interval, - send, - handle, - } - } - +impl + HeartbeatHandlerCapable< + tokio_tungstenite::tungstenite::Message, + WebSocketStream>, + > for HeartbeatHandler +{ fn get_send(&self) -> &Sender { &self.send } diff --git a/src/gateway/default/mod.rs b/src/gateway/default/mod.rs index 14e5c1e..f90f97d 100644 --- a/src/gateway/default/mod.rs +++ b/src/gateway/default/mod.rs @@ -26,6 +26,27 @@ use tokio::task::JoinHandle; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; +impl crate::gateway::MessageCapable for tokio_tungstenite::tungstenite::Message { + fn as_string(&self) -> Option { + match self { + Message::Text(text) => Some(text.clone()), + _ => None, + } + } + + fn is_empty(&self) -> bool { + todo!() + } + + fn is_error(&self) -> bool { + todo!() + } + + fn as_bytes(&self) -> Option> { + todo!() + } +} + #[cfg(test)] mod test { use crate::types; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e2e3552..071e52f 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -29,11 +29,10 @@ use std::time::{self, Duration, Instant}; use async_trait::async_trait; use futures_util::stream::SplitSink; use futures_util::Sink; -use futures_util::{SinkExt, Stream}; +use futures_util::SinkExt; use log::{info, trace, warn}; use tokio::sync::mpsc::Sender; use tokio::sync::Mutex; -use tokio_tungstenite::tungstenite::Message; pub type GatewayStore = Arc>>>>; @@ -88,6 +87,13 @@ const GATEWAY_CALL_SYNC: u8 = 13; /// See [types::LazyRequest] const GATEWAY_LAZY_REQUEST: u8 = 14; +pub trait MessageCapable { + fn as_string(&self) -> Option; + fn as_bytes(&self) -> Option>; + fn is_empty(&self) -> bool; + fn is_error(&self) -> bool; +} + pub type ObservableObject = dyn Send + Sync + Any; /// Used for communications between the heartbeat and gateway thread. @@ -153,22 +159,22 @@ impl GatewayEvent { #[allow(clippy::type_complexity)] #[async_trait] -pub trait GatewayCapable +pub trait GatewayCapable where - R: Stream, - S: Sink + Send + 'static, - G: GatewayHandleCapable, - H: HeartbeatHandlerCapable + Send + Sync, + T: MessageCapable + Send + 'static, + S: Sink + Send, { fn get_events(&self) -> Arc>; - fn get_websocket_send(&self) -> Arc>>; + fn get_websocket_send(&self) -> Arc>>; fn get_store(&self) -> GatewayStore; fn get_url(&self) -> String; - fn get_heartbeat_handler(&self) -> &H; + fn get_heartbeat_handler(&self) -> &HeartbeatHandler; /// Returns a Result with a matching impl of [`GatewayHandleCapable`], or a [`GatewayError`] /// /// DOCUMENTME: Explain what this method has to do to be a good get_handle() impl, or link to such documentation - async fn get_handle(websocket_url: String) -> Result; + async fn get_handle>( + websocket_url: String, + ) -> Result; async fn close(&mut self); /// This handles a message as a websocket event and updates its events along with the events' observers async fn handle_message(&mut self, msg: GatewayMessage) { @@ -456,27 +462,35 @@ pub struct HeartbeatHandler { } #[async_trait(?Send)] -pub trait GatewayHandleCapable +pub trait GatewayHandleCapable where - R: Stream, - S: Sink, + T: MessageCapable + Send + 'static, + S: Sink, { + fn new( + url: String, + events: Arc>, + websocket_send: Arc>>, + kill_send: tokio::sync::broadcast::Sender<()>, + store: GatewayStore, + ) -> Self; + /// Sends json to the gateway with an opcode async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value); /// Observes an Item ``, which will update itself, if new information about this /// item arrives on the corresponding Gateway Thread - async fn observe + Send + Sync>( + async fn observe + Send + Sync>( &self, - object: Arc>, - ) -> Arc>; + object: Arc>, + ) -> Arc>; /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` /// with all of its observable fields being observed. - async fn observe_and_into_inner>( + async fn observe_and_into_inner>( &self, - object: Arc>, - ) -> T { + object: Arc>, + ) -> U { let channel = self.observe(object.clone()).await; let object = channel.read().unwrap().clone(); object @@ -556,17 +570,12 @@ where } #[async_trait] -pub trait HeartbeatHandlerCapable + Send + 'static> { - fn new( - heartbeat_interval: Duration, - websocket_tx: Arc>>, - kill_rc: tokio::sync::broadcast::Receiver<()>, - ) -> Self; - +pub trait HeartbeatHandlerCapable> { fn get_send(&self) -> &Sender; + fn as_arc_mutex(&self) -> Arc>; fn get_heartbeat_interval(&self) -> Duration; async fn heartbeat_task( - websocket_tx: Arc>>, + websocket_tx: Arc>>, heartbeat_interval: Duration, mut receive: tokio::sync::mpsc::Receiver, mut kill_receive: tokio::sync::broadcast::Receiver<()>,