From 06d25d3e500bba73209bf1340761d89ed97860bf Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Wed, 15 Nov 2023 00:04:04 +0100 Subject: [PATCH] Impl base of GatewayCapable for Gateway --- src/api/auth/login.rs | 4 +- src/api/auth/mod.rs | 3 +- src/api/auth/register.rs | 4 +- src/gateway/gateway.rs | 100 ++++++-------- src/gateway/handle.rs | 8 ++ src/gateway/mod.rs | 275 +++++++++++++++++++++++++++++++++++++-- src/instance.rs | 4 +- 7 files changed, 322 insertions(+), 76 deletions(-) diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 272ee04..750d01e 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -4,7 +4,7 @@ use reqwest::Client; use serde_json::to_string; use crate::errors::ChorusResult; -use crate::gateway::Gateway; +use crate::gateway::{Gateway, GatewayCapable}; use crate::instance::{ChorusUser, Instance}; use crate::ratelimiter::ChorusRequest; use crate::types::{GatewayIdentifyPayload, LimitType, LoginResult, LoginSchema}; @@ -36,7 +36,7 @@ impl Instance { self.limits_information.as_mut().unwrap().ratelimits = shell.limits.clone().unwrap(); } let mut identify = GatewayIdentifyPayload::common(); - let gateway = Gateway::new(self.urls.wss.clone()).await.unwrap(); + let gateway = Gateway::get_handle(self.urls.wss.clone()).await.unwrap(); identify.token = login_result.token.clone(); gateway.send_identify(identify).await; let user = ChorusUser::new( diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index 3ad4a60..091975c 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -3,6 +3,7 @@ use std::sync::{Arc, RwLock}; pub use login::*; pub use register::*; +use crate::gateway::GatewayCapable; use crate::{ errors::ChorusResult, gateway::Gateway, @@ -25,7 +26,7 @@ impl Instance { .await .unwrap(); let mut identify = GatewayIdentifyPayload::common(); - let gateway = Gateway::new(self.urls.wss.clone()).await.unwrap(); + let gateway = Gateway::get_handle(self.urls.wss.clone()).await.unwrap(); identify.token = token.clone(); gateway.send_identify(identify).await; let user = ChorusUser::new( diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index d10915e..ed6fcbf 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use reqwest::Client; use serde_json::to_string; -use crate::gateway::Gateway; +use crate::gateway::{Gateway, GatewayCapable, GatewayHandle}; use crate::types::GatewayIdentifyPayload; use crate::{ errors::ChorusResult, @@ -45,7 +45,7 @@ impl Instance { let user_object = self.get_user(token.clone(), None).await.unwrap(); let settings = ChorusUser::get_settings(&token, &self.urls.api.clone(), &mut self).await?; let mut identify = GatewayIdentifyPayload::common(); - let gateway = Gateway::new(self.urls.wss.clone()).await.unwrap(); + let gateway: GatewayHandle = Gateway::get_handle(self.urls.wss.clone()).await.unwrap(); identify.token = token.clone(); gateway.send_identify(identify).await; let user = ChorusUser::new( diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index 7aaae67..15215ac 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -26,14 +26,16 @@ pub struct Gateway { url: String, } +#[async_trait] impl GatewayCapable< WebSocketStream>, WebSocketStream>, + GatewayHandle, > for Gateway { #[allow(clippy::new_ret_no_self)] - async fn new(websocket_url: String) -> Result { + async fn get_handle(websocket_url: String) -> Result { let mut roots = rustls::RootCertStore::empty(); for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { @@ -118,51 +120,14 @@ impl }) } - /// The main gateway listener task; - /// - /// Can only be stopped by closing the websocket, cannot be made to listen for kill - pub async fn gateway_listen_task(&mut self) { - loop { - let msg = self.websocket_receive.next().await; - - // This if chain can be much better but if let is unstable on stable rust - if let Some(Ok(message)) = msg { - self.handle_message(GatewayMessage::from_tungstenite_message(message)) - .await; - continue; - } - - // We couldn't receive the next message or it was an error, something is wrong with the websocket, close - warn!("GW: Websocket is broken, stopping gateway"); - break; - } - } - /// Closes the websocket connection and stops all tasks async fn close(&mut self) { self.kill_send.send(()).unwrap(); self.websocket_send.lock().await.close().await.unwrap(); } - /// Deserializes and updates a dispatched event, when we already know its type; - /// (Called for every event in handle_message) - #[allow(dead_code)] // TODO: Remove this allow annotation - async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>( - data: &'a str, - event: &mut GatewayEvent, - ) -> Result<(), serde_json::Error> { - let data_deserialize_result: Result = serde_json::from_str(data); - - if data_deserialize_result.is_err() { - return Err(data_deserialize_result.err().unwrap()); - } - - event.notify(data_deserialize_result.unwrap()).await; - Ok(()) - } - /// This handles a message as a websocket event and updates its events along with the events' observers - pub async fn handle_message(&mut self, msg: GatewayMessage) { + async fn handle_message(&mut self, msg: GatewayMessage) { if msg.is_empty() { return; } @@ -436,28 +401,45 @@ impl } fn get_url(&self) -> String { - self.url + self.url.clone() + } +} + +impl Gateway { + /// The main gateway listener task; + /// + /// Can only be stopped by closing the websocket, cannot be made to listen for kill + pub async fn gateway_listen_task(&mut self) { + loop { + let msg = self.websocket_receive.next().await; + + // This if chain can be much better but if let is unstable on stable rust + if let Some(Ok(message)) = msg { + self.handle_message(GatewayMessage::from_tungstenite_message(message)); + continue; + } + + // We couldn't receive the next message or it was an error, something is wrong with the websocket, close + warn!("GW: Websocket is broken, stopping gateway"); + break; + } } - fn get_handle( - &self, - websocket_url: &'static str, - ) -> Result< - Box< - dyn GatewayHandleCapable< - Box< - dyn GatewayCapable< - WebSocketStream>, - WebSocketStream>, - >, - >, - WebSocketStream>, - WebSocketStream>, - >, - >, - GatewayError, - > { - todo!() + /// Deserializes and updates a dispatched event, when we already know its type; + /// (Called for every event in handle_message) + #[allow(dead_code)] // TODO: Remove this allow annotation + async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>( + data: &'a str, + event: &mut GatewayEvent, + ) -> Result<(), serde_json::Error> { + let data_deserialize_result: Result = serde_json::from_str(data); + + if data_deserialize_result.is_err() { + return Err(data_deserialize_result.err().unwrap()); + } + + event.notify(data_deserialize_result.unwrap()).await; + Ok(()) } } diff --git a/src/gateway/handle.rs b/src/gateway/handle.rs index a4cf20a..3a87c5a 100644 --- a/src/gateway/handle.rs +++ b/src/gateway/handle.rs @@ -169,3 +169,11 @@ impl GatewayHandle { self.websocket_send.lock().await.close().await.unwrap(); } } + +impl + GatewayHandleCapable< + WebSocketStream>, + WebSocketStream>, + > for GatewayHandle +{ +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 04043e9..2b5d414 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,7 +10,7 @@ pub use message::*; use tokio_tungstenite::tungstenite::Message; use crate::errors::GatewayError; -use crate::types::{Snowflake, WebSocketEvent}; +use crate::types::{self, Snowflake, WebSocketEvent}; use async_trait::async_trait; use std::any::Any; @@ -139,10 +139,12 @@ impl GatewayEvent { } #[allow(clippy::type_complexity)] -pub trait GatewayCapable +#[async_trait] +pub trait GatewayCapable where R: Stream, S: Sink, + G: GatewayHandleCapable, { fn get_events(&self) -> Arc>; fn get_websocket_send(&self) -> Arc>>; @@ -151,17 +153,270 @@ where /// Returns a Result with a matching impl of [`GatewayHandleCapable`], or a [`GatewayError`] /// /// DOCUMENTME: Explain what this method has to do to be a good get_handle() impl, or link to such documentation - fn get_handle( - &self, - websocket_url: &'static str, - ) -> Result>, R, S>>, GatewayError>; - fn close(&mut self); - fn handle_message(&mut self, msg: GatewayMessage); + async fn get_handle(websocket_url: String) -> Result; + async fn close(&mut self); + async fn handle_message(&mut self, msg: GatewayMessage) { + if msg.is_empty() { + return; + } + + if !msg.is_error() && !msg.is_payload() { + warn!( + "Message unrecognised: {:?}, please open an issue on the chorus github", + msg.message.to_string() + ); + return; + } + + if msg.is_error() { + let error = msg.error().unwrap(); + + warn!("GW: Received error {:?}, connection will close..", error); + + self.close().await; + + self.get_events().lock().await.error.notify(error).await; + + return; + } + + let gateway_payload = msg.payload().unwrap(); + + // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes + match gateway_payload.op_code { + // An event was dispatched, we need to look at the gateway event name t + GATEWAY_DISPATCH => { + let Some(event_name) = gateway_payload.event_name else { + warn!("Gateway dispatch op without event_name"); + return; + }; + + trace!("Gateway: Received {event_name}"); + + macro_rules! handle { + ($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => { + match event_name.as_str() { + $($name => { + let event = &mut self.get_events().lock().await.$($path).+; + let json = gateway_payload.event_data.unwrap().get(); + match serde_json::from_str(json) { + Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), + Ok(message) => { + $( + let mut message: $message_type = message; + let store = self.get_store().lock().await; + let id = if message.id().is_some() { + message.id().unwrap() + } else { + event.notify(message).await; + return; + }; + if let Some(to_update) = store.get(&id) { + let object = to_update.clone(); + let inner_object = object.read().unwrap(); + if let Some(_) = inner_object.downcast_ref::<$update_type>() { + let ptr = Arc::into_raw(object.clone()); + // SAFETY: + // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. + // - This operation doesn't read or write any shared data, and thus cannot cause a data race + // - The reference count is not being modified + let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<$update_type>).clone() }; + drop(inner_object); + message.set_json(json.to_string()); + message.set_source_url(self.get_url().clone()); + message.update(downcasted.clone()); + } else { + warn!("Received {} for {}, but it has been observed to be a different type!", $name, id) + } + } + )? + event.notify(message).await; + } + } + },)* + "RESUMED" => (), + "SESSIONS_REPLACE" => { + let result: Result, serde_json::Error> = + serde_json::from_str(gateway_payload.event_data.unwrap().get()); + match result { + Err(err) => { + warn!( + "Failed to parse gateway event {} ({})", + event_name, + err + ); + return; + } + Ok(sessions) => { + self.events.lock().await.session.replace.notify( + types::SessionsReplace {sessions} + ).await; + } + } + }, + _ => { + warn!("Received unrecognized gateway event ({event_name})! Please open an issue on the chorus github so we can implement it"); + } + } + }; + } + + // See https://discord.com/developers/docs/topics/gateway-events#receive-events + // "Some" of these are undocumented + handle!( + "READY" => session.ready, + "READY_SUPPLEMENTAL" => session.ready_supplemental, + "APPLICATION_COMMAND_PERMISSIONS_UPDATE" => application.command_permissions_update, + "AUTO_MODERATION_RULE_CREATE" =>auto_moderation.rule_create, + "AUTO_MODERATION_RULE_UPDATE" =>auto_moderation.rule_update AutoModerationRuleUpdate: AutoModerationRule, + "AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete, + "AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution, + "CHANNEL_CREATE" => channel.create ChannelCreate: Guild, + "CHANNEL_UPDATE" => channel.update ChannelUpdate: Channel, + "CHANNEL_UNREAD_UPDATE" => channel.unread_update, + "CHANNEL_DELETE" => channel.delete ChannelDelete: Guild, + "CHANNEL_PINS_UPDATE" => channel.pins_update, + "CALL_CREATE" => call.create, + "CALL_UPDATE" => call.update, + "CALL_DELETE" => call.delete, + "THREAD_CREATE" => thread.create, // TODO + "THREAD_UPDATE" => thread.update ThreadUpdate: Channel, + "THREAD_DELETE" => thread.delete, // TODO + "THREAD_LIST_SYNC" => thread.list_sync, // TODO + "THREAD_MEMBER_UPDATE" => thread.member_update, // TODO + "THREAD_MEMBERS_UPDATE" => thread.members_update, // TODO + "GUILD_CREATE" => guild.create, // TODO + "GUILD_UPDATE" => guild.update, // TODO + "GUILD_DELETE" => guild.delete, // TODO + "GUILD_AUDIT_LOG_ENTRY_CREATE" => guild.audit_log_entry_create, + "GUILD_BAN_ADD" => guild.ban_add, // TODO + "GUILD_BAN_REMOVE" => guild.ban_remove, // TODO + "GUILD_EMOJIS_UPDATE" => guild.emojis_update, // TODO + "GUILD_STICKERS_UPDATE" => guild.stickers_update, // TODO + "GUILD_INTEGRATIONS_UPDATE" => guild.integrations_update, + "GUILD_MEMBER_ADD" => guild.member_add, + "GUILD_MEMBER_REMOVE" => guild.member_remove, + "GUILD_MEMBER_UPDATE" => guild.member_update, // TODO + "GUILD_MEMBERS_CHUNK" => guild.members_chunk, // TODO + "GUILD_ROLE_CREATE" => guild.role_create GuildRoleCreate: Guild, + "GUILD_ROLE_UPDATE" => guild.role_update GuildRoleUpdate: RoleObject, + "GUILD_ROLE_DELETE" => guild.role_delete, // TODO + "GUILD_SCHEDULED_EVENT_CREATE" => guild.role_scheduled_event_create, // TODO + "GUILD_SCHEDULED_EVENT_UPDATE" => guild.role_scheduled_event_update, // TODO + "GUILD_SCHEDULED_EVENT_DELETE" => guild.role_scheduled_event_delete, // TODO + "GUILD_SCHEDULED_EVENT_USER_ADD" => guild.role_scheduled_event_user_add, + "GUILD_SCHEDULED_EVENT_USER_REMOVE" => guild.role_scheduled_event_user_remove, + "PASSIVE_UPDATE_V1" => guild.passive_update_v1, // TODO + "INTEGRATION_CREATE" => integration.create, // TODO + "INTEGRATION_UPDATE" => integration.update, // TODO + "INTEGRATION_DELETE" => integration.delete, // TODO + "INTERACTION_CREATE" => interaction.create, // TODO + "INVITE_CREATE" => invite.create, // TODO + "INVITE_DELETE" => invite.delete, // TODO + "MESSAGE_CREATE" => message.create, + "MESSAGE_UPDATE" => message.update, // TODO + "MESSAGE_DELETE" => message.delete, + "MESSAGE_DELETE_BULK" => message.delete_bulk, + "MESSAGE_REACTION_ADD" => message.reaction_add, // TODO + "MESSAGE_REACTION_REMOVE" => message.reaction_remove, // TODO + "MESSAGE_REACTION_REMOVE_ALL" => message.reaction_remove_all, // TODO + "MESSAGE_REACTION_REMOVE_EMOJI" => message.reaction_remove_emoji, // TODO + "MESSAGE_ACK" => message.ack, + "PRESENCE_UPDATE" => user.presence_update, // TODO + "RELATIONSHIP_ADD" => relationship.add, + "RELATIONSHIP_REMOVE" => relationship.remove, + "STAGE_INSTANCE_CREATE" => stage_instance.create, + "STAGE_INSTANCE_UPDATE" => stage_instance.update, // TODO + "STAGE_INSTANCE_DELETE" => stage_instance.delete, + "TYPING_START" => user.typing_start, + "USER_UPDATE" => user.update, // TODO + "USER_GUILD_SETTINGS_UPDATE" => user.guild_settings_update, + "VOICE_STATE_UPDATE" => voice.state_update, // TODO + "VOICE_SERVER_UPDATE" => voice.server_update, + "WEBHOOKS_UPDATE" => webhooks.update + ); + } + // We received a heartbeat from the server + // "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately." + GATEWAY_HEARTBEAT => { + trace!("GW: Received Heartbeat // Heartbeat Request"); + + // Tell the heartbeat handler it should send a heartbeat right away + + let heartbeat_communication = HeartbeatThreadCommunication { + sequence_number: gateway_payload.sequence_number, + op_code: Some(GATEWAY_HEARTBEAT), + }; + + self.heartbeat_handler + .send + .send(heartbeat_communication) + .await + .unwrap(); + } + GATEWAY_RECONNECT => { + todo!() + } + GATEWAY_INVALID_SESSION => { + todo!() + } + // Starts our heartbeat + // We should have already handled this in gateway init + GATEWAY_HELLO => { + warn!("Received hello when it was unexpected"); + } + GATEWAY_HEARTBEAT_ACK => { + trace!("GW: Received Heartbeat ACK"); + + // Tell the heartbeat handler we received an ack + + let heartbeat_communication = HeartbeatThreadCommunication { + sequence_number: gateway_payload.sequence_number, + op_code: Some(GATEWAY_HEARTBEAT_ACK), + }; + + self.heartbeat_handler + .send + .send(heartbeat_communication) + .await + .unwrap(); + } + GATEWAY_IDENTIFY + | GATEWAY_UPDATE_PRESENCE + | GATEWAY_UPDATE_VOICE_STATE + | GATEWAY_RESUME + | GATEWAY_REQUEST_GUILD_MEMBERS + | GATEWAY_CALL_SYNC + | GATEWAY_LAZY_REQUEST => { + info!( + "Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation and is likely not the fault of chorus.", + gateway_payload.op_code + ); + } + _ => { + warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); + } + } + + // If we we received a seq number we should let it know + if let Some(seq_num) = gateway_payload.sequence_number { + let heartbeat_communication = HeartbeatThreadCommunication { + sequence_number: Some(seq_num), + // Op code is irrelevant here + op_code: None, + }; + + self.heartbeat_handler + .send + .send(heartbeat_communication) + .await + .unwrap(); + } + } } -pub trait GatewayHandleCapable +pub trait GatewayHandleCapable where - T: GatewayCapable, R: Stream, S: Sink, { diff --git a/src/instance.rs b/src/instance.rs index 72bf350..8dc9e2b 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -9,7 +9,7 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use crate::errors::ChorusResult; -use crate::gateway::{Gateway, GatewayHandle}; +use crate::gateway::{Gateway, GatewayCapable, GatewayHandle}; use crate::ratelimiter::ChorusRequest; use crate::types::types::subconfigs::limits::rates::RateLimits; use crate::types::{GeneralConfiguration, Limit, LimitType, User, UserSettings}; @@ -138,7 +138,7 @@ impl ChorusUser { let object = Arc::new(RwLock::new(User::default())); let wss_url = instance.read().unwrap().urls.wss.clone(); // Dummy gateway object - let gateway = Gateway::new(wss_url).await.unwrap(); + let gateway = Gateway::get_handle(wss_url).await.unwrap(); ChorusUser { token, belongs_to: instance.clone(),