use crate::errors::GatewayError; use crate::errors::ObserverError; use crate::gateway::events::Events; use crate::types; use futures_util::stream::SplitSink; use futures_util::stream::SplitStream; use futures_util::SinkExt; use futures_util::StreamExt; use native_tls::TlsConnector; use std::sync::Arc; use tokio::net::TcpStream; use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::Sender; use tokio::sync::Mutex; use tokio::task; use tokio::task::JoinHandle; use tokio::time; use tokio::time::Instant; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; // Gateway opcodes /// Opcode received when the server dispatches a [crate::types::WebSocketEvent] const GATEWAY_DISPATCH: u8 = 0; /// Opcode sent when sending a heartbeat const GATEWAY_HEARTBEAT: u8 = 1; /// Opcode sent to initiate a session /// /// See [types::GatewayIdentifyPayload] const GATEWAY_IDENTIFY: u8 = 2; /// Opcode sent to update our presence /// /// See [types::GatewayUpdatePresence] const GATEWAY_UPDATE_PRESENCE: u8 = 3; /// Opcode sent to update our state in vc /// /// Like muting, deafening, leaving, joining.. /// /// See [types::UpdateVoiceState] const GATEWAY_UPDATE_VOICE_STATE: u8 = 4; /// Opcode sent to resume a session /// /// See [types::GatewayResume] const GATEWAY_RESUME: u8 = 6; /// Opcode received to tell the client to reconnect const GATEWAY_RECONNECT: u8 = 7; /// Opcode sent to request guild member data /// /// See [types::GatewayRequestGuildMembers] const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8; /// Opcode received to tell the client their token / session is invalid const GATEWAY_INVALID_SESSION: u8 = 9; /// Opcode received when initially connecting to the gateway, starts our heartbeat /// /// See [types::HelloData] const GATEWAY_HELLO: u8 = 10; /// Opcode received to acknowledge a heartbeat const GATEWAY_HEARTBEAT_ACK: u8 = 11; /// Opcode sent to get the voice state of users in a given DM/group channel /// /// See [types::CallSync] const GATEWAY_CALL_SYNC: u8 = 13; /// Opcode sent to get data for a server (Lazy Loading request) /// /// Sent by the official client when switching to a server /// /// See [types::LazyRequest] const GATEWAY_LAZY_REQUEST: u8 = 14; /// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms const HEARTBEAT_ACK_TIMEOUT: u128 = 2000; #[derive(Clone, Debug)] /** Represents a messsage received from the gateway. This will be either a [GatewayReceivePayload], containing events, or a [GatewayError]. This struct is used internally when handling messages. */ pub struct GatewayMessage { /// The message we received from the server message: tokio_tungstenite::tungstenite::Message, } impl GatewayMessage { /// Creates self from a tungstenite message pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self { Self { message } } /// Parses the message as an error; /// Returns the error if succesfully parsed, None if the message isn't an error pub fn error(&self) -> Option { let content = self.message.to_string(); // Some error strings have dots on the end, which we don't care about let processed_content = content.clone().to_lowercase().replace(".", ""); match processed_content.as_str() { "unknown error" | "4000" => { return Some(GatewayError::UnknownError); } "unknown opcode" | "4001" => { return Some(GatewayError::UnknownOpcodeError); } "decode error" | "4002" => { return Some(GatewayError::DecodeError); } "not authenticated" | "4003" => { return Some(GatewayError::NotAuthenticatedError); } "authentication failed" | "4004" => { return Some(GatewayError::AuthenticationFailedError); } "already authenticated" | "4005" => { return Some(GatewayError::AlreadyAuthenticatedError); } "invalid seq" | "4007" => { return Some(GatewayError::InvalidSequenceNumberError); } "rate limited" | "4008" => { return Some(GatewayError::RateLimitedError); } "session timed out" | "4009" => { return Some(GatewayError::SessionTimedOutError); } "invalid shard" | "4010" => { return Some(GatewayError::InvalidShardError); } "sharding required" | "4011" => { return Some(GatewayError::ShardingRequiredError); } "invalid api version" | "4012" => { return Some(GatewayError::InvalidAPIVersionError); } "invalid intent(s)" | "invalid intent" | "4013" => { return Some(GatewayError::InvalidIntentsError); } "disallowed intent(s)" | "disallowed intents" | "4014" => { return Some(GatewayError::DisallowedIntentsError); } _ => { return None; } } } /// Returns whether or not the message is an error pub fn is_error(&self) -> bool { return self.error().is_some(); } /// Parses the message as a payload; /// Returns a result of deserializing pub fn payload(&self) -> Result { return serde_json::from_str(self.message.to_text().unwrap()); } /// Returns whether or not the message is a payload pub fn is_payload(&self) -> bool { // close messages are never payloads, payloads are only text messages if self.message.is_close() | !self.message.is_text() { return false; } return self.payload().is_ok(); } /// Returns whether or not the message is empty pub fn is_empty(&self) -> bool { return self.message.is_empty(); } } #[derive(Debug)] /** Represents a handle to a Gateway connection. A Gateway connection will create observable [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently implemented [Types] with the trait [`WebSocketEvent`] Using this handle you can also send Gateway Events directly. */ pub struct GatewayHandle { pub url: String, pub events: Arc>, pub websocket_send: Arc< Mutex< SplitSink< WebSocketStream>, tokio_tungstenite::tungstenite::Message, >, >, >, pub handle: JoinHandle<()>, /// Tells gateway tasks to close kill_send: tokio::sync::broadcast::Sender<()>, } impl GatewayHandle { /// Sends json to the gateway with an opcode async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) { let gateway_payload = types::GatewaySendPayload { op_code, event_data: Some(to_send), sequence_number: None, }; let payload_json = serde_json::to_string(&gateway_payload).unwrap(); let message = tokio_tungstenite::tungstenite::Message::text(payload_json); self.websocket_send .lock() .await .send(message) .await .unwrap(); } /// Sends an identify event to the gateway pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { let to_send_value = serde_json::to_value(&to_send).unwrap(); println!("GW: Sending Identify.."); self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; } /// Sends a resume event to the gateway pub async fn send_resume(&self, to_send: types::GatewayResume) { let to_send_value = serde_json::to_value(&to_send).unwrap(); println!("GW: Sending Resume.."); self.send_json_event(GATEWAY_RESUME, to_send_value).await; } /// Sends an update presence event to the gateway pub async fn send_update_presence(&self, to_send: types::PresenceUpdate) { let to_send_value = serde_json::to_value(&to_send).unwrap(); println!("GW: Sending Presence Update.."); self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) .await; } /// Sends a request guild members to the server pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { let to_send_value = serde_json::to_value(&to_send).unwrap(); println!("GW: Sending Request Guild Members.."); self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) .await; } /// Sends an update voice state to the server pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { let to_send_value = serde_json::to_value(&to_send).unwrap(); println!("GW: Sending Update Voice State.."); self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) .await; } /// Sends a call sync to the server pub async fn send_call_sync(&self, to_send: types::CallSync) { let to_send_value = serde_json::to_value(&to_send).unwrap(); println!("GW: Sending Call Sync.."); self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; } /// Sends a Lazy Request pub async fn send_lazy_request(&self, to_send: types::LazyRequest) { let to_send_value = serde_json::to_value(&to_send).unwrap(); println!("GW: Sending Lazy Request.."); self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) .await; } /// Closes the websocket connection and stops all gateway tasks pub async fn close(&mut self) { self.kill_send.send(()).unwrap(); self.websocket_send.lock().await.close().await.unwrap(); } } pub struct Gateway { pub events: Arc>, heartbeat_handler: HeartbeatHandler, pub websocket_send: Arc< Mutex< SplitSink< WebSocketStream>, tokio_tungstenite::tungstenite::Message, >, >, >, pub websocket_receive: SplitStream>>, kill_send: tokio::sync::broadcast::Sender<()>, } impl Gateway { pub async fn new(websocket_url: String) -> Result { let (websocket_stream, _) = match connect_async_tls_with_config( &websocket_url, None, false, Some(Connector::NativeTls( TlsConnector::builder().build().unwrap(), )), ) .await { Ok(websocket_stream) => websocket_stream, Err(e) => { return Err(GatewayError::CannotConnectError { error: e.to_string(), }) } }; let (websocket_send, mut websocket_receive) = websocket_stream.split(); let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); // Create a shared broadcast channel for killing all gateway tasks let (kill_send, mut _kill_receive) = tokio::sync::broadcast::channel::<()>(16); // Wait for the first hello and then spawn both tasks so we avoid nested tasks // This automatically spawns the heartbeat task, but from the main thread let msg = websocket_receive.next().await.unwrap().unwrap(); let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(msg.to_text().unwrap()).unwrap(); if gateway_payload.op_code != GATEWAY_HELLO { return Err(GatewayError::NonHelloOnInitiateError { opcode: gateway_payload.op_code, }); } println!("GW: Received Hello"); let gateway_hello: types::HelloData = serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap(); let mut gateway = Gateway { events: Arc::new(Mutex::new(Events::default())), heartbeat_handler: HeartbeatHandler::new( gateway_hello.heartbeat_interval, shared_websocket_send.clone(), kill_send.subscribe(), ), websocket_send: shared_websocket_send.clone(), websocket_receive, kill_send: kill_send.clone(), }; let shared_events = gateway.events.clone(); // Now we can continuously check for messages in a different task, since we aren't going to receive another hello let handle: JoinHandle<()> = task::spawn(async move { gateway.gateway_listen_task().await; }); return Ok(GatewayHandle { url: websocket_url.clone(), events: shared_events, websocket_send: shared_websocket_send.clone(), handle, kill_send: kill_send.clone(), }); } /// The main gateway listener task; /// /// Can only be stopped by closing the websocket, cannot be made to listen for kill pub async fn gateway_listen_task(&mut self) { loop { let msg = self.websocket_receive.next().await; // This if chain can be much better but if let is unstable on stable rust if msg.as_ref().is_some() { if msg.as_ref().unwrap().is_ok() { let msg_unwrapped = msg.unwrap().unwrap(); self.handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped)) .await; continue; } } // We couldn't receive the next message or it was an error, something is wrong with the websocket, close println!("GW: Websocket is broken, stopping gateway"); break; } } /// Closes the websocket connection and stops all tasks async fn close(&mut self) { self.kill_send.send(()).unwrap(); self.websocket_send.lock().await.close().await.unwrap(); } /// This handles a message as a websocket event and updates its events along with the events' observers pub async fn handle_event(&mut self, msg: GatewayMessage) { if msg.is_empty() { return; } // To:do: handle errors in a good way, maybe observers like events? if msg.is_error() { println!("GW: Received error, connection will close.."); let error = msg.error(); match error { _ => {} } self.close().await; return; } let gateway_payload = msg.payload().unwrap(); // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes match gateway_payload.op_code { // An event was dispatched, we need to look at the gateway event name t GATEWAY_DISPATCH => { let gateway_payload_t = gateway_payload.clone().event_name.unwrap(); println!("GW: Received {}..", gateway_payload_t); //println!("Event data dump: {}", gateway_payload.d.clone().unwrap().get()); // See https://discord.com/developers/docs/topics/gateway-events#receive-events // "Some" of these are undocumented match gateway_payload_t.as_str() { "READY" => { let new_data: types::GatewayReady = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .session .ready .update_data(new_data) .await; } "READY_SUPPLEMENTAL" => { let new_data: types::GatewayReadySupplemental = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .session .ready_supplemental .update_data(new_data) .await; } "RESUMED" => {} "APPLICATION_COMMAND_PERMISSIONS_UPDATE" => { let new_data: types::ApplicationCommandPermissionsUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .application .command_permissions_update .update_data(new_data) .await; } "AUTO_MODERATION_RULE_CREATE" => { let new_data: types::AutoModerationRuleCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .auto_moderation .rule_create .update_data(new_data) .await; } "AUTO_MODERATION_RULE_UPDATE" => { let new_data: types::AutoModerationRuleUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .auto_moderation .rule_update .update_data(new_data) .await; } "AUTO_MODERATION_RULE_DELETE" => { let new_data: types::AutoModerationRuleDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .auto_moderation .rule_delete .update_data(new_data) .await; } "AUTO_MODERATION_ACTION_EXECUTION" => { let new_data: types::AutoModerationActionExecution = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .auto_moderation .action_execution .update_data(new_data) .await; } "CHANNEL_CREATE" => { let new_data: types::ChannelCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .channel .create .update_data(new_data) .await; } "CHANNEL_UPDATE" => { let new_data: types::ChannelUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .channel .update .update_data(new_data) .await; } "CHANNEL_UNREAD_UPDATE" => { let new_data: types::ChannelUnreadUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .channel .unread_update .update_data(new_data) .await; } "CHANNEL_DELETE" => { let new_data: types::ChannelDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .channel .delete .update_data(new_data) .await; } "CHANNEL_PINS_UPDATE" => { let new_data: types::ChannelPinsUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .channel .pins_update .update_data(new_data) .await; } "CALL_CREATE" => { let new_data: types::CallCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .call .create .update_data(new_data) .await; } "CALL_UPDATE" => { let new_data: types::CallUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .call .update .update_data(new_data) .await; } "CALL_DELETE" => { let new_data: types::CallDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .call .delete .update_data(new_data) .await; } "THREAD_CREATE" => { let new_data: types::ThreadCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .thread .create .update_data(new_data) .await; } "THREAD_UPDATE" => { let new_data: types::ThreadUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .thread .update .update_data(new_data) .await; } "THREAD_DELETE" => { let new_data: types::ThreadDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .thread .delete .update_data(new_data) .await; } "THREAD_LIST_SYNC" => { let new_data: types::ThreadListSync = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .thread .list_sync .update_data(new_data) .await; } "THREAD_MEMBER_UPDATE" => { let new_data: types::ThreadMemberUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .thread .member_update .update_data(new_data) .await; } "THREAD_MEMBERS_UPDATE" => { let new_data: types::ThreadMembersUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .thread .members_update .update_data(new_data) .await; } "GUILD_CREATE" => { let new_data: types::GuildCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .create .update_data(new_data) .await; } "GUILD_UPDATE" => { let new_data: types::GuildUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .update .update_data(new_data) .await; } "GUILD_DELETE" => { let new_data: types::GuildDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .delete .update_data(new_data) .await; } "GUILD_AUDIT_LOG_ENTRY_CREATE" => { let new_data: types::GuildAuditLogEntryCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .audit_log_entry_create .update_data(new_data) .await; } "GUILD_BAN_ADD" => { let new_data: types::GuildBanAdd = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .ban_add .update_data(new_data) .await; } "GUILD_BAN_REMOVE" => { let new_data: types::GuildBanRemove = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .ban_remove .update_data(new_data) .await; } "GUILD_EMOJIS_UPDATE" => { let new_data: types::GuildEmojisUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .emojis_update .update_data(new_data) .await; } "GUILD_STICKERS_UPDATE" => { let new_data: types::GuildStickersUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .stickers_update .update_data(new_data) .await; } "GUILD_INTEGRATIONS_UPDATE" => { let new_data: types::GuildIntegrationsUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .integrations_update .update_data(new_data) .await; } "GUILD_MEMBER_ADD" => { let new_data: types::GuildMemberAdd = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .member_add .update_data(new_data) .await; } "GUILD_MEMBER_REMOVE" => { let new_data: types::GuildMemberRemove = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .member_remove .update_data(new_data) .await; } "GUILD_MEMBER_UPDATE" => { let new_data: types::GuildMemberUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .member_update .update_data(new_data) .await; } "GUILD_MEMBERS_CHUNK" => { let new_data: types::GuildMembersChunk = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .members_chunk .update_data(new_data) .await; } "GUILD_ROLE_CREATE" => { let new_data: types::GuildRoleCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .role_create .update_data(new_data) .await; } "GUILD_ROLE_UPDATE" => { let new_data: types::GuildRoleUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .role_update .update_data(new_data) .await; } "GUILD_ROLE_DELETE" => { let new_data: types::GuildRoleDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .role_delete .update_data(new_data) .await; } "GUILD_SCHEDULED_EVENT_CREATE" => { let new_data: types::GuildScheduledEventCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .role_scheduled_event_create .update_data(new_data) .await; } "GUILD_SCHEDULED_EVENT_UPDATE" => { let new_data: types::GuildScheduledEventUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .role_scheduled_event_update .update_data(new_data) .await; } "GUILD_SCHEDULED_EVENT_DELETE" => { let new_data: types::GuildScheduledEventDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .role_scheduled_event_delete .update_data(new_data) .await; } "GUILD_SCHEDULED_EVENT_USER_ADD" => { let new_data: types::GuildScheduledEventUserAdd = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .role_scheduled_event_user_add .update_data(new_data) .await; } "GUILD_SCHEDULED_EVENT_USER_REMOVE" => { let new_data: types::GuildScheduledEventUserRemove = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .role_scheduled_event_user_remove .update_data(new_data) .await; } "PASSIVE_UPDATE_V1" => { let new_data: types::PassiveUpdateV1 = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .guild .passive_update_v1 .update_data(new_data) .await; } "INTEGRATION_CREATE" => { let new_data: types::IntegrationCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .integration .create .update_data(new_data) .await; } "INTEGRATION_UPDATE" => { let new_data: types::IntegrationUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .integration .update .update_data(new_data) .await; } "INTEGRATION_DELETE" => { let new_data: types::IntegrationDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .integration .delete .update_data(new_data) .await; } "INTERACTION_CREATE" => { let new_data: types::InteractionCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .interaction .create .update_data(new_data) .await; } "INVITE_CREATE" => { let new_data: types::InviteCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .invite .create .update_data(new_data) .await; } "INVITE_DELETE" => { let new_data: types::InviteDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .invite .delete .update_data(new_data) .await; } "MESSAGE_CREATE" => { let new_data: types::MessageCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .create .update_data(new_data) .await; } "MESSAGE_UPDATE" => { let new_data: types::MessageUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .update .update_data(new_data) .await; } "MESSAGE_DELETE" => { let new_data: types::MessageDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .delete .update_data(new_data) .await; } "MESSAGE_DELETE_BULK" => { let new_data: types::MessageDeleteBulk = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .delete_bulk .update_data(new_data) .await; } "MESSAGE_REACTION_ADD" => { let new_data: types::MessageReactionAdd = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .reaction_add .update_data(new_data) .await; } "MESSAGE_REACTION_REMOVE" => { let new_data: types::MessageReactionRemove = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .reaction_remove .update_data(new_data) .await; } "MESSAGE_REACTION_REMOVE_ALL" => { let new_data: types::MessageReactionRemoveAll = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .reaction_remove_all .update_data(new_data) .await; } "MESSAGE_REACTION_REMOVE_EMOJI" => { let new_data: types::MessageReactionRemoveEmoji = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .reaction_remove_emoji .update_data(new_data) .await; } "MESSAGE_ACK" => { let new_data: types::MessageACK = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .message .ack .update_data(new_data) .await; } "PRESENCE_UPDATE" => { let new_data: types::PresenceUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .user .presence_update .update_data(new_data) .await; } "RELATIONSHIP_ADD" => { let new_data: types::RelationshipAdd = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .relationship .add .update_data(new_data) .await; } "RELATIONSHIP_REMOVE" => { let new_data: types::RelationshipRemove = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .relationship .remove .update_data(new_data) .await; } "STAGE_INSTANCE_CREATE" => { let new_data: types::StageInstanceCreate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .stage_instance .create .update_data(new_data) .await; } "STAGE_INSTANCE_UPDATE" => { let new_data: types::StageInstanceUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .stage_instance .update .update_data(new_data) .await; } "STAGE_INSTANCE_DELETE" => { let new_data: types::StageInstanceDelete = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .stage_instance .delete .update_data(new_data) .await; } "SESSIONS_REPLACE" => { let sessions: Vec = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); let new_data = types::SessionsReplace { sessions }; self.events .lock() .await .session .replace .update_data(new_data) .await; } "TYPING_START" => { let new_data: types::TypingStartEvent = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .user .typing_start_event .update_data(new_data) .await; } "USER_UPDATE" => { let new_data: types::UserUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .user .update .update_data(new_data) .await; } "USER_GUILD_SETTINGS_UPDATE" => { let new_data: types::UserGuildSettingsUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .user .guild_settings_update .update_data(new_data) .await; } "VOICE_STATE_UPDATE" => { let new_data: types::VoiceStateUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .voice .state_update .update_data(new_data) .await; } "VOICE_SERVER_UPDATE" => { let new_data: types::VoiceServerUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .voice .server_update .update_data(new_data) .await; } "WEBHOOKS_UPDATE" => { let new_data: types::WebhooksUpdate = serde_json::from_str(gateway_payload.event_data.unwrap().get()) .unwrap(); self.events .lock() .await .webhooks .update .update_data(new_data) .await; } _ => { println!("Received unrecognized gateway event ({})! Please open an issue on the chorus github so we can implement it", &gateway_payload_t); } } } // We received a heartbeat from the server // "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately." GATEWAY_HEARTBEAT => { println!("GW: Received Heartbeat // Heartbeat Request"); // Tell the heartbeat handler it should send a heartbeat right away let heartbeat_communication = HeartbeatThreadCommunication { sequence_number: gateway_payload.sequence_number, op_code: Some(GATEWAY_HEARTBEAT), }; self.heartbeat_handler .send .send(heartbeat_communication) .await .unwrap(); } GATEWAY_RECONNECT => { todo!() } GATEWAY_INVALID_SESSION => { todo!() } // Starts our heartbeat // We should have already handled this in gateway init GATEWAY_HELLO => { panic!("Received hello when it was unexpected"); } GATEWAY_HEARTBEAT_ACK => { println!("GW: Received Heartbeat ACK"); // Tell the heartbeat handler we received an ack let heartbeat_communication = HeartbeatThreadCommunication { sequence_number: gateway_payload.sequence_number, op_code: Some(GATEWAY_HEARTBEAT_ACK), }; self.heartbeat_handler .send .send(heartbeat_communication) .await .unwrap(); } GATEWAY_IDENTIFY | GATEWAY_UPDATE_PRESENCE | GATEWAY_UPDATE_VOICE_STATE | GATEWAY_RESUME | GATEWAY_REQUEST_GUILD_MEMBERS | GATEWAY_CALL_SYNC | GATEWAY_LAZY_REQUEST => { let error = GatewayError::UnexpectedOpcodeReceivedError { opcode: gateway_payload.op_code, }; Err::<(), GatewayError>(error).unwrap(); } _ => { println!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); } } // If we we received a seq number we should let it know if gateway_payload.sequence_number.is_some() { let heartbeat_communication = HeartbeatThreadCommunication { sequence_number: Some(gateway_payload.sequence_number.unwrap()), // Op code is irrelevant here op_code: None, }; self.heartbeat_handler .send .send(heartbeat_communication) .await .unwrap(); } } } /** Handles sending heartbeats to the gateway in another thread */ struct HeartbeatHandler { /// The heartbeat interval in milliseconds pub heartbeat_interval: u128, /// The send channel for the heartbeat thread pub send: Sender, /// The handle of the thread handle: JoinHandle<()>, } impl HeartbeatHandler { pub fn new( heartbeat_interval: u128, websocket_tx: Arc< Mutex< SplitSink< WebSocketStream>, tokio_tungstenite::tungstenite::Message, >, >, >, kill_rc: tokio::sync::broadcast::Receiver<()>, ) -> HeartbeatHandler { let (send, receive) = tokio::sync::mpsc::channel(32); let kill_receive = kill_rc.resubscribe(); let handle: JoinHandle<()> = task::spawn(async move { HeartbeatHandler::heartbeat_task( websocket_tx, heartbeat_interval, receive, kill_receive, ) .await; }); Self { heartbeat_interval, send, handle, } } /// The main heartbeat task; /// /// Can be killed by the kill broadcast; /// If the websocket is closed, will die out next time it tries to send a heartbeat; pub async fn heartbeat_task( websocket_tx: Arc< Mutex< SplitSink< WebSocketStream>, tokio_tungstenite::tungstenite::Message, >, >, >, heartbeat_interval: u128, mut receive: tokio::sync::mpsc::Receiver, mut kill_receive: tokio::sync::broadcast::Receiver<()>, ) { let mut last_heartbeat_timestamp: Instant = time::Instant::now(); let mut last_heartbeat_acknowledged = true; let mut last_seq_number: Option = None; loop { let should_shutdown = kill_receive.try_recv().is_ok(); if should_shutdown { break; } let mut should_send; let time_to_send = last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval; should_send = time_to_send; let received_communication: Result = receive.try_recv(); if received_communication.is_ok() { let communication = received_communication.unwrap(); // If we received a seq number update, use that as the last seq number if communication.sequence_number.is_some() { last_seq_number = Some(communication.sequence_number.unwrap()); } if communication.op_code.is_some() { match communication.op_code.unwrap() { GATEWAY_HEARTBEAT => { // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately should_send = true; } GATEWAY_HEARTBEAT_ACK => { // The server received our heartbeat last_heartbeat_acknowledged = true; } _ => {} } } } // If the server hasn't acknowledged our heartbeat we should resend it if !last_heartbeat_acknowledged && last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT { should_send = true; println!("GW: Timed out waiting for a heartbeat ack, resending"); } if should_send { println!("GW: Sending Heartbeat.."); let heartbeat = types::GatewayHeartbeat { op: GATEWAY_HEARTBEAT, d: last_seq_number, }; let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); let send_result = websocket_tx.lock().await.send(msg).await; if send_result.is_err() { // We couldn't send, the websocket is broken println!("GW: Couldnt send heartbeat, websocket seems broken"); break; } last_heartbeat_timestamp = time::Instant::now(); last_heartbeat_acknowledged = false; } } } } /** Used for communications between the heartbeat and gateway thread. Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server */ #[derive(Clone, Copy, Debug)] struct HeartbeatThreadCommunication { /// The opcode for the communication we received, if relevant op_code: Option, /// The sequence number we got from discord, if any sequence_number: Option, } /** Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to an Observable. The Observer is notified when the Observable's data changes. In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent. */ pub trait Observer: std::fmt::Debug { fn update(&self, data: &T); } /** GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a change in the WebSocketEvent. GatewayEvents are observable. */ #[derive(Default, Debug)] pub struct GatewayEvent { observers: Vec + Sync + Send>>>, pub event_data: T, pub is_observed: bool, } impl GatewayEvent { fn new(event_data: T) -> Self { Self { is_observed: false, observers: Vec::new(), event_data, } } /** Returns true if the GatewayEvent is observed by at least one Observer. */ pub fn is_observed(&self) -> bool { self.is_observed } /** Subscribes an Observer to the GatewayEvent. Returns an error if the GatewayEvent is already observed. # Errors Returns an error if the GatewayEvent is already observed. Error type: [`ObserverError::AlreadySubscribedError`] */ pub fn subscribe( &mut self, observable: Arc + Sync + Send>>, ) -> Result<(), ObserverError> { if self.is_observed { return Err(ObserverError::AlreadySubscribedError); } self.is_observed = true; self.observers.push(observable); Ok(()) } /** Unsubscribes an Observer from the GatewayEvent. */ pub fn unsubscribe(&mut self, observable: Arc + Sync + Send>>) { // .retain()'s closure retains only those elements of the vector, which have a different // pointer value than observable. // The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same // anddd there is no way to do that without using format self.observers .retain(|obs| !(format!("{:?}", obs) == format!("{:?}", &observable))); self.is_observed = !self.observers.is_empty(); } /** Updates the GatewayEvent's data and notifies the observers. */ async fn update_data(&mut self, new_event_data: T) { self.event_data = new_event_data; self.notify().await; } /** Notifies the observers of the GatewayEvent. */ async fn notify(&self) { for observer in &self.observers { observer.lock().await.update(&self.event_data); } } } mod events { use super::*; #[derive(Default, Debug)] pub struct Events { pub application: Application, pub auto_moderation: AutoModeration, pub session: Session, pub message: Message, pub user: User, pub relationship: Relationship, pub channel: Channel, pub thread: Thread, pub guild: Guild, pub invite: Invite, pub integration: Integration, pub interaction: Interaction, pub stage_instance: StageInstance, pub call: Call, pub voice: Voice, pub webhooks: Webhooks, pub gateway_identify_payload: GatewayEvent, pub gateway_resume: GatewayEvent, } #[derive(Default, Debug)] pub struct Application { pub command_permissions_update: GatewayEvent, } #[derive(Default, Debug)] pub struct AutoModeration { pub rule_create: GatewayEvent, pub rule_update: GatewayEvent, pub rule_delete: GatewayEvent, pub action_execution: GatewayEvent, } #[derive(Default, Debug)] pub struct Session { pub ready: GatewayEvent, pub ready_supplemental: GatewayEvent, pub replace: GatewayEvent, } #[derive(Default, Debug)] pub struct StageInstance { pub create: GatewayEvent, pub update: GatewayEvent, pub delete: GatewayEvent, } #[derive(Default, Debug)] pub struct Message { pub create: GatewayEvent, pub update: GatewayEvent, pub delete: GatewayEvent, pub delete_bulk: GatewayEvent, pub reaction_add: GatewayEvent, pub reaction_remove: GatewayEvent, pub reaction_remove_all: GatewayEvent, pub reaction_remove_emoji: GatewayEvent, pub ack: GatewayEvent, } #[derive(Default, Debug)] pub struct User { pub update: GatewayEvent, pub guild_settings_update: GatewayEvent, pub presence_update: GatewayEvent, pub typing_start_event: GatewayEvent, } #[derive(Default, Debug)] pub struct Relationship { pub add: GatewayEvent, pub remove: GatewayEvent, } #[derive(Default, Debug)] pub struct Channel { pub create: GatewayEvent, pub update: GatewayEvent, pub unread_update: GatewayEvent, pub delete: GatewayEvent, pub pins_update: GatewayEvent, } #[derive(Default, Debug)] pub struct Thread { pub create: GatewayEvent, pub update: GatewayEvent, pub delete: GatewayEvent, pub list_sync: GatewayEvent, pub member_update: GatewayEvent, pub members_update: GatewayEvent, } #[derive(Default, Debug)] pub struct Guild { pub create: GatewayEvent, pub update: GatewayEvent, pub delete: GatewayEvent, pub audit_log_entry_create: GatewayEvent, pub ban_add: GatewayEvent, pub ban_remove: GatewayEvent, pub emojis_update: GatewayEvent, pub stickers_update: GatewayEvent, pub integrations_update: GatewayEvent, pub member_add: GatewayEvent, pub member_remove: GatewayEvent, pub member_update: GatewayEvent, pub members_chunk: GatewayEvent, pub role_create: GatewayEvent, pub role_update: GatewayEvent, pub role_delete: GatewayEvent, pub role_scheduled_event_create: GatewayEvent, pub role_scheduled_event_update: GatewayEvent, pub role_scheduled_event_delete: GatewayEvent, pub role_scheduled_event_user_add: GatewayEvent, pub role_scheduled_event_user_remove: GatewayEvent, pub passive_update_v1: GatewayEvent, } #[derive(Default, Debug)] pub struct Invite { pub create: GatewayEvent, pub delete: GatewayEvent, } #[derive(Default, Debug)] pub struct Integration { pub create: GatewayEvent, pub update: GatewayEvent, pub delete: GatewayEvent, } #[derive(Default, Debug)] pub struct Interaction { pub create: GatewayEvent, } #[derive(Default, Debug)] pub struct Call { pub create: GatewayEvent, pub update: GatewayEvent, pub delete: GatewayEvent, } #[derive(Default, Debug)] pub struct Voice { pub state_update: GatewayEvent, pub server_update: GatewayEvent, } #[derive(Default, Debug)] pub struct Webhooks { pub update: GatewayEvent, } } #[cfg(test)] mod example { use super::*; #[derive(Debug)] struct Consumer; impl Observer for Consumer { fn update(&self, data: &types::GatewayResume) { println!("{}", data.token) } } #[tokio::test] async fn test_observer_behavior() { let mut event = GatewayEvent::new(types::GatewayResume { token: "start".to_string(), session_id: "start".to_string(), seq: "start".to_string(), }); let new_data = types::GatewayResume { token: "token_3276ha37am3".to_string(), session_id: "89346671230".to_string(), seq: "3".to_string(), }; let consumer = Consumer; let arc_mut_consumer = Arc::new(Mutex::new(consumer)); event.subscribe(arc_mut_consumer.clone()); event.notify().await; event.update_data(new_data).await; let second_consumer = Consumer; let arc_mut_second_consumer = Arc::new(Mutex::new(second_consumer)); match event.subscribe(arc_mut_second_consumer.clone()).err() { None => assert!(false), Some(err) => println!("You cannot subscribe twice: {}", err), } event.unsubscribe(arc_mut_consumer.clone()); event.subscribe(arc_mut_second_consumer.clone()).unwrap(); } }