diff --git a/src/errors.rs b/src/errors.rs index 749c8f7..a2843aa 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -31,3 +31,35 @@ custom_error! { pub ObserverError AlreadySubscribedError = "Each event can only be subscribed to once." } + +custom_error! { + /// For errors we receive from the gateway, see https://discord-userdoccers.vercel.app/topics/opcodes-and-status-codes#gateway-close-event-codes; + /// + /// Supposed to be sent as numbers, though they are sent as string most of the time? + /// + /// Also includes errors when initiating a connection and unexpected opcodes + #[derive(PartialEq, Eq)] + pub GatewayError + // Errors we have received from the gateway + UnknownError = "We're not sure what went wrong. Try reconnecting?", + UnknownOpcodeError = "You sent an invalid Gateway opcode or an invalid payload for an opcode", + DecodeError = "Gateway server couldn't decode payload", + NotAuthenticatedError = "You sent a payload prior to identifying", + AuthenticationFailedError = "The account token sent with your identify payload is invalid", + AlreadyAuthenticatedError = "You've already identified, no need to reauthenticate", + InvalidSequenceNumberError = "The sequence number sent when resuming the session was invalid. Reconnect and start a new session", + RateLimitedError = "You are being rate limited!", + SessionTimedOutError = "Your session timed out. Reconnect and start a new one", + InvalidShardError = "You sent us an invalid shard when identifying", + ShardingRequiredError = "The session would have handled too many guilds - you are required to shard your connection in order to connect", + InvalidAPIVersionError = "You sent an invalid Gateway version", + InvalidIntentsError = "You sent an invalid intent", + DisallowedIntentsError = "You sent a disallowed intent. You may have tried to specify an intent that you have not enabled or are not approved for", + + // Errors when initiating a gateway connection + CannotConnectError{error: String} = "Cannot connect due to a tungstenite error: {error}", + NonHelloOnInitiateError{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong", + + // Other misc errors + UnexpectedOpcodeReceivedError{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}", +} diff --git a/src/gateway.rs b/src/gateway.rs index ae696f7..607fabf 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,3 +1,4 @@ +use crate::errors::GatewayError; use crate::errors::ObserverError; use crate::gateway::events::Events; use crate::types; @@ -69,6 +70,106 @@ const GATEWAY_LAZY_REQUEST: u8 = 14; /// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms const HEARTBEAT_ACK_TIMEOUT: u128 = 2000; +#[derive(Clone, Debug)] +/** +Represents a messsage received from the gateway. This will be either a [GatewayReceivePayload], containing events, or a [GatewayError]. +This struct is used internally when handling messages. +*/ +pub struct GatewayMessage { + /// The message we received from the server + message: tokio_tungstenite::tungstenite::Message, +} + +impl GatewayMessage { + /// Creates self from a tungstenite message + pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self { + Self { message } + } + + /// Parses the message as an error; + /// Returns the error if succesfully parsed, None if the message isn't an error + pub fn error(&self) -> Option { + let content = self.message.to_string(); + + // Some error strings have dots on the end, which we don't care about + let processed_content = content.clone().to_lowercase().replace(".", ""); + + match processed_content.as_str() { + "unknown error" | "4000" => { + return Some(GatewayError::UnknownError); + } + "unknown opcode" | "4001" => { + return Some(GatewayError::UnknownOpcodeError); + } + "decode error" | "4002" => { + return Some(GatewayError::DecodeError); + } + "not authenticated" | "4003" => { + return Some(GatewayError::NotAuthenticatedError); + } + "authentication failed" | "4004" => { + return Some(GatewayError::AuthenticationFailedError); + } + "already authenticated" | "4005" => { + return Some(GatewayError::AlreadyAuthenticatedError); + } + "invalid seq" | "4007" => { + return Some(GatewayError::InvalidSequenceNumberError); + } + "rate limited" | "4008" => { + return Some(GatewayError::RateLimitedError); + } + "session timed out" | "4009" => { + return Some(GatewayError::SessionTimedOutError); + } + "invalid shard" | "4010" => { + return Some(GatewayError::InvalidShardError); + } + "sharding required" | "4011" => { + return Some(GatewayError::ShardingRequiredError); + } + "invalid api version" | "4012" => { + return Some(GatewayError::InvalidAPIVersionError); + } + "invalid intent(s)" | "invalid intent" | "4013" => { + return Some(GatewayError::InvalidIntentsError); + } + "disallowed intent(s)" | "disallowed intents" | "4014" => { + return Some(GatewayError::DisallowedIntentsError); + } + _ => { + return None; + } + } + } + + /// Returns whether or not the message is an error + pub fn is_error(&self) -> bool { + return self.error().is_some(); + } + + /// Parses the message as a payload; + /// Returns a result of deserializing + pub fn payload(&self) -> Result { + return serde_json::from_str(self.message.to_text().unwrap()); + } + + /// Returns whether or not the message is a payload + pub fn is_payload(&self) -> bool { + // close messages are never payloads, payloads are only text messages + if self.message.is_close() | !self.message.is_text() { + return false; + } + + return self.payload().is_ok(); + } + + /// Returns whether or not the message is empty + pub fn is_empty(&self) -> bool { + return self.message.is_empty(); + } +} + #[derive(Debug)] /** Represents a handle to a Gateway connection. A Gateway connection will create observable @@ -88,6 +189,8 @@ pub struct GatewayHandle { >, >, pub handle: JoinHandle<()>, + /// Tells gateway tasks to close + kill_send: tokio::sync::broadcast::Sender<()>, } impl GatewayHandle { @@ -177,6 +280,12 @@ impl GatewayHandle { self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) .await; } + + /// Closes the websocket connection and stops all gateway tasks + async fn close(&mut self) { + self.kill_send.send(()).unwrap(); + self.websocket_send.lock().await.close().await.unwrap(); + } } pub struct Gateway { @@ -190,12 +299,11 @@ pub struct Gateway { >, >, >, + kill_send: tokio::sync::broadcast::Sender<()>, } impl Gateway { - pub async fn new( - websocket_url: String, - ) -> Result { + pub async fn new(websocket_url: String) -> Result { let (websocket_stream, _) = match connect_async_tls_with_config( &websocket_url, None, @@ -207,34 +315,39 @@ impl Gateway { .await { Ok(websocket_stream) => websocket_stream, - Err(e) => return Err(e), + Err(e) => { + return Err(GatewayError::CannotConnectError { + error: e.to_string(), + }) + } }; - let (gateway_send, mut gateway_receive) = websocket_stream.split(); + let (websocket_send, mut websocket_receive) = websocket_stream.split(); - let shared_gateway_send = Arc::new(Mutex::new(gateway_send)); + let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); + + // Create a shared broadcast channel for killing all gateway tasks + let (kill_send, mut kill_receive) = tokio::sync::broadcast::channel::<()>(16); let mut gateway = Gateway { events: Arc::new(Mutex::new(Events::default())), heartbeat_handler: None, - websocket_send: shared_gateway_send.clone(), + websocket_send: shared_websocket_send.clone(), + kill_send: kill_send.clone(), }; let shared_events = gateway.events.clone(); // Wait for the first hello and then spawn both tasks so we avoid nested tasks // This automatically spawns the heartbeat task, but from the main thread - let msg = gateway_receive.next().await.unwrap().unwrap(); + let msg = websocket_receive.next().await.unwrap().unwrap(); let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(msg.to_text().unwrap()).unwrap(); if gateway_payload.op_code != GATEWAY_HELLO { - println!("Received non hello on gateway init, what is happening?"); - return Err(tokio_tungstenite::tungstenite::Error::Protocol( - tokio_tungstenite::tungstenite::error::ProtocolError::InvalidOpcode( - gateway_payload.op_code, - ), - )); + return Err(GatewayError::NonHelloOnInitiateError { + opcode: gateway_payload.op_code, + }); } println!("GW: Received Hello"); @@ -243,36 +356,69 @@ impl Gateway { serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap(); gateway.heartbeat_handler = Some(HeartbeatHandler::new( gateway_hello.heartbeat_interval, - shared_gateway_send.clone(), + shared_websocket_send.clone(), + kill_send.subscribe(), )); // Now we can continuously check for messages in a different task, since we aren't going to receive another hello let handle: JoinHandle<()> = task::spawn(async move { loop { - let msg = gateway_receive.next().await; + let msg = websocket_receive.next().await; + + // This if chain can be much better but if let is unstable on stable rust if msg.as_ref().is_some() { - let msg_unwrapped = msg.unwrap().unwrap(); - gateway.handle_event(msg_unwrapped).await; - }; + if msg.as_ref().unwrap().is_ok() { + let msg_unwrapped = msg.unwrap().unwrap(); + gateway + .handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped)) + .await; + + continue; + } + } + + // We couldn't receive the next message or it was an error, something is wrong with the websocket, close + println!("GW: Websocket is broken, stopping gateway"); + break; } }); return Ok(GatewayHandle { url: websocket_url.clone(), events: shared_events, - websocket_send: shared_gateway_send.clone(), + websocket_send: shared_websocket_send.clone(), handle, + kill_send: kill_send.clone(), }); } + /// Closes the websocket connection and stops all tasks + async fn close(&mut self) { + self.kill_send.send(()).unwrap(); + self.websocket_send.lock().await.close().await.unwrap(); + } + /// This handles a message as a websocket event and updates its events along with the events' observers - pub async fn handle_event(&mut self, msg: tokio_tungstenite::tungstenite::Message) { - if msg.to_string() == String::new() { + pub async fn handle_event(&mut self, msg: GatewayMessage) { + if msg.is_empty() { return; } - let gateway_payload: types::GatewayReceivePayload = - serde_json::from_str(msg.to_text().unwrap()).unwrap(); + // To:do: handle errors in a good way, maybe observers like events? + if msg.is_error() { + println!("GW: Received error, connection will close.."); + + let error = msg.error(); + + match error { + _ => {} + } + + self.close().await; + return; + } + + let gateway_payload = msg.payload().unwrap(); // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes match gateway_payload.op_code { @@ -1210,10 +1356,10 @@ impl Gateway { | GATEWAY_REQUEST_GUILD_MEMBERS | GATEWAY_CALL_SYNC | GATEWAY_LAZY_REQUEST => { - panic!( - "Received gateway op code that's meant to be sent, not received ({})", - gateway_payload.op_code - ) + let error = GatewayError::UnexpectedOpcodeReceivedError { + opcode: gateway_payload.op_code, + }; + Err::<(), GatewayError>(error).unwrap(); } _ => { println!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); @@ -1262,8 +1408,10 @@ impl HeartbeatHandler { >, >, >, + kill_rc: tokio::sync::broadcast::Receiver<()>, ) -> HeartbeatHandler { let (send, mut receive) = mpsc::channel(32); + let mut kill_receive = kill_rc.resubscribe(); let handle: JoinHandle<()> = task::spawn(async move { let mut last_heartbeat_timestamp: Instant = time::Instant::now(); @@ -1271,6 +1419,11 @@ impl HeartbeatHandler { let mut last_seq_number: Option = None; loop { + let should_shutdown = kill_receive.try_recv().is_ok(); + if should_shutdown { + break; + } + let mut should_send; let time_to_send = @@ -1323,7 +1476,12 @@ impl HeartbeatHandler { let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); - websocket_tx.lock().await.send(msg).await.unwrap(); + let send_result = websocket_tx.lock().await.send(msg).await; + if send_result.is_err() { + // We couldn't send, the websocket is broken + println!("GW: Couldnt send heartbeat, websocket seems broken"); + break; + } last_heartbeat_timestamp = time::Instant::now(); last_heartbeat_acknowledged = false;