From a4d5ebb6895212f327101e0d0f07d9c91d082b00 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Sun, 19 Nov 2023 19:12:29 +0100 Subject: [PATCH] Resolve merge conflicts --- examples/gateway_observers.rs | 6 +- examples/gateway_simple.rs | 7 +- src/api/auth/login.rs | 5 +- src/api/auth/mod.rs | 6 +- src/api/auth/register.rs | 3 +- src/gateway/events.rs | 1 + src/gateway/gateway.rs | 11 +- src/gateway/handle.rs | 168 +++++++++ src/gateway/heartbeat.rs | 12 +- src/gateway/mod.rs | 575 +----------------------------- src/instance.rs | 4 +- src/lib.rs | 17 - src/types/entities/channel.rs | 5 +- src/types/entities/emoji.rs | 5 +- src/types/entities/guild.rs | 5 +- src/types/entities/mod.rs | 5 +- src/types/entities/role.rs | 5 +- src/types/entities/user.rs | 5 +- src/types/entities/voice_state.rs | 5 +- src/types/entities/webhook.rs | 5 +- tests/common/mod.rs | 4 +- tests/gateway.rs | 10 +- 22 files changed, 253 insertions(+), 616 deletions(-) create mode 100644 src/gateway/handle.rs diff --git a/examples/gateway_observers.rs b/examples/gateway_observers.rs index 7d49a1f..d4e690c 100644 --- a/examples/gateway_observers.rs +++ b/examples/gateway_observers.rs @@ -1,11 +1,9 @@ use async_trait::async_trait; -use chorus::gateway::{GatewayCapable, GatewayHandleCapable}; -use chorus::GatewayHandle; +use chorus::gateway::Gateway; use chorus::{ self, gateway::Observer, types::{GatewayIdentifyPayload, GatewayReady}, - Gateway, }; use std::{sync::Arc, time::Duration}; use tokio::{self, time::sleep}; @@ -33,7 +31,7 @@ async fn main() { let websocket_url_spacebar = "wss://gateway.old.server.spacebar.chat/".to_string(); // Initiate the gateway connection - let gateway: GatewayHandle = Gateway::spawn(websocket_url_spacebar).await.unwrap(); + let gateway = Gateway::spawn(websocket_url_spacebar).await.unwrap(); // Create an instance of our observer let observer = ExampleObserver {}; diff --git a/examples/gateway_simple.rs b/examples/gateway_simple.rs index 56c557a..a9c019b 100644 --- a/examples/gateway_simple.rs +++ b/examples/gateway_simple.rs @@ -1,8 +1,7 @@ use std::time::Duration; -use chorus::gateway::{GatewayCapable, GatewayHandleCapable}; -use chorus::GatewayHandle; -use chorus::{self, types::GatewayIdentifyPayload, Gateway}; +use chorus::gateway::Gateway; +use chorus::{self, types::GatewayIdentifyPayload}; use tokio::time::sleep; /// This example creates a simple gateway connection and a session with an Identify event @@ -12,7 +11,7 @@ async fn main() { let websocket_url_spacebar = "wss://gateway.old.server.spacebar.chat/".to_string(); // Initiate the gateway connection, starting a listener in one thread and a heartbeat handler in another - let gateway: GatewayHandle = Gateway::spawn(websocket_url_spacebar).await.unwrap(); + let gateway = Gateway::spawn(websocket_url_spacebar).await.unwrap(); // At this point, we are connected to the server and are sending heartbeats, however we still haven't authenticated diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 5eddfa3..1d9fc8a 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -4,11 +4,10 @@ use reqwest::Client; use serde_json::to_string; use crate::errors::ChorusResult; -use crate::gateway::{DefaultGatewayHandle, GatewayCapable, GatewayHandleCapable}; +use crate::gateway::Gateway; use crate::instance::{ChorusUser, Instance}; use crate::ratelimiter::ChorusRequest; use crate::types::{GatewayIdentifyPayload, LimitType, LoginResult, LoginSchema}; -use crate::Gateway; impl Instance { /// Logs into an existing account on the spacebar server. @@ -37,7 +36,7 @@ impl Instance { self.limits_information.as_mut().unwrap().ratelimits = shell.limits.clone().unwrap(); } let mut identify = GatewayIdentifyPayload::common(); - let gateway: DefaultGatewayHandle = Gateway::spawn(self.urls.wss.clone()).await.unwrap(); + let gateway = Gateway::spawn(self.urls.wss.clone()).await.unwrap(); identify.token = login_result.token.clone(); gateway.send_identify(identify).await; let user = ChorusUser::new( diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index a1490fc..ae3b219 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -3,10 +3,9 @@ use std::sync::{Arc, RwLock}; pub use login::*; pub use register::*; -use crate::gateway::{DefaultGatewayHandle, GatewayCapable, GatewayHandleCapable}; +use crate::gateway::Gateway; use crate::{ errors::ChorusResult, - gateway::DefaultGateway, instance::{ChorusUser, Instance}, types::{GatewayIdentifyPayload, User}, }; @@ -26,8 +25,7 @@ impl Instance { .await .unwrap(); let mut identify = GatewayIdentifyPayload::common(); - let gateway: DefaultGatewayHandle = - DefaultGateway::spawn(self.urls.wss.clone()).await.unwrap(); + let gateway = Gateway::spawn(self.urls.wss.clone()).await.unwrap(); identify.token = token.clone(); gateway.send_identify(identify).await; let user = ChorusUser::new( diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index 2ea6e7d..2ea7d57 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use reqwest::Client; use serde_json::to_string; -use crate::gateway::{GatewayCapable, GatewayHandleCapable}; +use crate::gateway::{Gateway, GatewayHandle}; use crate::types::GatewayIdentifyPayload; use crate::{ errors::ChorusResult, @@ -12,7 +12,6 @@ use crate::{ types::LimitType, types::RegisterSchema, }; -use crate::{Gateway, GatewayHandle}; impl Instance { /// Registers a new user on the server. diff --git a/src/gateway/events.rs b/src/gateway/events.rs index bc8565e..fdb7b25 100644 --- a/src/gateway/events.rs +++ b/src/gateway/events.rs @@ -1,4 +1,5 @@ use super::*; +use crate::types; #[derive(Default, Debug)] pub struct Events { diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index 38efe7c..9e3410c 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -1,7 +1,12 @@ +use std::time::Duration; + +use futures_util::{SinkExt, StreamExt}; +use log::*; +use tokio::task; + use self::event::Events; -use super::handle::GatewayHandle; -use super::heartbeat::HeartbeatHandler; use super::*; +use super::{WsSink, WsStream}; use crate::types::{ self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete, ChannelUpdate, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, SourceUrlField, @@ -21,7 +26,7 @@ pub struct Gateway { impl Gateway { #[allow(clippy::new_ret_no_self)] - pub async fn new(websocket_url: String) -> Result { + pub async fn spawn(websocket_url: String) -> Result { let (websocket_send, mut websocket_receive) = WebSocketBackend::connect(&websocket_url).await?; diff --git a/src/gateway/handle.rs b/src/gateway/handle.rs new file mode 100644 index 0000000..9a3c509 --- /dev/null +++ b/src/gateway/handle.rs @@ -0,0 +1,168 @@ +use futures_util::SinkExt; +use log::*; + +use std::fmt::Debug; + +use super::{event::Events, *}; +use crate::types::{self, Composite}; + +/// Represents a handle to a Gateway connection. A Gateway connection will create observable +/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently +/// implemented types with the trait [`WebSocketEvent`] +/// Using this handle you can also send Gateway Events directly. +#[derive(Debug, Clone)] +pub struct GatewayHandle { + pub url: String, + pub events: Arc>, + pub websocket_send: Arc>, + /// Tells gateway tasks to close + pub(super) kill_send: tokio::sync::broadcast::Sender<()>, + pub(crate) store: Arc>>>>, +} + +impl GatewayHandle { + /// Sends json to the gateway with an opcode + async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) { + let gateway_payload = types::GatewaySendPayload { + op_code, + event_data: Some(to_send), + sequence_number: None, + }; + + let payload_json = serde_json::to_string(&gateway_payload).unwrap(); + let message = GatewayMessage(payload_json); + + self.websocket_send + .lock() + .await + .send(message.into()) + .await + .unwrap(); + } + + pub async fn observe>( + &self, + object: Arc>, + ) -> Arc> { + let mut store = self.store.lock().await; + let id = object.read().unwrap().id(); + if let Some(channel) = store.get(&id) { + let object = channel.clone(); + drop(store); + object + .read() + .unwrap() + .downcast_ref::() + .unwrap_or_else(|| { + panic!( + "Snowflake {} already exists in the store, but it is not of type T.", + id + ) + }); + let ptr = Arc::into_raw(object.clone()); + // SAFETY: + // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. + // - This operation doesn't read or write any shared data, and thus cannot cause a data race + // - The reference count is not being modified + let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock).clone() }; + let object = downcasted.read().unwrap().clone(); + + let watched_object = object.watch_whole(self).await; + *downcasted.write().unwrap() = watched_object; + downcasted + } else { + let id = object.read().unwrap().id(); + let object = object.read().unwrap().clone(); + let object = object.clone().watch_whole(self).await; + let wrapped = Arc::new(RwLock::new(object)); + store.insert(id, wrapped.clone()); + wrapped + } + } + + /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` + /// with all of its observable fields being observed. + pub async fn observe_and_into_inner>( + &self, + object: Arc>, + ) -> T { + let channel = self.observe(object.clone()).await; + let object = channel.read().unwrap().clone(); + object + } + + /// Sends an identify event to the gateway + pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Identify.."); + + self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; + } + + /// Sends a resume event to the gateway + pub async fn send_resume(&self, to_send: types::GatewayResume) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Resume.."); + + self.send_json_event(GATEWAY_RESUME, to_send_value).await; + } + + /// Sends an update presence event to the gateway + pub async fn send_update_presence(&self, to_send: types::UpdatePresence) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Update Presence.."); + + self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) + .await; + } + + /// Sends a request guild members to the server + pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Request Guild Members.."); + + self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) + .await; + } + + /// Sends an update voice state to the server + pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { + let to_send_value = serde_json::to_value(to_send).unwrap(); + + trace!("GW: Sending Update Voice State.."); + + self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) + .await; + } + + /// Sends a call sync to the server + pub async fn send_call_sync(&self, to_send: types::CallSync) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Call Sync.."); + + self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; + } + + /// Sends a Lazy Request + pub async fn send_lazy_request(&self, to_send: types::LazyRequest) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Lazy Request.."); + + self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) + .await; + } + + /// Closes the websocket connection and stops all gateway tasks; + /// + /// Esentially pulls the plug on the gateway, leaving it possible to resume; + pub async fn close(&self) { + self.kill_send.send(()).unwrap(); + self.websocket_send.lock().await.close().await.unwrap(); + } +} diff --git a/src/gateway/heartbeat.rs b/src/gateway/heartbeat.rs index aeb3dae..a5875a4 100644 --- a/src/gateway/heartbeat.rs +++ b/src/gateway/heartbeat.rs @@ -1,3 +1,11 @@ +use futures_util::SinkExt; +use log::*; +use std::time::{self, Duration, Instant}; +use tokio::sync::mpsc::{Receiver, Sender}; + +use safina_timer::sleep_until; +use tokio::task::{self, JoinHandle}; + use super::*; use crate::types; @@ -43,13 +51,15 @@ impl HeartbeatHandler { pub async fn heartbeat_task( websocket_tx: Arc>, heartbeat_interval: Duration, - mut receive: tokio::sync::mpsc::Receiver, + mut receive: Receiver, mut kill_receive: tokio::sync::broadcast::Receiver<()>, ) { let mut last_heartbeat_timestamp: Instant = time::Instant::now(); let mut last_heartbeat_acknowledged = true; let mut last_seq_number: Option = None; + safina_timer::start_timer_thread(); + loop { if kill_receive.try_recv().is_ok() { trace!("GW: Closing heartbeat task"); diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 9ab4f9d..269a274 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,38 +1,32 @@ +use async_trait::async_trait; + +pub mod backend_tungstenite; pub mod events; +pub mod gateway; +pub mod handle; +pub mod heartbeat; pub mod message; -#[cfg(not(wasm))] -pub mod backend_tungstenite; -#[cfg(not(wasm))] -use backend_tungstenite::*; - +pub use gateway::*; +pub use handle::*; +use heartbeat::*; pub use message::*; -use safina_timer::sleep_until; -use tokio::task::{self, JoinHandle}; -use self::events::Events; use crate::errors::GatewayError; -use crate::types::{ - self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete, - ChannelUpdate, Composite, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, - Snowflake, SourceUrlField, ThreadUpdate, UpdateMessage, WebSocketEvent, -}; +use crate::types::{Snowflake, WebSocketEvent}; use std::any::Any; use std::collections::HashMap; -use std::marker::PhantomData; use std::sync::{Arc, RwLock}; -use std::time::{self, Duration, Instant}; -use futures_util::SinkExt; -use log::{info, trace, warn}; -use tokio::sync::mpsc::Sender; use tokio::sync::Mutex; -pub type GatewayStore = Arc>>>>; - -/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms -const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; +#[cfg(not(target_arch = "wasm32"))] +pub type WsSink = backend_tungstenite::WsSink; +#[cfg(not(target_arch = "wasm32"))] +pub type WsStream = backend_tungstenite::WsStream; +#[cfg(not(target_arch = "wasm32"))] +pub type WebSocketBackend = backend_tungstenite::WebSocketBackend; // Gateway opcodes /// Opcode received when the server dispatches a [crate::types::WebSocketEvent] @@ -82,25 +76,8 @@ const GATEWAY_CALL_SYNC: u8 = 13; /// See [types::LazyRequest] const GATEWAY_LAZY_REQUEST: u8 = 14; -pub trait MessageCapable { - fn as_string(&self) -> Option; - fn as_bytes(&self) -> Option>; - fn is_empty(&self) -> bool; - fn from_str(s: &str) -> Self; -} - pub type ObservableObject = dyn Send + Sync + Any; -/// Used for communications between the heartbeat and gateway thread. -/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server -#[derive(Clone, Copy, Debug)] -pub struct HeartbeatThreadCommunication { - /// The opcode for the communication we received, if relevant - pub op_code: Option, - /// The sequence number we got from discord, if any - pub sequence_number: Option, -} - /// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. pub trait Updateable: 'static + Send + Sync { fn id(&self) -> Snowflake; @@ -151,523 +128,3 @@ impl GatewayEvent { } } } - -#[async_trait] -pub trait GatewayCapable -where - T: MessageCapable + Send + 'static, - S: Sink + Send, -{ - fn get_events(&self) -> Arc>; - fn get_websocket_send(&self) -> Arc>>; - fn get_store(&self) -> GatewayStore; - fn get_url(&self) -> String; - fn get_heartbeat_handler(&self) -> &HeartbeatHandler; - /// Returns a Result with a matching impl of [`GatewayHandleCapable`], or a [`GatewayError`] - /// - /// DOCUMENTME: Explain what this method has to do to be a good get_handle() impl, or link to such documentation - /// TODO: Give spawn a default trait impl, avoid code duplication - async fn spawn>(websocket_url: String) - -> Result; - async fn close(&mut self); - /// This handles a message as a websocket event and updates its events along with the events' observers - async fn handle_message(&mut self, msg: GatewayMessage) { - if msg.is_empty() { - return; - } - - if !msg.is_error() && !msg.is_payload() { - warn!( - "Message unrecognised: {:?}, please open an issue on the chorus github", - msg.message.to_string() - ); - return; - } - - if msg.is_error() { - let error = msg.error().unwrap(); - - warn!("GW: Received error {:?}, connection will close..", error); - - self.close().await; - - let events = self.get_events(); - let events = events.lock().await; - - events.error.notify(error).await; - - return; - } - - let gateway_payload = msg.payload().unwrap(); - println!("gateway payload: {:#?}", &gateway_payload); - - // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes - match gateway_payload.op_code { - // An event was dispatched, we need to look at the gateway event name t - GATEWAY_DISPATCH => { - let Some(event_name) = gateway_payload.event_name else { - warn!("Gateway dispatch op without event_name"); - return; - }; - - trace!("Gateway: Received {event_name}"); - - macro_rules! handle { - ($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => { - match event_name.as_str() { - $($name => { - let events = self.get_events(); - let event = &mut events.lock().await.$($path).+; - let json = gateway_payload.event_data.unwrap().get(); - match serde_json::from_str(json) { - Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), - Ok(message) => { - $( - let mut message: $message_type = message; - let store = self.get_store(); - let store = store.lock().await; - let id = if message.id().is_some() { - message.id().unwrap() - } else { - event.notify(message).await; - return; - }; - if let Some(to_update) = store.get(&id) { - let object = to_update.clone(); - let inner_object = object.read().unwrap(); - if let Some(_) = inner_object.downcast_ref::<$update_type>() { - let ptr = Arc::into_raw(object.clone()); - // SAFETY: - // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. - // - This operation doesn't read or write any shared data, and thus cannot cause a data race - // - The reference count is not being modified - let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<$update_type>).clone() }; - drop(inner_object); - message.set_json(json.to_string()); - message.set_source_url(self.get_url().clone()); - message.update(downcasted.clone()); - } else { - warn!("Received {} for {}, but it has been observed to be a different type!", $name, id) - } - } - )? - event.notify(message).await; - } - } - },)* - "RESUMED" => (), - "SESSIONS_REPLACE" => { - let result: Result, serde_json::Error> = - serde_json::from_str(gateway_payload.event_data.unwrap().get()); - match result { - Err(err) => { - warn!( - "Failed to parse gateway event {} ({})", - event_name, - err - ); - return; - } - Ok(sessions) => { - let events = self.get_events(); - let events = events.lock().await; - events.session.replace.notify( - types::SessionsReplace {sessions} - ).await; - } - } - }, - _ => { - warn!("Received unrecognized gateway event ({event_name})! Please open an issue on the chorus github so we can implement it"); - } - } - }; - } - - // See https://discord.com/developers/docs/topics/gateway-events#receive-events - // "Some" of these are undocumented - handle!( - "READY" => session.ready, - "READY_SUPPLEMENTAL" => session.ready_supplemental, - "APPLICATION_COMMAND_PERMISSIONS_UPDATE" => application.command_permissions_update, - "AUTO_MODERATION_RULE_CREATE" =>auto_moderation.rule_create, - "AUTO_MODERATION_RULE_UPDATE" =>auto_moderation.rule_update AutoModerationRuleUpdate: AutoModerationRule, - "AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete, - "AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution, - "CHANNEL_CREATE" => channel.create ChannelCreate: Guild, - "CHANNEL_UPDATE" => channel.update ChannelUpdate: Channel, - "CHANNEL_UNREAD_UPDATE" => channel.unread_update, - "CHANNEL_DELETE" => channel.delete ChannelDelete: Guild, - "CHANNEL_PINS_UPDATE" => channel.pins_update, - "CALL_CREATE" => call.create, - "CALL_UPDATE" => call.update, - "CALL_DELETE" => call.delete, - "THREAD_CREATE" => thread.create, // TODO - "THREAD_UPDATE" => thread.update ThreadUpdate: Channel, - "THREAD_DELETE" => thread.delete, // TODO - "THREAD_LIST_SYNC" => thread.list_sync, // TODO - "THREAD_MEMBER_UPDATE" => thread.member_update, // TODO - "THREAD_MEMBERS_UPDATE" => thread.members_update, // TODO - "GUILD_CREATE" => guild.create, // TODO - "GUILD_UPDATE" => guild.update, // TODO - "GUILD_DELETE" => guild.delete, // TODO - "GUILD_AUDIT_LOG_ENTRY_CREATE" => guild.audit_log_entry_create, - "GUILD_BAN_ADD" => guild.ban_add, // TODO - "GUILD_BAN_REMOVE" => guild.ban_remove, // TODO - "GUILD_EMOJIS_UPDATE" => guild.emojis_update, // TODO - "GUILD_STICKERS_UPDATE" => guild.stickers_update, // TODO - "GUILD_INTEGRATIONS_UPDATE" => guild.integrations_update, - "GUILD_MEMBER_ADD" => guild.member_add, - "GUILD_MEMBER_REMOVE" => guild.member_remove, - "GUILD_MEMBER_UPDATE" => guild.member_update, // TODO - "GUILD_MEMBERS_CHUNK" => guild.members_chunk, // TODO - "GUILD_ROLE_CREATE" => guild.role_create GuildRoleCreate: Guild, - "GUILD_ROLE_UPDATE" => guild.role_update GuildRoleUpdate: RoleObject, - "GUILD_ROLE_DELETE" => guild.role_delete, // TODO - "GUILD_SCHEDULED_EVENT_CREATE" => guild.role_scheduled_event_create, // TODO - "GUILD_SCHEDULED_EVENT_UPDATE" => guild.role_scheduled_event_update, // TODO - "GUILD_SCHEDULED_EVENT_DELETE" => guild.role_scheduled_event_delete, // TODO - "GUILD_SCHEDULED_EVENT_USER_ADD" => guild.role_scheduled_event_user_add, - "GUILD_SCHEDULED_EVENT_USER_REMOVE" => guild.role_scheduled_event_user_remove, - "PASSIVE_UPDATE_V1" => guild.passive_update_v1, // TODO - "INTEGRATION_CREATE" => integration.create, // TODO - "INTEGRATION_UPDATE" => integration.update, // TODO - "INTEGRATION_DELETE" => integration.delete, // TODO - "INTERACTION_CREATE" => interaction.create, // TODO - "INVITE_CREATE" => invite.create, // TODO - "INVITE_DELETE" => invite.delete, // TODO - "MESSAGE_CREATE" => message.create, - "MESSAGE_UPDATE" => message.update, // TODO - "MESSAGE_DELETE" => message.delete, - "MESSAGE_DELETE_BULK" => message.delete_bulk, - "MESSAGE_REACTION_ADD" => message.reaction_add, // TODO - "MESSAGE_REACTION_REMOVE" => message.reaction_remove, // TODO - "MESSAGE_REACTION_REMOVE_ALL" => message.reaction_remove_all, // TODO - "MESSAGE_REACTION_REMOVE_EMOJI" => message.reaction_remove_emoji, // TODO - "MESSAGE_ACK" => message.ack, - "PRESENCE_UPDATE" => user.presence_update, // TODO - "RELATIONSHIP_ADD" => relationship.add, - "RELATIONSHIP_REMOVE" => relationship.remove, - "STAGE_INSTANCE_CREATE" => stage_instance.create, - "STAGE_INSTANCE_UPDATE" => stage_instance.update, // TODO - "STAGE_INSTANCE_DELETE" => stage_instance.delete, - "TYPING_START" => user.typing_start, - "USER_UPDATE" => user.update, // TODO - "USER_GUILD_SETTINGS_UPDATE" => user.guild_settings_update, - "VOICE_STATE_UPDATE" => voice.state_update, // TODO - "VOICE_SERVER_UPDATE" => voice.server_update, - "WEBHOOKS_UPDATE" => webhooks.update - ); - } - // We received a heartbeat from the server - // "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately." - GATEWAY_HEARTBEAT => { - trace!("GW: Received Heartbeat // Heartbeat Request"); - - // Tell the heartbeat handler it should send a heartbeat right away - - let heartbeat_communication = HeartbeatThreadCommunication { - sequence_number: gateway_payload.sequence_number, - op_code: Some(GATEWAY_HEARTBEAT), - }; - - let heartbeat_thread_communicator = &self.get_heartbeat_handler().send; - - heartbeat_thread_communicator - .send(heartbeat_communication) - .await - .unwrap(); - } - GATEWAY_RECONNECT => { - todo!() - } - GATEWAY_INVALID_SESSION => { - todo!() - } - // Starts our heartbeat - // We should have already handled this in gateway init - GATEWAY_HELLO => { - warn!("Received hello when it was unexpected"); - } - GATEWAY_HEARTBEAT_ACK => { - trace!("GW: Received Heartbeat ACK"); - - // Tell the heartbeat handler we received an ack - - let heartbeat_communication = HeartbeatThreadCommunication { - sequence_number: gateway_payload.sequence_number, - op_code: Some(GATEWAY_HEARTBEAT_ACK), - }; - - let heartbeat_handler = self.get_heartbeat_handler(); - let heartbeat_thread_communicator = &heartbeat_handler.send; - - heartbeat_thread_communicator - .send(heartbeat_communication) - .await - .unwrap(); - } - GATEWAY_IDENTIFY - | GATEWAY_UPDATE_PRESENCE - | GATEWAY_UPDATE_VOICE_STATE - | GATEWAY_RESUME - | GATEWAY_REQUEST_GUILD_MEMBERS - | GATEWAY_CALL_SYNC - | GATEWAY_LAZY_REQUEST => { - info!( - "Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation and is likely not the fault of chorus.", - gateway_payload.op_code - ); - } - _ => { - warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); - } - } - - // If we we received a seq number we should let it know - if let Some(seq_num) = gateway_payload.sequence_number { - let heartbeat_communication = HeartbeatThreadCommunication { - sequence_number: Some(seq_num), - // Op code is irrelevant here - op_code: None, - }; - - let heartbeat_handler = self.get_heartbeat_handler(); - let heartbeat_thread_communicator = &heartbeat_handler.send; - heartbeat_thread_communicator - .send(heartbeat_communication) - .await - .unwrap(); - } - } -} - -#[async_trait(?Send)] -pub trait GatewayHandleCapable -where - T: MessageCapable + Send + 'static, - S: Sink, -{ - fn new( - url: String, - events: Arc>, - websocket_send: Arc>>, - kill_send: tokio::sync::broadcast::Sender<()>, - store: GatewayStore, - ) -> Self; - - /// Sends json to the gateway with an opcode - async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value); - - /// Observes an Item ``, which will update itself, if new information about this - /// item arrives on the corresponding Gateway Thread - async fn observe + Send + Sync>( - &self, - object: Arc>, - ) -> Arc>; - - /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` - /// with all of its observable fields being observed. - async fn observe_and_into_inner>( - &self, - object: Arc>, - ) -> U { - let channel = self.observe(object.clone()).await; - let object = channel.read().unwrap().clone(); - object - } - - /// Sends an identify event to the gateway - async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Identify.."); - - self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; - } - - /// Sends an update presence event to the gateway - async fn send_update_presence(&self, to_send: types::UpdatePresence) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Update Presence.."); - - self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) - .await; - } - - /// Sends a resume event to the gateway - async fn send_resume(&self, to_send: types::GatewayResume) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Resume.."); - - self.send_json_event(GATEWAY_RESUME, to_send_value).await; - } - - /// Sends a request guild members to the server - async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Request Guild Members.."); - - self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) - .await; - } - - /// Sends an update voice state to the server - async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { - let to_send_value = serde_json::to_value(to_send).unwrap(); - - trace!("GW: Sending Update Voice State.."); - - self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) - .await; - } - - /// Sends a call sync to the server - async fn send_call_sync(&self, to_send: types::CallSync) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Call Sync.."); - - self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; - } - - /// Sends a Lazy Request - async fn send_lazy_request(&self, to_send: types::LazyRequest) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Lazy Request.."); - - self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) - .await; - } - - /// Closes the websocket connection and stops all gateway tasks; - /// - /// Esentially pulls the plug on the gateway, leaving it possible to resume; - async fn close(&self); -} - -/// Handles sending heartbeats to the gateway in another thread -#[derive(Debug)] -pub struct HeartbeatHandler> { - /// How ofter heartbeats need to be sent at a minimum - pub heartbeat_interval: Duration, - /// The send channel for the heartbeat thread - pub send: Sender, - /// The handle of the thread - handle: JoinHandle<()>, - hb_type: (PhantomData, PhantomData), -} - -impl + Send + 'static> HeartbeatHandler { - pub async fn heartbeat_task( - websocket_tx: Arc>>, - heartbeat_interval: Duration, - mut receive: tokio::sync::mpsc::Receiver, - mut kill_receive: tokio::sync::broadcast::Receiver<()>, - ) { - let mut last_heartbeat_timestamp: Instant = time::Instant::now(); - let mut last_heartbeat_acknowledged = true; - let mut last_seq_number: Option = None; - safina_timer::start_timer_thread(); - - loop { - if kill_receive.try_recv().is_ok() { - trace!("GW: Closing heartbeat task"); - break; - } - - let timeout = if last_heartbeat_acknowledged { - heartbeat_interval - } else { - // If the server hasn't acknowledged our heartbeat we should resend it - Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) - }; - - let mut should_send = false; - - tokio::select! { - () = sleep_until(last_heartbeat_timestamp + timeout) => { - should_send = true; - } - Some(communication) = receive.recv() => { - // If we received a seq number update, use that as the last seq number - if communication.sequence_number.is_some() { - last_seq_number = communication.sequence_number; - } - - if let Some(op_code) = communication.op_code { - match op_code { - GATEWAY_HEARTBEAT => { - // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately - should_send = true; - } - GATEWAY_HEARTBEAT_ACK => { - // The server received our heartbeat - last_heartbeat_acknowledged = true; - } - _ => {} - } - } - } - } - - if should_send { - trace!("GW: Sending Heartbeat.."); - - let heartbeat = types::GatewayHeartbeat { - op: GATEWAY_HEARTBEAT, - d: last_seq_number, - }; - - let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); - - let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); - - let send_result = websocket_tx - .lock() - .await - .send(MessageCapable::from_str(msg.to_string().as_str())) - .await; - if send_result.is_err() { - // We couldn't send, the websocket is broken - warn!("GW: Couldnt send heartbeat, websocket seems broken"); - break; - } - - last_heartbeat_timestamp = time::Instant::now(); - last_heartbeat_acknowledged = false; - } - } - } - - fn new( - heartbeat_interval: Duration, - websocket_tx: Arc>>, - kill_rc: tokio::sync::broadcast::Receiver<()>, - ) -> HeartbeatHandler { - let (send, receive) = tokio::sync::mpsc::channel(32); - let kill_receive = kill_rc.resubscribe(); - - let handle: JoinHandle<()> = task::spawn(async move { - HeartbeatHandler::heartbeat_task( - websocket_tx, - heartbeat_interval, - receive, - kill_receive, - ) - .await; - }); - - Self { - heartbeat_interval, - send, - handle, - hb_type: (PhantomData::, PhantomData::), - } - } -} diff --git a/src/instance.rs b/src/instance.rs index 8c462a6..4ce4338 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -9,11 +9,11 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use crate::errors::ChorusResult; -use crate::gateway::GatewayCapable; +use crate::gateway::{Gateway, GatewayHandle}; use crate::ratelimiter::ChorusRequest; use crate::types::types::subconfigs::limits::rates::RateLimits; use crate::types::{GeneralConfiguration, Limit, LimitType, User, UserSettings}; -use crate::{Gateway, GatewayHandle, UrlBundle}; +use crate::UrlBundle; #[derive(Debug, Clone, Default)] /// The [`Instance`]; what you will be using to perform all sorts of actions on the Spacebar server. diff --git a/src/lib.rs b/src/lib.rs index 510a727..47bbaab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,23 +17,6 @@ #[cfg(all(feature = "rt", feature = "rt_multi_thread"))] compile_error!("feature \"rt\" and feature \"rt_multi_thread\" cannot be enabled at the same time"); -#[cfg(all(not(target_arch = "wasm32"), feature = "client"))] -pub type Gateway = DefaultGateway; -#[cfg(all(target_arch = "wasm32", feature = "client"))] -pub type Gateway = WasmGateway; -#[cfg(all(not(target_arch = "wasm32"), feature = "client"))] -pub type GatewayHandle = DefaultGatewayHandle; -#[cfg(all(target_arch = "wasm32", feature = "client"))] -pub type GatewayHandle = WasmGatewayHandle; - -#[cfg(all(not(target_arch = "wasm32"), feature = "client"))] -use gateway::DefaultGateway; -#[cfg(all(not(target_arch = "wasm32"), feature = "client"))] -use gateway::DefaultGatewayHandle; -#[cfg(all(target_arch = "wasm32", feature = "client"))] -use gateway::WasmGateway; -#[cfg(all(target_arch = "wasm32", feature = "client"))] -use gateway::WasmGatewayHandle; use url::{ParseError, Url}; #[cfg(feature = "client")] diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index 1d1c58c..5acf2ae 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -12,7 +12,10 @@ use crate::types::{ }; #[cfg(feature = "client")] -use crate::{types::Composite, GatewayHandle}; +use crate::types::Composite; + +#[cfg(feature = "client")] +use crate::gateway::GatewayHandle; #[cfg(feature = "client")] use crate::gateway::Updateable; diff --git a/src/types/entities/emoji.rs b/src/types/entities/emoji.rs index 8c0b8e6..b3916e2 100644 --- a/src/types/entities/emoji.rs +++ b/src/types/entities/emoji.rs @@ -7,7 +7,10 @@ use crate::types::entities::User; use crate::types::Snowflake; #[cfg(feature = "client")] -use crate::{types::Composite, GatewayHandle}; +use crate::gateway::GatewayHandle; + +#[cfg(feature = "client")] +use crate::types::Composite; #[cfg(feature = "client")] use crate::gateway::Updateable; diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index 8638ff7..eb04322 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -16,7 +16,7 @@ use crate::types::{ use super::PublicUser; #[cfg(feature = "client")] -use crate::{gateway::Updateable, GatewayHandle}; +use crate::gateway::Updateable; #[cfg(feature = "client")] use chorus_macros::{observe_option_vec, observe_vec, Composite, Updateable}; @@ -24,6 +24,9 @@ use chorus_macros::{observe_option_vec, observe_vec, Composite, Updateable}; #[cfg(feature = "client")] use crate::types::Composite; +#[cfg(feature = "client")] +use crate::gateway::GatewayHandle; + /// See #[derive(Serialize, Deserialize, Debug, Default, Clone)] #[cfg_attr(feature = "client", derive(Updateable, Composite))] diff --git a/src/types/entities/mod.rs b/src/types/entities/mod.rs index 3371598..8343628 100644 --- a/src/types/entities/mod.rs +++ b/src/types/entities/mod.rs @@ -24,7 +24,10 @@ pub use voice_state::*; pub use webhook::*; #[cfg(feature = "client")] -use crate::{gateway::Updateable, GatewayHandle}; +use crate::gateway::Updateable; + +#[cfg(feature = "client")] +use crate::gateway::GatewayHandle; #[cfg(feature = "client")] use async_trait::async_trait; diff --git a/src/types/entities/role.rs b/src/types/entities/role.rs index 1166431..6a8327e 100644 --- a/src/types/entities/role.rs +++ b/src/types/entities/role.rs @@ -12,7 +12,10 @@ use chorus_macros::{Composite, Updateable}; use crate::gateway::Updateable; #[cfg(feature = "client")] -use crate::{types::Composite, GatewayHandle}; +use crate::types::Composite; + +#[cfg(feature = "client")] +use crate::gateway::GatewayHandle; #[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq)] #[cfg_attr(feature = "client", derive(Updateable, Composite))] diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index e247240..a7bdf63 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -8,7 +8,10 @@ use std::fmt::Debug; use crate::gateway::Updateable; #[cfg(feature = "client")] -use crate::{types::Composite, GatewayHandle}; +use crate::types::Composite; + +#[cfg(feature = "client")] +use crate::gateway::GatewayHandle; #[cfg(feature = "client")] use chorus_macros::{Composite, Updateable}; diff --git a/src/types/entities/voice_state.rs b/src/types/entities/voice_state.rs index 3f122d6..1a0c3b3 100644 --- a/src/types/entities/voice_state.rs +++ b/src/types/entities/voice_state.rs @@ -4,7 +4,10 @@ use std::sync::{Arc, RwLock}; use chorus_macros::Composite; #[cfg(feature = "client")] -use crate::{types::Composite, GatewayHandle}; +use crate::types::Composite; + +#[cfg(feature = "client")] +use crate::gateway::GatewayHandle; #[cfg(feature = "client")] use crate::gateway::Updateable; diff --git a/src/types/entities/webhook.rs b/src/types/entities/webhook.rs index 3d8c687..cf5716c 100644 --- a/src/types/entities/webhook.rs +++ b/src/types/entities/webhook.rs @@ -10,7 +10,10 @@ use crate::gateway::Updateable; use chorus_macros::{Composite, Updateable}; #[cfg(feature = "client")] -use crate::{types::Composite, GatewayHandle}; +use crate::types::Composite; + +#[cfg(feature = "client")] +use crate::gateway::GatewayHandle; use crate::types::{ entities::{Guild, User}, diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 76f20c7..b533fd2 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,6 +1,6 @@ use std::sync::{Arc, RwLock}; -use chorus::gateway::{DefaultGateway, GatewayCapable}; +use chorus::gateway::Gateway; use chorus::{ instance::{ChorusUser, Instance}, types::{ @@ -43,7 +43,7 @@ impl TestBundle { limits: self.user.limits.clone(), settings: self.user.settings.clone(), object: self.user.object.clone(), - gateway: DefaultGateway::spawn(self.instance.urls.wss.clone()) + gateway: Gateway::spawn(self.instance.urls.wss.clone()) .await .unwrap(), } diff --git a/tests/gateway.rs b/tests/gateway.rs index 2303707..0b1e12f 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -2,17 +2,15 @@ mod common; use std::sync::{Arc, RwLock}; +use chorus::gateway::*; use chorus::types::{self, ChannelModifySchema, RoleCreateModifySchema, RoleObject}; -use chorus::{gateway::*, GatewayHandle}; #[tokio::test] /// Tests establishing a connection (hello and heartbeats) on the local gateway; async fn test_gateway_establish() { let bundle = common::setup().await; - let _: GatewayHandle = DefaultGateway::spawn(bundle.urls.wss.clone()) - .await - .unwrap(); + let _: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone()).await.unwrap(); common::teardown(bundle).await } @@ -21,9 +19,7 @@ async fn test_gateway_establish() { async fn test_gateway_authenticate() { let bundle = common::setup().await; - let gateway: GatewayHandle = DefaultGateway::spawn(bundle.urls.wss.clone()) - .await - .unwrap(); + let gateway: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone()).await.unwrap(); let mut identify = types::GatewayIdentifyPayload::common(); identify.token = bundle.user.token.clone();