diff --git a/src/api/invites/mod.rs b/src/api/invites/mod.rs index 332570b..80b47d2 100644 --- a/src/api/invites/mod.rs +++ b/src/api/invites/mod.rs @@ -28,11 +28,11 @@ impl ChorusUser { .header("Authorization", self.token()), limit_type: LimitType::Global, }; - if session_id.is_some() { + if let Some(session_id) = session_id { request.request = request .request .header("Content-Type", "application/json") - .body(to_string(session_id.unwrap()).unwrap()); + .body(to_string(session_id).unwrap()); } request.deserialize_response::(self).await } diff --git a/src/gateway.rs b/src/gateway/gateway.rs similarity index 55% rename from src/gateway.rs rename to src/gateway/gateway.rs index edd402d..30d0610 100644 --- a/src/gateway.rs +++ b/src/gateway/gateway.rs @@ -1,331 +1,10 @@ -//! Gateway connection, communication and handling, as well as object caching and updating. - -use crate::errors::GatewayError; -use crate::gateway::events::Events; +use self::event::Events; +use super::*; use crate::types::{ self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete, - ChannelUpdate, Composite, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, - Snowflake, SourceUrlField, ThreadUpdate, UpdateMessage, WebSocketEvent, + ChannelUpdate, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, SourceUrlField, + ThreadUpdate, UpdateMessage, WebSocketEvent, }; -use async_trait::async_trait; -use std::any::Any; -use std::collections::HashMap; -use std::fmt::Debug; -use std::sync::{Arc, RwLock}; -use std::time::Duration; -use tokio::time::sleep_until; - -use futures_util::stream::SplitSink; -use futures_util::stream::SplitStream; -use futures_util::SinkExt; -use futures_util::StreamExt; -use log::{info, trace, warn}; -use tokio::net::TcpStream; -use tokio::sync::mpsc::Sender; -use tokio::sync::Mutex; -use tokio::task; -use tokio::task::JoinHandle; -use tokio::time; -use tokio::time::Instant; -use tokio_tungstenite::MaybeTlsStream; -use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; - -// Gateway opcodes -/// Opcode received when the server dispatches a [crate::types::WebSocketEvent] -const GATEWAY_DISPATCH: u8 = 0; -/// Opcode sent when sending a heartbeat -const GATEWAY_HEARTBEAT: u8 = 1; -/// Opcode sent to initiate a session -/// -/// See [types::GatewayIdentifyPayload] -const GATEWAY_IDENTIFY: u8 = 2; -/// Opcode sent to update our presence -/// -/// See [types::GatewayUpdatePresence] -const GATEWAY_UPDATE_PRESENCE: u8 = 3; -/// Opcode sent to update our state in vc -/// -/// Like muting, deafening, leaving, joining.. -/// -/// See [types::UpdateVoiceState] -const GATEWAY_UPDATE_VOICE_STATE: u8 = 4; -/// Opcode sent to resume a session -/// -/// See [types::GatewayResume] -const GATEWAY_RESUME: u8 = 6; -/// Opcode received to tell the client to reconnect -const GATEWAY_RECONNECT: u8 = 7; -/// Opcode sent to request guild member data -/// -/// See [types::GatewayRequestGuildMembers] -const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8; -/// Opcode received to tell the client their token / session is invalid -const GATEWAY_INVALID_SESSION: u8 = 9; -/// Opcode received when initially connecting to the gateway, starts our heartbeat -/// -/// See [types::HelloData] -const GATEWAY_HELLO: u8 = 10; -/// Opcode received to acknowledge a heartbeat -const GATEWAY_HEARTBEAT_ACK: u8 = 11; -/// Opcode sent to get the voice state of users in a given DM/group channel -/// -/// See [types::CallSync] -const GATEWAY_CALL_SYNC: u8 = 13; -/// Opcode sent to get data for a server (Lazy Loading request) -/// -/// Sent by the official client when switching to a server -/// -/// See [types::LazyRequest] -const GATEWAY_LAZY_REQUEST: u8 = 14; - -/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms -const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; - -/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError]. -/// This struct is used internally when handling messages. -#[derive(Clone, Debug)] -pub struct GatewayMessage { - /// The message we received from the server - message: tokio_tungstenite::tungstenite::Message, -} - -impl GatewayMessage { - /// Creates self from a tungstenite message - pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self { - Self { message } - } - - /// Parses the message as an error; - /// Returns the error if succesfully parsed, None if the message isn't an error - pub fn error(&self) -> Option { - let content = self.message.to_string(); - - // Some error strings have dots on the end, which we don't care about - let processed_content = content.to_lowercase().replace('.', ""); - - match processed_content.as_str() { - "unknown error" | "4000" => Some(GatewayError::Unknown), - "unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode), - "decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode), - "not authenticated" | "4003" => Some(GatewayError::NotAuthenticated), - "authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed), - "already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated), - "invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber), - "rate limited" | "4008" => Some(GatewayError::RateLimited), - "session timed out" | "4009" => Some(GatewayError::SessionTimedOut), - "invalid shard" | "4010" => Some(GatewayError::InvalidShard), - "sharding required" | "4011" => Some(GatewayError::ShardingRequired), - "invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion), - "invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents), - "disallowed intent(s)" | "disallowed intents" | "4014" => { - Some(GatewayError::DisallowedIntents) - } - _ => None, - } - } - - /// Returns whether or not the message is an error - pub fn is_error(&self) -> bool { - self.error().is_some() - } - - /// Parses the message as a payload; - /// Returns a result of deserializing - pub fn payload(&self) -> Result { - return serde_json::from_str(self.message.to_text().unwrap()); - } - - /// Returns whether or not the message is a payload - pub fn is_payload(&self) -> bool { - // close messages are never payloads, payloads are only text messages - if self.message.is_close() | !self.message.is_text() { - return false; - } - - return self.payload().is_ok(); - } - - /// Returns whether or not the message is empty - pub fn is_empty(&self) -> bool { - self.message.is_empty() - } -} - -pub type ObservableObject = dyn Send + Sync + Any; - -/// Represents a handle to a Gateway connection. A Gateway connection will create observable -/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently -/// implemented types with the trait [`WebSocketEvent`] -/// Using this handle you can also send Gateway Events directly. -#[derive(Debug, Clone)] -pub struct GatewayHandle { - pub url: String, - pub events: Arc>, - pub websocket_send: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - /// Tells gateway tasks to close - kill_send: tokio::sync::broadcast::Sender<()>, - pub(crate) store: Arc>>>>, -} - -/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. -pub trait Updateable: 'static + Send + Sync { - fn id(&self) -> Snowflake; -} - -impl GatewayHandle { - /// Sends json to the gateway with an opcode - async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) { - let gateway_payload = types::GatewaySendPayload { - op_code, - event_data: Some(to_send), - sequence_number: None, - }; - - let payload_json = serde_json::to_string(&gateway_payload).unwrap(); - - let message = tokio_tungstenite::tungstenite::Message::text(payload_json); - - self.websocket_send - .lock() - .await - .send(message) - .await - .unwrap(); - } - - pub async fn observe>( - &self, - object: Arc>, - ) -> Arc> { - let mut store = self.store.lock().await; - let id = object.read().unwrap().id(); - if let Some(channel) = store.get(&id) { - let object = channel.clone(); - drop(store); - object - .read() - .unwrap() - .downcast_ref::() - .unwrap_or_else(|| { - panic!( - "Snowflake {} already exists in the store, but it is not of type T.", - id - ) - }); - let ptr = Arc::into_raw(object.clone()); - // SAFETY: - // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. - // - This operation doesn't read or write any shared data, and thus cannot cause a data race - // - The reference count is not being modified - let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock).clone() }; - let object = downcasted.read().unwrap().clone(); - - let watched_object = object.watch_whole(self).await; - *downcasted.write().unwrap() = watched_object; - downcasted - } else { - let id = object.read().unwrap().id(); - let object = object.read().unwrap().clone(); - let object = object.clone().watch_whole(self).await; - let wrapped = Arc::new(RwLock::new(object)); - store.insert(id, wrapped.clone()); - wrapped - } - } - - /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` - /// with all of its observable fields being observed. - pub async fn observe_and_into_inner>( - &self, - object: Arc>, - ) -> T { - let channel = self.observe(object.clone()).await; - let object = channel.read().unwrap().clone(); - object - } - - /// Sends an identify event to the gateway - pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Identify.."); - - self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; - } - - /// Sends a resume event to the gateway - pub async fn send_resume(&self, to_send: types::GatewayResume) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Resume.."); - - self.send_json_event(GATEWAY_RESUME, to_send_value).await; - } - - /// Sends an update presence event to the gateway - pub async fn send_update_presence(&self, to_send: types::UpdatePresence) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Update Presence.."); - - self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) - .await; - } - - /// Sends a request guild members to the server - pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Request Guild Members.."); - - self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) - .await; - } - - /// Sends an update voice state to the server - pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { - let to_send_value = serde_json::to_value(to_send).unwrap(); - - trace!("GW: Sending Update Voice State.."); - - self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) - .await; - } - - /// Sends a call sync to the server - pub async fn send_call_sync(&self, to_send: types::CallSync) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Call Sync.."); - - self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; - } - - /// Sends a Lazy Request - pub async fn send_lazy_request(&self, to_send: types::LazyRequest) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Lazy Request.."); - - self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) - .await; - } - - /// Closes the websocket connection and stops all gateway tasks; - /// - /// Esentially pulls the plug on the gateway, leaving it possible to resume; - pub async fn close(&self) { - self.kill_send.send(()).unwrap(); - self.websocket_send.lock().await.close().await.unwrap(); - } -} #[derive(Debug)] pub struct Gateway { @@ -709,10 +388,10 @@ impl Gateway { | GATEWAY_REQUEST_GUILD_MEMBERS | GATEWAY_CALL_SYNC | GATEWAY_LAZY_REQUEST => { - let error = GatewayError::UnexpectedOpcodeReceived { - opcode: gateway_payload.op_code, - }; - Err::<(), GatewayError>(error).unwrap(); + info!( + "Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation and is likely not the fault of chorus.", + gateway_payload.op_code + ); } _ => { warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); @@ -736,196 +415,7 @@ impl Gateway { } } -/// Handles sending heartbeats to the gateway in another thread -#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used -#[derive(Debug)] -struct HeartbeatHandler { - /// How ofter heartbeats need to be sent at a minimum - pub heartbeat_interval: Duration, - /// The send channel for the heartbeat thread - pub send: Sender, - /// The handle of the thread - handle: JoinHandle<()>, -} - -impl HeartbeatHandler { - pub fn new( - heartbeat_interval: Duration, - websocket_tx: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - kill_rc: tokio::sync::broadcast::Receiver<()>, - ) -> HeartbeatHandler { - let (send, receive) = tokio::sync::mpsc::channel(32); - let kill_receive = kill_rc.resubscribe(); - - let handle: JoinHandle<()> = task::spawn(async move { - HeartbeatHandler::heartbeat_task( - websocket_tx, - heartbeat_interval, - receive, - kill_receive, - ) - .await; - }); - - Self { - heartbeat_interval, - send, - handle, - } - } - - /// The main heartbeat task; - /// - /// Can be killed by the kill broadcast; - /// If the websocket is closed, will die out next time it tries to send a heartbeat; - pub async fn heartbeat_task( - websocket_tx: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - heartbeat_interval: Duration, - mut receive: tokio::sync::mpsc::Receiver, - mut kill_receive: tokio::sync::broadcast::Receiver<()>, - ) { - let mut last_heartbeat_timestamp: Instant = time::Instant::now(); - let mut last_heartbeat_acknowledged = true; - let mut last_seq_number: Option = None; - - loop { - if kill_receive.try_recv().is_ok() { - trace!("GW: Closing heartbeat task"); - break; - } - - let timeout = if last_heartbeat_acknowledged { - heartbeat_interval - } else { - // If the server hasn't acknowledged our heartbeat we should resend it - Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) - }; - - let mut should_send = false; - - tokio::select! { - () = sleep_until(last_heartbeat_timestamp + timeout) => { - should_send = true; - } - Some(communication) = receive.recv() => { - // If we received a seq number update, use that as the last seq number - if communication.sequence_number.is_some() { - last_seq_number = communication.sequence_number; - } - - if let Some(op_code) = communication.op_code { - match op_code { - GATEWAY_HEARTBEAT => { - // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately - should_send = true; - } - GATEWAY_HEARTBEAT_ACK => { - // The server received our heartbeat - last_heartbeat_acknowledged = true; - } - _ => {} - } - } - } - } - - if should_send { - trace!("GW: Sending Heartbeat.."); - - let heartbeat = types::GatewayHeartbeat { - op: GATEWAY_HEARTBEAT, - d: last_seq_number, - }; - - let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); - - let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); - - let send_result = websocket_tx.lock().await.send(msg).await; - if send_result.is_err() { - // We couldn't send, the websocket is broken - warn!("GW: Couldnt send heartbeat, websocket seems broken"); - break; - } - - last_heartbeat_timestamp = time::Instant::now(); - last_heartbeat_acknowledged = false; - } - } - } -} - -/// Used for communications between the heartbeat and gateway thread. -/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server -#[derive(Clone, Copy, Debug)] -struct HeartbeatThreadCommunication { - /// The opcode for the communication we received, if relevant - op_code: Option, - /// The sequence number we got from discord, if any - sequence_number: Option, -} - -/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to -/// an Observable. The Observer is notified when the Observable's data changes. -/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent. -/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing. -#[async_trait] -pub trait Observer: Sync + Send + std::fmt::Debug { - async fn update(&self, data: &T); -} - -/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a -/// change in the WebSocketEvent. GatewayEvents are observable. -#[derive(Default, Debug)] -pub struct GatewayEvent { - observers: Vec>>, -} - -impl GatewayEvent { - /// Returns true if the GatewayEvent is observed by at least one Observer. - pub fn is_observed(&self) -> bool { - !self.observers.is_empty() - } - - /// Subscribes an Observer to the GatewayEvent. - pub fn subscribe(&mut self, observable: Arc>) { - self.observers.push(observable); - } - - /// Unsubscribes an Observer from the GatewayEvent. - pub fn unsubscribe(&mut self, observable: &dyn Observer) { - // .retain()'s closure retains only those elements of the vector, which have a different - // pointer value than observable. - // The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same - // anddd there is no way to do that without using format - let to_remove = format!("{:?}", observable); - self.observers - .retain(|obs| format!("{:?}", obs) != to_remove); - } - - /// Notifies the observers of the GatewayEvent. - async fn notify(&self, new_event_data: T) { - for observer in &self.observers { - observer.update(&new_event_data).await; - } - } -} - -pub mod events { +pub mod event { use super::*; #[derive(Default, Debug)] @@ -1086,52 +576,3 @@ pub mod events { pub update: GatewayEvent, } } - -#[cfg(test)] -mod example { - use super::*; - use std::sync::atomic::{AtomicI32, Ordering::Relaxed}; - - #[derive(Debug)] - struct Consumer { - _name: String, - events_received: AtomicI32, - } - - #[async_trait] - impl Observer for Consumer { - async fn update(&self, _data: &types::GatewayResume) { - self.events_received.fetch_add(1, Relaxed); - } - } - - #[tokio::test] - async fn test_observer_behavior() { - let mut event = GatewayEvent::default(); - - let new_data = types::GatewayResume { - token: "token_3276ha37am3".to_string(), - session_id: "89346671230".to_string(), - seq: "3".to_string(), - }; - - let consumer = Arc::new(Consumer { - _name: "first".into(), - events_received: 0.into(), - }); - event.subscribe(consumer.clone()); - - let second_consumer = Arc::new(Consumer { - _name: "second".into(), - events_received: 0.into(), - }); - event.subscribe(second_consumer.clone()); - - event.notify(new_data.clone()).await; - event.unsubscribe(&*consumer); - event.notify(new_data).await; - - assert_eq!(consumer.events_received.load(Relaxed), 1); - assert_eq!(second_consumer.events_received.load(Relaxed), 2); - } -} diff --git a/src/gateway/handle.rs b/src/gateway/handle.rs new file mode 100644 index 0000000..1200b30 --- /dev/null +++ b/src/gateway/handle.rs @@ -0,0 +1,171 @@ +use super::{event::Events, *}; +use crate::types::{self, Composite}; + +/// Represents a handle to a Gateway connection. A Gateway connection will create observable +/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently +/// implemented types with the trait [`WebSocketEvent`] +/// Using this handle you can also send Gateway Events directly. +#[derive(Debug, Clone)] +pub struct GatewayHandle { + pub url: String, + pub events: Arc>, + pub websocket_send: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + /// Tells gateway tasks to close + pub(super) kill_send: tokio::sync::broadcast::Sender<()>, + pub(crate) store: Arc>>>>, +} + +impl GatewayHandle { + /// Sends json to the gateway with an opcode + async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) { + let gateway_payload = types::GatewaySendPayload { + op_code, + event_data: Some(to_send), + sequence_number: None, + }; + + let payload_json = serde_json::to_string(&gateway_payload).unwrap(); + + let message = tokio_tungstenite::tungstenite::Message::text(payload_json); + + self.websocket_send + .lock() + .await + .send(message) + .await + .unwrap(); + } + + pub async fn observe>( + &self, + object: Arc>, + ) -> Arc> { + let mut store = self.store.lock().await; + let id = object.read().unwrap().id(); + if let Some(channel) = store.get(&id) { + let object = channel.clone(); + drop(store); + object + .read() + .unwrap() + .downcast_ref::() + .unwrap_or_else(|| { + panic!( + "Snowflake {} already exists in the store, but it is not of type T.", + id + ) + }); + let ptr = Arc::into_raw(object.clone()); + // SAFETY: + // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. + // - This operation doesn't read or write any shared data, and thus cannot cause a data race + // - The reference count is not being modified + let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock).clone() }; + let object = downcasted.read().unwrap().clone(); + + let watched_object = object.watch_whole(self).await; + *downcasted.write().unwrap() = watched_object; + downcasted + } else { + let id = object.read().unwrap().id(); + let object = object.read().unwrap().clone(); + let object = object.clone().watch_whole(self).await; + let wrapped = Arc::new(RwLock::new(object)); + store.insert(id, wrapped.clone()); + wrapped + } + } + + /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` + /// with all of its observable fields being observed. + pub async fn observe_and_into_inner>( + &self, + object: Arc>, + ) -> T { + let channel = self.observe(object.clone()).await; + let object = channel.read().unwrap().clone(); + object + } + + /// Sends an identify event to the gateway + pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Identify.."); + + self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; + } + + /// Sends a resume event to the gateway + pub async fn send_resume(&self, to_send: types::GatewayResume) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Resume.."); + + self.send_json_event(GATEWAY_RESUME, to_send_value).await; + } + + /// Sends an update presence event to the gateway + pub async fn send_update_presence(&self, to_send: types::UpdatePresence) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Update Presence.."); + + self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) + .await; + } + + /// Sends a request guild members to the server + pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Request Guild Members.."); + + self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) + .await; + } + + /// Sends an update voice state to the server + pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { + let to_send_value = serde_json::to_value(to_send).unwrap(); + + trace!("GW: Sending Update Voice State.."); + + self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) + .await; + } + + /// Sends a call sync to the server + pub async fn send_call_sync(&self, to_send: types::CallSync) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Call Sync.."); + + self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; + } + + /// Sends a Lazy Request + pub async fn send_lazy_request(&self, to_send: types::LazyRequest) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Lazy Request.."); + + self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) + .await; + } + + /// Closes the websocket connection and stops all gateway tasks; + /// + /// Esentially pulls the plug on the gateway, leaving it possible to resume; + pub async fn close(&self) { + self.kill_send.send(()).unwrap(); + self.websocket_send.lock().await.close().await.unwrap(); + } +} diff --git a/src/gateway/heartbeat.rs b/src/gateway/heartbeat.rs new file mode 100644 index 0000000..dd162b7 --- /dev/null +++ b/src/gateway/heartbeat.rs @@ -0,0 +1,149 @@ +use crate::types; + +use super::*; + +/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms +const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; + +/// Handles sending heartbeats to the gateway in another thread +#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used +#[derive(Debug)] +pub(super) struct HeartbeatHandler { + /// How ofter heartbeats need to be sent at a minimum + pub heartbeat_interval: Duration, + /// The send channel for the heartbeat thread + pub send: Sender, + /// The handle of the thread + handle: JoinHandle<()>, +} + +impl HeartbeatHandler { + pub fn new( + heartbeat_interval: Duration, + websocket_tx: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + kill_rc: tokio::sync::broadcast::Receiver<()>, + ) -> HeartbeatHandler { + let (send, receive) = tokio::sync::mpsc::channel(32); + let kill_receive = kill_rc.resubscribe(); + + let handle: JoinHandle<()> = task::spawn(async move { + HeartbeatHandler::heartbeat_task( + websocket_tx, + heartbeat_interval, + receive, + kill_receive, + ) + .await; + }); + + Self { + heartbeat_interval, + send, + handle, + } + } + + /// The main heartbeat task; + /// + /// Can be killed by the kill broadcast; + /// If the websocket is closed, will die out next time it tries to send a heartbeat; + pub async fn heartbeat_task( + websocket_tx: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + heartbeat_interval: Duration, + mut receive: tokio::sync::mpsc::Receiver, + mut kill_receive: tokio::sync::broadcast::Receiver<()>, + ) { + let mut last_heartbeat_timestamp: Instant = time::Instant::now(); + let mut last_heartbeat_acknowledged = true; + let mut last_seq_number: Option = None; + + loop { + if kill_receive.try_recv().is_ok() { + trace!("GW: Closing heartbeat task"); + break; + } + + let timeout = if last_heartbeat_acknowledged { + heartbeat_interval + } else { + // If the server hasn't acknowledged our heartbeat we should resend it + Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) + }; + + let mut should_send = false; + + tokio::select! { + () = sleep_until(last_heartbeat_timestamp + timeout) => { + should_send = true; + } + Some(communication) = receive.recv() => { + // If we received a seq number update, use that as the last seq number + if communication.sequence_number.is_some() { + last_seq_number = communication.sequence_number; + } + + if let Some(op_code) = communication.op_code { + match op_code { + GATEWAY_HEARTBEAT => { + // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately + should_send = true; + } + GATEWAY_HEARTBEAT_ACK => { + // The server received our heartbeat + last_heartbeat_acknowledged = true; + } + _ => {} + } + } + } + } + + if should_send { + trace!("GW: Sending Heartbeat.."); + + let heartbeat = types::GatewayHeartbeat { + op: GATEWAY_HEARTBEAT, + d: last_seq_number, + }; + + let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); + + let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); + + let send_result = websocket_tx.lock().await.send(msg).await; + if send_result.is_err() { + // We couldn't send, the websocket is broken + warn!("GW: Couldnt send heartbeat, websocket seems broken"); + break; + } + + last_heartbeat_timestamp = time::Instant::now(); + last_heartbeat_acknowledged = false; + } + } + } +} + +/// Used for communications between the heartbeat and gateway thread. +/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server +#[derive(Clone, Copy, Debug)] +pub(super) struct HeartbeatThreadCommunication { + /// The opcode for the communication we received, if relevant + pub(super) op_code: Option, + /// The sequence number we got from discord, if any + pub(super) sequence_number: Option, +} diff --git a/src/gateway/message.rs b/src/gateway/message.rs new file mode 100644 index 0000000..edee9dd --- /dev/null +++ b/src/gateway/message.rs @@ -0,0 +1,73 @@ +use crate::types; + +use super::*; + +/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError]. +/// This struct is used internally when handling messages. +#[derive(Clone, Debug)] +pub struct GatewayMessage { + /// The message we received from the server + pub(super) message: tokio_tungstenite::tungstenite::Message, +} + +impl GatewayMessage { + /// Creates self from a tungstenite message + pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self { + Self { message } + } + + /// Parses the message as an error; + /// Returns the error if succesfully parsed, None if the message isn't an error + pub fn error(&self) -> Option { + let content = self.message.to_string(); + + // Some error strings have dots on the end, which we don't care about + let processed_content = content.to_lowercase().replace('.', ""); + + match processed_content.as_str() { + "unknown error" | "4000" => Some(GatewayError::Unknown), + "unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode), + "decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode), + "not authenticated" | "4003" => Some(GatewayError::NotAuthenticated), + "authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed), + "already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated), + "invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber), + "rate limited" | "4008" => Some(GatewayError::RateLimited), + "session timed out" | "4009" => Some(GatewayError::SessionTimedOut), + "invalid shard" | "4010" => Some(GatewayError::InvalidShard), + "sharding required" | "4011" => Some(GatewayError::ShardingRequired), + "invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion), + "invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents), + "disallowed intent(s)" | "disallowed intents" | "4014" => { + Some(GatewayError::DisallowedIntents) + } + _ => None, + } + } + + /// Returns whether or not the message is an error + pub fn is_error(&self) -> bool { + self.error().is_some() + } + + /// Parses the message as a payload; + /// Returns a result of deserializing + pub fn payload(&self) -> Result { + return serde_json::from_str(self.message.to_text().unwrap()); + } + + /// Returns whether or not the message is a payload + pub fn is_payload(&self) -> bool { + // close messages are never payloads, payloads are only text messages + if self.message.is_close() | !self.message.is_text() { + return false; + } + + return self.payload().is_ok(); + } + + /// Returns whether or not the message is empty + pub fn is_empty(&self) -> bool { + self.message.is_empty() + } +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs new file mode 100644 index 0000000..ebd06cc --- /dev/null +++ b/src/gateway/mod.rs @@ -0,0 +1,187 @@ +pub mod gateway; +pub mod handle; +pub mod heartbeat; +pub mod message; + +pub use gateway::*; +pub use handle::*; +use heartbeat::*; +pub use message::*; + +use crate::errors::GatewayError; +use crate::types::{Snowflake, WebSocketEvent}; + +use async_trait::async_trait; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; +use std::time::Duration; +use tokio::time::sleep_until; + +use futures_util::stream::SplitSink; +use futures_util::stream::SplitStream; +use futures_util::SinkExt; +use futures_util::StreamExt; +use log::{info, trace, warn}; +use tokio::net::TcpStream; +use tokio::sync::mpsc::Sender; +use tokio::sync::Mutex; +use tokio::task; +use tokio::task::JoinHandle; +use tokio::time; +use tokio::time::Instant; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; + +// Gateway opcodes +/// Opcode received when the server dispatches a [crate::types::WebSocketEvent] +const GATEWAY_DISPATCH: u8 = 0; +/// Opcode sent when sending a heartbeat +const GATEWAY_HEARTBEAT: u8 = 1; +/// Opcode sent to initiate a session +/// +/// See [types::GatewayIdentifyPayload] +const GATEWAY_IDENTIFY: u8 = 2; +/// Opcode sent to update our presence +/// +/// See [types::GatewayUpdatePresence] +const GATEWAY_UPDATE_PRESENCE: u8 = 3; +/// Opcode sent to update our state in vc +/// +/// Like muting, deafening, leaving, joining.. +/// +/// See [types::UpdateVoiceState] +const GATEWAY_UPDATE_VOICE_STATE: u8 = 4; +/// Opcode sent to resume a session +/// +/// See [types::GatewayResume] +const GATEWAY_RESUME: u8 = 6; +/// Opcode received to tell the client to reconnect +const GATEWAY_RECONNECT: u8 = 7; +/// Opcode sent to request guild member data +/// +/// See [types::GatewayRequestGuildMembers] +const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8; +/// Opcode received to tell the client their token / session is invalid +const GATEWAY_INVALID_SESSION: u8 = 9; +/// Opcode received when initially connecting to the gateway, starts our heartbeat +/// +/// See [types::HelloData] +const GATEWAY_HELLO: u8 = 10; +/// Opcode received to acknowledge a heartbeat +const GATEWAY_HEARTBEAT_ACK: u8 = 11; +/// Opcode sent to get the voice state of users in a given DM/group channel +/// +/// See [types::CallSync] +const GATEWAY_CALL_SYNC: u8 = 13; +/// Opcode sent to get data for a server (Lazy Loading request) +/// +/// Sent by the official client when switching to a server +/// +/// See [types::LazyRequest] +const GATEWAY_LAZY_REQUEST: u8 = 14; + +pub type ObservableObject = dyn Send + Sync + Any; + +/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. +pub trait Updateable: 'static + Send + Sync { + fn id(&self) -> Snowflake; +} + +/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to +/// an Observable. The Observer is notified when the Observable's data changes. +/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent. +/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing. +#[async_trait] +pub trait Observer: Sync + Send + std::fmt::Debug { + async fn update(&self, data: &T); +} + +/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a +/// change in the WebSocketEvent. GatewayEvents are observable. +#[derive(Default, Debug)] +pub struct GatewayEvent { + observers: Vec>>, +} + +impl GatewayEvent { + /// Returns true if the GatewayEvent is observed by at least one Observer. + pub fn is_observed(&self) -> bool { + !self.observers.is_empty() + } + + /// Subscribes an Observer to the GatewayEvent. + pub fn subscribe(&mut self, observable: Arc>) { + self.observers.push(observable); + } + + /// Unsubscribes an Observer from the GatewayEvent. + pub fn unsubscribe(&mut self, observable: &dyn Observer) { + // .retain()'s closure retains only those elements of the vector, which have a different + // pointer value than observable. + // The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same + // anddd there is no way to do that without using format + let to_remove = format!("{:?}", observable); + self.observers + .retain(|obs| format!("{:?}", obs) != to_remove); + } + + /// Notifies the observers of the GatewayEvent. + async fn notify(&self, new_event_data: T) { + for observer in &self.observers { + observer.update(&new_event_data).await; + } + } +} + +#[cfg(test)] +mod example { + use crate::types; + + use super::*; + use std::sync::atomic::{AtomicI32, Ordering::Relaxed}; + + #[derive(Debug)] + struct Consumer { + _name: String, + events_received: AtomicI32, + } + + #[async_trait] + impl Observer for Consumer { + async fn update(&self, _data: &types::GatewayResume) { + self.events_received.fetch_add(1, Relaxed); + } + } + + #[tokio::test] + async fn test_observer_behavior() { + let mut event = GatewayEvent::default(); + + let new_data = types::GatewayResume { + token: "token_3276ha37am3".to_string(), + session_id: "89346671230".to_string(), + seq: "3".to_string(), + }; + + let consumer = Arc::new(Consumer { + _name: "first".into(), + events_received: 0.into(), + }); + event.subscribe(consumer.clone()); + + let second_consumer = Arc::new(Consumer { + _name: "second".into(), + events_received: 0.into(), + }); + event.subscribe(second_consumer.clone()); + + event.notify(new_data.clone()).await; + event.unsubscribe(&*consumer); + event.notify(new_data).await; + + assert_eq!(consumer.events_received.load(Relaxed), 1); + assert_eq!(second_consumer.events_received.load(Relaxed), 2); + } +} diff --git a/src/types/entities/invite.rs b/src/types/entities/invite.rs index 842db3f..d08ad1d 100644 --- a/src/types/entities/invite.rs +++ b/src/types/entities/invite.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::types::{Snowflake, WelcomeScreenObject}; use super::guild::GuildScheduledEvent; -use super::{Application, Channel, GuildMember, User}; +use super::{Application, Channel, GuildMember, NSFWLevel, User}; /// Represents a code that when used, adds a user to a guild or group DM channel, or creates a relationship between two users. /// See @@ -56,17 +56,6 @@ pub struct InviteGuild { pub welcome_screen: Option, } -/// See for an explanation on what -/// the levels mean. -#[derive(Debug, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum NSFWLevel { - Default = 0, - Explicit = 1, - Safe = 2, - AgeRestricted = 3, -} - /// See #[derive(Debug, Serialize, Deserialize)] pub struct InviteStageInstance {