Gateway basic error handling

This commit is contained in:
kozabrada123 2023-06-08 17:34:52 +02:00
parent 7d17a1c976
commit f4f17f7454
2 changed files with 219 additions and 29 deletions

View File

@ -31,3 +31,35 @@ custom_error! {
pub ObserverError pub ObserverError
AlreadySubscribedError = "Each event can only be subscribed to once." AlreadySubscribedError = "Each event can only be subscribed to once."
} }
custom_error! {
/// For errors we receive from the gateway, see https://discord-userdoccers.vercel.app/topics/opcodes-and-status-codes#gateway-close-event-codes;
///
/// Supposed to be sent as numbers, though they are sent as string most of the time?
///
/// Also includes errors when initiating a connection and unexpected opcodes
#[derive(PartialEq, Eq)]
pub GatewayError
// Errors we have received from the gateway
UnknownError = "We're not sure what went wrong. Try reconnecting?",
UnknownOpcodeError = "You sent an invalid Gateway opcode or an invalid payload for an opcode",
DecodeError = "Gateway server couldn't decode payload",
NotAuthenticatedError = "You sent a payload prior to identifying",
AuthenticationFailedError = "The account token sent with your identify payload is invalid",
AlreadyAuthenticatedError = "You've already identified, no need to reauthenticate",
InvalidSequenceNumberError = "The sequence number sent when resuming the session was invalid. Reconnect and start a new session",
RateLimitedError = "You are being rate limited!",
SessionTimedOutError = "Your session timed out. Reconnect and start a new one",
InvalidShardError = "You sent us an invalid shard when identifying",
ShardingRequiredError = "The session would have handled too many guilds - you are required to shard your connection in order to connect",
InvalidAPIVersionError = "You sent an invalid Gateway version",
InvalidIntentsError = "You sent an invalid intent",
DisallowedIntentsError = "You sent a disallowed intent. You may have tried to specify an intent that you have not enabled or are not approved for",
// Errors when initiating a gateway connection
CannotConnectError{error: String} = "Cannot connect due to a tungstenite error: {error}",
NonHelloOnInitiateError{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong",
// Other misc errors
UnexpectedOpcodeReceivedError{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}",
}

View File

@ -1,3 +1,4 @@
use crate::errors::GatewayError;
use crate::errors::ObserverError; use crate::errors::ObserverError;
use crate::gateway::events::Events; use crate::gateway::events::Events;
use crate::types; use crate::types;
@ -69,6 +70,106 @@ const GATEWAY_LAZY_REQUEST: u8 = 14;
/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms /// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms
const HEARTBEAT_ACK_TIMEOUT: u128 = 2000; const HEARTBEAT_ACK_TIMEOUT: u128 = 2000;
#[derive(Clone, Debug)]
/**
Represents a messsage received from the gateway. This will be either a [GatewayReceivePayload], containing events, or a [GatewayError].
This struct is used internally when handling messages.
*/
pub struct GatewayMessage {
/// The message we received from the server
message: tokio_tungstenite::tungstenite::Message,
}
impl GatewayMessage {
/// Creates self from a tungstenite message
pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self {
Self { message }
}
/// Parses the message as an error;
/// Returns the error if succesfully parsed, None if the message isn't an error
pub fn error(&self) -> Option<GatewayError> {
let content = self.message.to_string();
// Some error strings have dots on the end, which we don't care about
let processed_content = content.clone().to_lowercase().replace(".", "");
match processed_content.as_str() {
"unknown error" | "4000" => {
return Some(GatewayError::UnknownError);
}
"unknown opcode" | "4001" => {
return Some(GatewayError::UnknownOpcodeError);
}
"decode error" | "4002" => {
return Some(GatewayError::DecodeError);
}
"not authenticated" | "4003" => {
return Some(GatewayError::NotAuthenticatedError);
}
"authentication failed" | "4004" => {
return Some(GatewayError::AuthenticationFailedError);
}
"already authenticated" | "4005" => {
return Some(GatewayError::AlreadyAuthenticatedError);
}
"invalid seq" | "4007" => {
return Some(GatewayError::InvalidSequenceNumberError);
}
"rate limited" | "4008" => {
return Some(GatewayError::RateLimitedError);
}
"session timed out" | "4009" => {
return Some(GatewayError::SessionTimedOutError);
}
"invalid shard" | "4010" => {
return Some(GatewayError::InvalidShardError);
}
"sharding required" | "4011" => {
return Some(GatewayError::ShardingRequiredError);
}
"invalid api version" | "4012" => {
return Some(GatewayError::InvalidAPIVersionError);
}
"invalid intent(s)" | "invalid intent" | "4013" => {
return Some(GatewayError::InvalidIntentsError);
}
"disallowed intent(s)" | "disallowed intents" | "4014" => {
return Some(GatewayError::DisallowedIntentsError);
}
_ => {
return None;
}
}
}
/// Returns whether or not the message is an error
pub fn is_error(&self) -> bool {
return self.error().is_some();
}
/// Parses the message as a payload;
/// Returns a result of deserializing
pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> {
return serde_json::from_str(self.message.to_text().unwrap());
}
/// Returns whether or not the message is a payload
pub fn is_payload(&self) -> bool {
// close messages are never payloads, payloads are only text messages
if self.message.is_close() | !self.message.is_text() {
return false;
}
return self.payload().is_ok();
}
/// Returns whether or not the message is empty
pub fn is_empty(&self) -> bool {
return self.message.is_empty();
}
}
#[derive(Debug)] #[derive(Debug)]
/** /**
Represents a handle to a Gateway connection. A Gateway connection will create observable Represents a handle to a Gateway connection. A Gateway connection will create observable
@ -88,6 +189,8 @@ pub struct GatewayHandle {
>, >,
>, >,
pub handle: JoinHandle<()>, pub handle: JoinHandle<()>,
/// Tells gateway tasks to close
kill_send: tokio::sync::broadcast::Sender<()>,
} }
impl GatewayHandle { impl GatewayHandle {
@ -177,6 +280,12 @@ impl GatewayHandle {
self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value)
.await; .await;
} }
/// Closes the websocket connection and stops all gateway tasks
async fn close(&mut self) {
self.kill_send.send(()).unwrap();
self.websocket_send.lock().await.close().await.unwrap();
}
} }
pub struct Gateway { pub struct Gateway {
@ -190,12 +299,11 @@ pub struct Gateway {
>, >,
>, >,
>, >,
kill_send: tokio::sync::broadcast::Sender<()>,
} }
impl Gateway { impl Gateway {
pub async fn new( pub async fn new(websocket_url: String) -> Result<GatewayHandle, GatewayError> {
websocket_url: String,
) -> Result<GatewayHandle, tokio_tungstenite::tungstenite::Error> {
let (websocket_stream, _) = match connect_async_tls_with_config( let (websocket_stream, _) = match connect_async_tls_with_config(
&websocket_url, &websocket_url,
None, None,
@ -207,34 +315,39 @@ impl Gateway {
.await .await
{ {
Ok(websocket_stream) => websocket_stream, Ok(websocket_stream) => websocket_stream,
Err(e) => return Err(e), Err(e) => {
return Err(GatewayError::CannotConnectError {
error: e.to_string(),
})
}
}; };
let (gateway_send, mut gateway_receive) = websocket_stream.split(); let (websocket_send, mut websocket_receive) = websocket_stream.split();
let shared_gateway_send = Arc::new(Mutex::new(gateway_send)); let shared_websocket_send = Arc::new(Mutex::new(websocket_send));
// Create a shared broadcast channel for killing all gateway tasks
let (kill_send, mut kill_receive) = tokio::sync::broadcast::channel::<()>(16);
let mut gateway = Gateway { let mut gateway = Gateway {
events: Arc::new(Mutex::new(Events::default())), events: Arc::new(Mutex::new(Events::default())),
heartbeat_handler: None, heartbeat_handler: None,
websocket_send: shared_gateway_send.clone(), websocket_send: shared_websocket_send.clone(),
kill_send: kill_send.clone(),
}; };
let shared_events = gateway.events.clone(); let shared_events = gateway.events.clone();
// Wait for the first hello and then spawn both tasks so we avoid nested tasks // Wait for the first hello and then spawn both tasks so we avoid nested tasks
// This automatically spawns the heartbeat task, but from the main thread // This automatically spawns the heartbeat task, but from the main thread
let msg = gateway_receive.next().await.unwrap().unwrap(); let msg = websocket_receive.next().await.unwrap().unwrap();
let gateway_payload: types::GatewayReceivePayload = let gateway_payload: types::GatewayReceivePayload =
serde_json::from_str(msg.to_text().unwrap()).unwrap(); serde_json::from_str(msg.to_text().unwrap()).unwrap();
if gateway_payload.op_code != GATEWAY_HELLO { if gateway_payload.op_code != GATEWAY_HELLO {
println!("Received non hello on gateway init, what is happening?"); return Err(GatewayError::NonHelloOnInitiateError {
return Err(tokio_tungstenite::tungstenite::Error::Protocol( opcode: gateway_payload.op_code,
tokio_tungstenite::tungstenite::error::ProtocolError::InvalidOpcode( });
gateway_payload.op_code,
),
));
} }
println!("GW: Received Hello"); println!("GW: Received Hello");
@ -243,36 +356,69 @@ impl Gateway {
serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap(); serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap();
gateway.heartbeat_handler = Some(HeartbeatHandler::new( gateway.heartbeat_handler = Some(HeartbeatHandler::new(
gateway_hello.heartbeat_interval, gateway_hello.heartbeat_interval,
shared_gateway_send.clone(), shared_websocket_send.clone(),
kill_send.subscribe(),
)); ));
// Now we can continuously check for messages in a different task, since we aren't going to receive another hello // Now we can continuously check for messages in a different task, since we aren't going to receive another hello
let handle: JoinHandle<()> = task::spawn(async move { let handle: JoinHandle<()> = task::spawn(async move {
loop { loop {
let msg = gateway_receive.next().await; let msg = websocket_receive.next().await;
// This if chain can be much better but if let is unstable on stable rust
if msg.as_ref().is_some() { if msg.as_ref().is_some() {
let msg_unwrapped = msg.unwrap().unwrap(); if msg.as_ref().unwrap().is_ok() {
gateway.handle_event(msg_unwrapped).await; let msg_unwrapped = msg.unwrap().unwrap();
}; gateway
.handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped))
.await;
continue;
}
}
// We couldn't receive the next message or it was an error, something is wrong with the websocket, close
println!("GW: Websocket is broken, stopping gateway");
break;
} }
}); });
return Ok(GatewayHandle { return Ok(GatewayHandle {
url: websocket_url.clone(), url: websocket_url.clone(),
events: shared_events, events: shared_events,
websocket_send: shared_gateway_send.clone(), websocket_send: shared_websocket_send.clone(),
handle, handle,
kill_send: kill_send.clone(),
}); });
} }
/// Closes the websocket connection and stops all tasks
async fn close(&mut self) {
self.kill_send.send(()).unwrap();
self.websocket_send.lock().await.close().await.unwrap();
}
/// This handles a message as a websocket event and updates its events along with the events' observers /// This handles a message as a websocket event and updates its events along with the events' observers
pub async fn handle_event(&mut self, msg: tokio_tungstenite::tungstenite::Message) { pub async fn handle_event(&mut self, msg: GatewayMessage) {
if msg.to_string() == String::new() { if msg.is_empty() {
return; return;
} }
let gateway_payload: types::GatewayReceivePayload = // To:do: handle errors in a good way, maybe observers like events?
serde_json::from_str(msg.to_text().unwrap()).unwrap(); if msg.is_error() {
println!("GW: Received error, connection will close..");
let error = msg.error();
match error {
_ => {}
}
self.close().await;
return;
}
let gateway_payload = msg.payload().unwrap();
// See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes
match gateway_payload.op_code { match gateway_payload.op_code {
@ -1210,10 +1356,10 @@ impl Gateway {
| GATEWAY_REQUEST_GUILD_MEMBERS | GATEWAY_REQUEST_GUILD_MEMBERS
| GATEWAY_CALL_SYNC | GATEWAY_CALL_SYNC
| GATEWAY_LAZY_REQUEST => { | GATEWAY_LAZY_REQUEST => {
panic!( let error = GatewayError::UnexpectedOpcodeReceivedError {
"Received gateway op code that's meant to be sent, not received ({})", opcode: gateway_payload.op_code,
gateway_payload.op_code };
) Err::<(), GatewayError>(error).unwrap();
} }
_ => { _ => {
println!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); println!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code);
@ -1262,8 +1408,10 @@ impl HeartbeatHandler {
>, >,
>, >,
>, >,
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> HeartbeatHandler { ) -> HeartbeatHandler {
let (send, mut receive) = mpsc::channel(32); let (send, mut receive) = mpsc::channel(32);
let mut kill_receive = kill_rc.resubscribe();
let handle: JoinHandle<()> = task::spawn(async move { let handle: JoinHandle<()> = task::spawn(async move {
let mut last_heartbeat_timestamp: Instant = time::Instant::now(); let mut last_heartbeat_timestamp: Instant = time::Instant::now();
@ -1271,6 +1419,11 @@ impl HeartbeatHandler {
let mut last_seq_number: Option<u64> = None; let mut last_seq_number: Option<u64> = None;
loop { loop {
let should_shutdown = kill_receive.try_recv().is_ok();
if should_shutdown {
break;
}
let mut should_send; let mut should_send;
let time_to_send = let time_to_send =
@ -1323,7 +1476,12 @@ impl HeartbeatHandler {
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
websocket_tx.lock().await.send(msg).await.unwrap(); let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
println!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now(); last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false; last_heartbeat_acknowledged = false;