diff --git a/examples/gateway_observers.rs b/examples/gateway_observers.rs index 0a54d31..562a366 100644 --- a/examples/gateway_observers.rs +++ b/examples/gateway_observers.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use chorus::gateway::GatewayCapable; use chorus::{ self, gateway::{Gateway, Observer}, @@ -30,7 +31,7 @@ async fn main() { let websocket_url_spacebar = "wss://gateway.old.server.spacebar.chat/".to_string(); // Initiate the gateway connection - let gateway = Gateway::new(websocket_url_spacebar).await.unwrap(); + let gateway = Gateway::get_handle(websocket_url_spacebar).await.unwrap(); // Create an instance of our observer let observer = ExampleObserver {}; diff --git a/examples/gateway_simple.rs b/examples/gateway_simple.rs index 276d95f..4549a2c 100644 --- a/examples/gateway_simple.rs +++ b/examples/gateway_simple.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use chorus::gateway::GatewayCapable; use chorus::{self, gateway::Gateway, types::GatewayIdentifyPayload}; use tokio::time::sleep; @@ -10,7 +11,7 @@ async fn main() { let websocket_url_spacebar = "wss://gateway.old.server.spacebar.chat/".to_string(); // Initiate the gateway connection, starting a listener in one thread and a heartbeat handler in another - let gateway = Gateway::new(websocket_url_spacebar).await.unwrap(); + let gateway = Gateway::get_handle(websocket_url_spacebar).await.unwrap(); // At this point, we are connected to the server and are sending heartbeats, however we still haven't authenticated diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index 15215ac..0c6b095 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -8,124 +8,25 @@ use crate::types::{ pub type GatewayStore = Arc>>>>; -#[derive(Debug)] -pub struct Gateway { - events: Arc>, - heartbeat_handler: HeartbeatHandler, - websocket_send: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - websocket_receive: SplitStream>>, - kill_send: tokio::sync::broadcast::Sender<()>, - store: GatewayStore, - url: String, -} - +#[allow(clippy::type_complexity)] #[async_trait] -impl - GatewayCapable< - WebSocketStream>, - WebSocketStream>, - GatewayHandle, - > for Gateway +pub trait GatewayCapable +where + R: Stream, + S: Sink, + G: GatewayHandleCapable, + H: HeartbeatHandlerCapable + Send + Sync, { - #[allow(clippy::new_ret_no_self)] - async fn get_handle(websocket_url: String) -> Result { - let mut roots = rustls::RootCertStore::empty(); - for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") - { - roots.add(&rustls::Certificate(cert.0)).unwrap(); - } - let (websocket_stream, _) = match connect_async_tls_with_config( - &websocket_url, - None, - false, - Some(Connector::Rustls( - rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(roots) - .with_no_client_auth() - .into(), - )), - ) - .await - { - Ok(websocket_stream) => websocket_stream, - Err(e) => { - return Err(GatewayError::CannotConnect { - error: e.to_string(), - }) - } - }; - - let (websocket_send, mut websocket_receive) = websocket_stream.split(); - - let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); - - // Create a shared broadcast channel for killing all gateway tasks - let (kill_send, mut _kill_receive) = tokio::sync::broadcast::channel::<()>(16); - - // Wait for the first hello and then spawn both tasks so we avoid nested tasks - // This automatically spawns the heartbeat task, but from the main thread - let msg = websocket_receive.next().await.unwrap().unwrap(); - let gateway_payload: types::GatewayReceivePayload = - serde_json::from_str(msg.to_text().unwrap()).unwrap(); - - if gateway_payload.op_code != GATEWAY_HELLO { - return Err(GatewayError::NonHelloOnInitiate { - opcode: gateway_payload.op_code, - }); - } - - info!("GW: Received Hello"); - - let gateway_hello: types::HelloData = - serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap(); - - let events = Events::default(); - let shared_events = Arc::new(Mutex::new(events)); - - let store = Arc::new(Mutex::new(HashMap::new())); - - let mut gateway = Gateway { - events: shared_events.clone(), - heartbeat_handler: HeartbeatHandler::new( - Duration::from_millis(gateway_hello.heartbeat_interval), - shared_websocket_send.clone(), - kill_send.subscribe(), - ), - websocket_send: shared_websocket_send.clone(), - websocket_receive, - kill_send: kill_send.clone(), - store: store.clone(), - url: websocket_url.clone(), - }; - - // Now we can continuously check for messages in a different task, since we aren't going to receive another hello - task::spawn(async move { - gateway.gateway_listen_task().await; - }); - - Ok(GatewayHandle { - url: websocket_url.clone(), - events: shared_events, - websocket_send: shared_websocket_send.clone(), - kill_send: kill_send.clone(), - store, - }) - } - - /// Closes the websocket connection and stops all tasks - async fn close(&mut self) { - self.kill_send.send(()).unwrap(); - self.websocket_send.lock().await.close().await.unwrap(); - } - + fn get_events(&self) -> Arc>; + fn get_websocket_send(&self) -> Arc>>; + fn get_store(&self) -> GatewayStore; + fn get_url(&self) -> String; + fn get_heartbeat_handler(&self) -> &H; + /// Returns a Result with a matching impl of [`GatewayHandleCapable`], or a [`GatewayError`] + /// + /// DOCUMENTME: Explain what this method has to do to be a good get_handle() impl, or link to such documentation + async fn get_handle(websocket_url: String) -> Result; + async fn close(&mut self); /// This handles a message as a websocket event and updates its events along with the events' observers async fn handle_message(&mut self, msg: GatewayMessage) { if msg.is_empty() { @@ -147,12 +48,16 @@ impl self.close().await; - self.events.lock().await.error.notify(error).await; + let events = self.get_events(); + let events = events.lock().await; + + events.error.notify(error).await; return; } let gateway_payload = msg.payload().unwrap(); + println!("gateway payload: {:#?}", &gateway_payload); // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes match gateway_payload.op_code { @@ -169,14 +74,16 @@ impl ($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => { match event_name.as_str() { $($name => { - let event = &mut self.events.lock().await.$($path).+; + let events = self.get_events(); + let event = &mut events.lock().await.$($path).+; let json = gateway_payload.event_data.unwrap().get(); match serde_json::from_str(json) { Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), Ok(message) => { $( let mut message: $message_type = message; - let store = self.store.lock().await; + let store = self.get_store(); + let store = store.lock().await; let id = if message.id().is_some() { message.id().unwrap() } else { @@ -195,7 +102,7 @@ impl let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<$update_type>).clone() }; drop(inner_object); message.set_json(json.to_string()); - message.set_source_url(self.url.clone()); + message.set_source_url(self.get_url().clone()); message.update(downcasted.clone()); } else { warn!("Received {} for {}, but it has been observed to be a different type!", $name, id) @@ -220,7 +127,9 @@ impl return; } Ok(sessions) => { - self.events.lock().await.session.replace.notify( + let events = self.get_events(); + let events = events.lock().await; + events.session.replace.notify( types::SessionsReplace {sessions} ).await; } @@ -320,8 +229,9 @@ impl op_code: Some(GATEWAY_HEARTBEAT), }; - self.heartbeat_handler - .send + let heartbeat_thread_communicator = self.get_heartbeat_handler().get_send(); + + heartbeat_thread_communicator .send(heartbeat_communication) .await .unwrap(); @@ -347,8 +257,10 @@ impl op_code: Some(GATEWAY_HEARTBEAT_ACK), }; - self.heartbeat_handler - .send + let heartbeat_handler = self.get_heartbeat_handler(); + let heartbeat_thread_communicator = heartbeat_handler.get_send(); + + heartbeat_thread_communicator .send(heartbeat_communication) .await .unwrap(); @@ -378,13 +290,138 @@ impl op_code: None, }; - self.heartbeat_handler - .send + let heartbeat_handler = self.get_heartbeat_handler(); + let heartbeat_thread_communicator = heartbeat_handler.get_send(); + heartbeat_thread_communicator .send(heartbeat_communication) .await .unwrap(); } } +} + +#[derive(Debug)] +pub struct Gateway { + events: Arc>, + heartbeat_handler: HeartbeatHandler, + websocket_send: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + websocket_receive: SplitStream>>, + kill_send: tokio::sync::broadcast::Sender<()>, + store: GatewayStore, + url: String, +} + +#[async_trait] +impl + GatewayCapable< + WebSocketStream>, + WebSocketStream>, + GatewayHandle, + HeartbeatHandler, + > for Gateway +{ + fn get_heartbeat_handler(&self) -> &HeartbeatHandler { + &self.heartbeat_handler + } + + #[allow(clippy::new_ret_no_self)] + async fn get_handle(websocket_url: String) -> Result { + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + roots.add(&rustls::Certificate(cert.0)).unwrap(); + } + let (websocket_stream, _) = match connect_async_tls_with_config( + &websocket_url, + None, + false, + Some(Connector::Rustls( + rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth() + .into(), + )), + ) + .await + { + Ok(websocket_stream) => websocket_stream, + Err(e) => { + return Err(GatewayError::CannotConnect { + error: e.to_string(), + }) + } + }; + + let (websocket_send, mut websocket_receive) = websocket_stream.split(); + + let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); + + // Create a shared broadcast channel for killing all gateway tasks + let (kill_send, mut _kill_receive) = tokio::sync::broadcast::channel::<()>(16); + + // Wait for the first hello and then spawn both tasks so we avoid nested tasks + // This automatically spawns the heartbeat task, but from the main thread + let msg = websocket_receive.next().await.unwrap().unwrap(); + let gateway_payload: types::GatewayReceivePayload = + serde_json::from_str(msg.to_text().unwrap()).unwrap(); + + if gateway_payload.op_code != GATEWAY_HELLO { + return Err(GatewayError::NonHelloOnInitiate { + opcode: gateway_payload.op_code, + }); + } + + info!("GW: Received Hello"); + + let gateway_hello: types::HelloData = + serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap(); + + let events = Events::default(); + let shared_events = Arc::new(Mutex::new(events)); + + let store = Arc::new(Mutex::new(HashMap::new())); + + let mut gateway = Gateway { + events: shared_events.clone(), + heartbeat_handler: HeartbeatHandler::new( + Duration::from_millis(gateway_hello.heartbeat_interval), + shared_websocket_send.clone(), + kill_send.subscribe(), + ), + websocket_send: shared_websocket_send.clone(), + websocket_receive, + kill_send: kill_send.clone(), + store: store.clone(), + url: websocket_url.clone(), + }; + + // Now we can continuously check for messages in a different task, since we aren't going to receive another hello + task::spawn(async move { + gateway.gateway_listen_task().await; + }); + + Ok(GatewayHandle { + url: websocket_url.clone(), + events: shared_events, + websocket_send: shared_websocket_send.clone(), + kill_send: kill_send.clone(), + store, + }) + } + + /// Closes the websocket connection and stops all tasks + async fn close(&mut self) { + self.kill_send.send(()).unwrap(); + self.websocket_send.lock().await.close().await.unwrap(); + } fn get_events(&self) -> Arc> { self.events.clone() @@ -415,7 +452,9 @@ impl Gateway { // This if chain can be much better but if let is unstable on stable rust if let Some(Ok(message)) = msg { - self.handle_message(GatewayMessage::from_tungstenite_message(message)); + let _ = self + .handle_message(GatewayMessage::from_tungstenite_message(message)) + .await; continue; } diff --git a/src/gateway/handle.rs b/src/gateway/handle.rs index 3a87c5a..8a18a5d 100644 --- a/src/gateway/handle.rs +++ b/src/gateway/handle.rs @@ -1,6 +1,13 @@ use super::{event::Events, *}; use crate::types::{self, Composite}; +pub trait GatewayHandleCapable +where + R: Stream, + S: Sink, +{ +} + /// Represents a handle to a Gateway connection. A Gateway connection will create observable /// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently /// implemented types with the trait [`WebSocketEvent`] diff --git a/src/gateway/heartbeat.rs b/src/gateway/heartbeat.rs index dd162b7..d0c43b3 100644 --- a/src/gateway/heartbeat.rs +++ b/src/gateway/heartbeat.rs @@ -5,10 +5,21 @@ use super::*; /// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; +pub trait HeartbeatHandlerCapable> { + fn new( + heartbeat_interval: Duration, + websocket_tx: Arc>>, + kill_rc: tokio::sync::broadcast::Receiver<()>, + ) -> Self; + + fn get_send(&self) -> &Sender; + fn get_heartbeat_interval(&self) -> Duration; +} + /// Handles sending heartbeats to the gateway in another thread #[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used #[derive(Debug)] -pub(super) struct HeartbeatHandler { +pub struct HeartbeatHandler { /// How ofter heartbeats need to be sent at a minimum pub heartbeat_interval: Duration, /// The send channel for the heartbeat thread @@ -17,8 +28,8 @@ pub(super) struct HeartbeatHandler { handle: JoinHandle<()>, } -impl HeartbeatHandler { - pub fn new( +impl HeartbeatHandlerCapable>> for HeartbeatHandler { + fn new( heartbeat_interval: Duration, websocket_tx: Arc< Mutex< @@ -50,6 +61,16 @@ impl HeartbeatHandler { } } + fn get_send(&self) -> &Sender { + &self.send + } + + fn get_heartbeat_interval(&self) -> Duration { + self.heartbeat_interval + } +} + +impl HeartbeatHandler { /// The main heartbeat task; /// /// Can be killed by the kill broadcast; @@ -141,9 +162,9 @@ impl HeartbeatHandler { /// Used for communications between the heartbeat and gateway thread. /// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server #[derive(Clone, Copy, Debug)] -pub(super) struct HeartbeatThreadCommunication { +pub struct HeartbeatThreadCommunication { /// The opcode for the communication we received, if relevant - pub(super) op_code: Option, + pub op_code: Option, /// The sequence number we got from discord, if any - pub(super) sequence_number: Option, + pub sequence_number: Option, } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 2b5d414..b5888b4 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,7 +10,7 @@ pub use message::*; use tokio_tungstenite::tungstenite::Message; use crate::errors::GatewayError; -use crate::types::{self, Snowflake, WebSocketEvent}; +use crate::types::{Snowflake, WebSocketEvent}; use async_trait::async_trait; use std::any::Any; @@ -35,8 +35,6 @@ use tokio::time::Instant; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; -use self::event::Events; - // Gateway opcodes /// Opcode received when the server dispatches a [crate::types::WebSocketEvent] const GATEWAY_DISPATCH: u8 = 0; @@ -138,290 +136,6 @@ impl GatewayEvent { } } -#[allow(clippy::type_complexity)] -#[async_trait] -pub trait GatewayCapable -where - R: Stream, - S: Sink, - G: GatewayHandleCapable, -{ - fn get_events(&self) -> Arc>; - fn get_websocket_send(&self) -> Arc>>; - fn get_store(&self) -> GatewayStore; - fn get_url(&self) -> String; - /// Returns a Result with a matching impl of [`GatewayHandleCapable`], or a [`GatewayError`] - /// - /// DOCUMENTME: Explain what this method has to do to be a good get_handle() impl, or link to such documentation - async fn get_handle(websocket_url: String) -> Result; - async fn close(&mut self); - async fn handle_message(&mut self, msg: GatewayMessage) { - if msg.is_empty() { - return; - } - - if !msg.is_error() && !msg.is_payload() { - warn!( - "Message unrecognised: {:?}, please open an issue on the chorus github", - msg.message.to_string() - ); - return; - } - - if msg.is_error() { - let error = msg.error().unwrap(); - - warn!("GW: Received error {:?}, connection will close..", error); - - self.close().await; - - self.get_events().lock().await.error.notify(error).await; - - return; - } - - let gateway_payload = msg.payload().unwrap(); - - // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes - match gateway_payload.op_code { - // An event was dispatched, we need to look at the gateway event name t - GATEWAY_DISPATCH => { - let Some(event_name) = gateway_payload.event_name else { - warn!("Gateway dispatch op without event_name"); - return; - }; - - trace!("Gateway: Received {event_name}"); - - macro_rules! handle { - ($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => { - match event_name.as_str() { - $($name => { - let event = &mut self.get_events().lock().await.$($path).+; - let json = gateway_payload.event_data.unwrap().get(); - match serde_json::from_str(json) { - Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), - Ok(message) => { - $( - let mut message: $message_type = message; - let store = self.get_store().lock().await; - let id = if message.id().is_some() { - message.id().unwrap() - } else { - event.notify(message).await; - return; - }; - if let Some(to_update) = store.get(&id) { - let object = to_update.clone(); - let inner_object = object.read().unwrap(); - if let Some(_) = inner_object.downcast_ref::<$update_type>() { - let ptr = Arc::into_raw(object.clone()); - // SAFETY: - // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. - // - This operation doesn't read or write any shared data, and thus cannot cause a data race - // - The reference count is not being modified - let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<$update_type>).clone() }; - drop(inner_object); - message.set_json(json.to_string()); - message.set_source_url(self.get_url().clone()); - message.update(downcasted.clone()); - } else { - warn!("Received {} for {}, but it has been observed to be a different type!", $name, id) - } - } - )? - event.notify(message).await; - } - } - },)* - "RESUMED" => (), - "SESSIONS_REPLACE" => { - let result: Result, serde_json::Error> = - serde_json::from_str(gateway_payload.event_data.unwrap().get()); - match result { - Err(err) => { - warn!( - "Failed to parse gateway event {} ({})", - event_name, - err - ); - return; - } - Ok(sessions) => { - self.events.lock().await.session.replace.notify( - types::SessionsReplace {sessions} - ).await; - } - } - }, - _ => { - warn!("Received unrecognized gateway event ({event_name})! Please open an issue on the chorus github so we can implement it"); - } - } - }; - } - - // See https://discord.com/developers/docs/topics/gateway-events#receive-events - // "Some" of these are undocumented - handle!( - "READY" => session.ready, - "READY_SUPPLEMENTAL" => session.ready_supplemental, - "APPLICATION_COMMAND_PERMISSIONS_UPDATE" => application.command_permissions_update, - "AUTO_MODERATION_RULE_CREATE" =>auto_moderation.rule_create, - "AUTO_MODERATION_RULE_UPDATE" =>auto_moderation.rule_update AutoModerationRuleUpdate: AutoModerationRule, - "AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete, - "AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution, - "CHANNEL_CREATE" => channel.create ChannelCreate: Guild, - "CHANNEL_UPDATE" => channel.update ChannelUpdate: Channel, - "CHANNEL_UNREAD_UPDATE" => channel.unread_update, - "CHANNEL_DELETE" => channel.delete ChannelDelete: Guild, - "CHANNEL_PINS_UPDATE" => channel.pins_update, - "CALL_CREATE" => call.create, - "CALL_UPDATE" => call.update, - "CALL_DELETE" => call.delete, - "THREAD_CREATE" => thread.create, // TODO - "THREAD_UPDATE" => thread.update ThreadUpdate: Channel, - "THREAD_DELETE" => thread.delete, // TODO - "THREAD_LIST_SYNC" => thread.list_sync, // TODO - "THREAD_MEMBER_UPDATE" => thread.member_update, // TODO - "THREAD_MEMBERS_UPDATE" => thread.members_update, // TODO - "GUILD_CREATE" => guild.create, // TODO - "GUILD_UPDATE" => guild.update, // TODO - "GUILD_DELETE" => guild.delete, // TODO - "GUILD_AUDIT_LOG_ENTRY_CREATE" => guild.audit_log_entry_create, - "GUILD_BAN_ADD" => guild.ban_add, // TODO - "GUILD_BAN_REMOVE" => guild.ban_remove, // TODO - "GUILD_EMOJIS_UPDATE" => guild.emojis_update, // TODO - "GUILD_STICKERS_UPDATE" => guild.stickers_update, // TODO - "GUILD_INTEGRATIONS_UPDATE" => guild.integrations_update, - "GUILD_MEMBER_ADD" => guild.member_add, - "GUILD_MEMBER_REMOVE" => guild.member_remove, - "GUILD_MEMBER_UPDATE" => guild.member_update, // TODO - "GUILD_MEMBERS_CHUNK" => guild.members_chunk, // TODO - "GUILD_ROLE_CREATE" => guild.role_create GuildRoleCreate: Guild, - "GUILD_ROLE_UPDATE" => guild.role_update GuildRoleUpdate: RoleObject, - "GUILD_ROLE_DELETE" => guild.role_delete, // TODO - "GUILD_SCHEDULED_EVENT_CREATE" => guild.role_scheduled_event_create, // TODO - "GUILD_SCHEDULED_EVENT_UPDATE" => guild.role_scheduled_event_update, // TODO - "GUILD_SCHEDULED_EVENT_DELETE" => guild.role_scheduled_event_delete, // TODO - "GUILD_SCHEDULED_EVENT_USER_ADD" => guild.role_scheduled_event_user_add, - "GUILD_SCHEDULED_EVENT_USER_REMOVE" => guild.role_scheduled_event_user_remove, - "PASSIVE_UPDATE_V1" => guild.passive_update_v1, // TODO - "INTEGRATION_CREATE" => integration.create, // TODO - "INTEGRATION_UPDATE" => integration.update, // TODO - "INTEGRATION_DELETE" => integration.delete, // TODO - "INTERACTION_CREATE" => interaction.create, // TODO - "INVITE_CREATE" => invite.create, // TODO - "INVITE_DELETE" => invite.delete, // TODO - "MESSAGE_CREATE" => message.create, - "MESSAGE_UPDATE" => message.update, // TODO - "MESSAGE_DELETE" => message.delete, - "MESSAGE_DELETE_BULK" => message.delete_bulk, - "MESSAGE_REACTION_ADD" => message.reaction_add, // TODO - "MESSAGE_REACTION_REMOVE" => message.reaction_remove, // TODO - "MESSAGE_REACTION_REMOVE_ALL" => message.reaction_remove_all, // TODO - "MESSAGE_REACTION_REMOVE_EMOJI" => message.reaction_remove_emoji, // TODO - "MESSAGE_ACK" => message.ack, - "PRESENCE_UPDATE" => user.presence_update, // TODO - "RELATIONSHIP_ADD" => relationship.add, - "RELATIONSHIP_REMOVE" => relationship.remove, - "STAGE_INSTANCE_CREATE" => stage_instance.create, - "STAGE_INSTANCE_UPDATE" => stage_instance.update, // TODO - "STAGE_INSTANCE_DELETE" => stage_instance.delete, - "TYPING_START" => user.typing_start, - "USER_UPDATE" => user.update, // TODO - "USER_GUILD_SETTINGS_UPDATE" => user.guild_settings_update, - "VOICE_STATE_UPDATE" => voice.state_update, // TODO - "VOICE_SERVER_UPDATE" => voice.server_update, - "WEBHOOKS_UPDATE" => webhooks.update - ); - } - // We received a heartbeat from the server - // "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately." - GATEWAY_HEARTBEAT => { - trace!("GW: Received Heartbeat // Heartbeat Request"); - - // Tell the heartbeat handler it should send a heartbeat right away - - let heartbeat_communication = HeartbeatThreadCommunication { - sequence_number: gateway_payload.sequence_number, - op_code: Some(GATEWAY_HEARTBEAT), - }; - - self.heartbeat_handler - .send - .send(heartbeat_communication) - .await - .unwrap(); - } - GATEWAY_RECONNECT => { - todo!() - } - GATEWAY_INVALID_SESSION => { - todo!() - } - // Starts our heartbeat - // We should have already handled this in gateway init - GATEWAY_HELLO => { - warn!("Received hello when it was unexpected"); - } - GATEWAY_HEARTBEAT_ACK => { - trace!("GW: Received Heartbeat ACK"); - - // Tell the heartbeat handler we received an ack - - let heartbeat_communication = HeartbeatThreadCommunication { - sequence_number: gateway_payload.sequence_number, - op_code: Some(GATEWAY_HEARTBEAT_ACK), - }; - - self.heartbeat_handler - .send - .send(heartbeat_communication) - .await - .unwrap(); - } - GATEWAY_IDENTIFY - | GATEWAY_UPDATE_PRESENCE - | GATEWAY_UPDATE_VOICE_STATE - | GATEWAY_RESUME - | GATEWAY_REQUEST_GUILD_MEMBERS - | GATEWAY_CALL_SYNC - | GATEWAY_LAZY_REQUEST => { - info!( - "Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation and is likely not the fault of chorus.", - gateway_payload.op_code - ); - } - _ => { - warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); - } - } - - // If we we received a seq number we should let it know - if let Some(seq_num) = gateway_payload.sequence_number { - let heartbeat_communication = HeartbeatThreadCommunication { - sequence_number: Some(seq_num), - // Op code is irrelevant here - op_code: None, - }; - - self.heartbeat_handler - .send - .send(heartbeat_communication) - .await - .unwrap(); - } - } -} - -pub trait GatewayHandleCapable -where - R: Stream, - S: Sink, -{ -} - #[cfg(test)] mod test { use crate::types; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index ce42578..7c9d652 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,6 +1,6 @@ use std::sync::{Arc, RwLock}; -use chorus::gateway::Gateway; +use chorus::gateway::{Gateway, GatewayCapable}; use chorus::{ instance::{ChorusUser, Instance}, types::{ @@ -43,7 +43,9 @@ impl TestBundle { limits: self.user.limits.clone(), settings: self.user.settings.clone(), object: self.user.object.clone(), - gateway: Gateway::new(self.instance.urls.wss.clone()).await.unwrap(), + gateway: Gateway::get_handle(self.instance.urls.wss.clone()) + .await + .unwrap(), } } } diff --git a/tests/gateway.rs b/tests/gateway.rs index 991c9f2..eb2dbc3 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -10,7 +10,7 @@ use chorus::types::{self, ChannelModifySchema, RoleCreateModifySchema, RoleObjec async fn test_gateway_establish() { let bundle = common::setup().await; - Gateway::new(bundle.urls.wss.clone()).await.unwrap(); + Gateway::get_handle(bundle.urls.wss.clone()).await.unwrap(); common::teardown(bundle).await } @@ -19,7 +19,7 @@ async fn test_gateway_establish() { async fn test_gateway_authenticate() { let bundle = common::setup().await; - let gateway = Gateway::new(bundle.urls.wss.clone()).await.unwrap(); + let gateway = Gateway::get_handle(bundle.urls.wss.clone()).await.unwrap(); let mut identify = types::GatewayIdentifyPayload::common(); identify.token = bundle.user.token.clone();