From da98f3c5818d864a48615425b66770788f288ac7 Mon Sep 17 00:00:00 2001 From: kozabrada123 <59031733+kozabrada123@users.noreply.github.com> Date: Thu, 8 Jun 2023 18:24:11 +0200 Subject: [PATCH] Refactor --- src/gateway.rs | 274 ++++++++++++++++++++++++++----------------------- 1 file changed, 146 insertions(+), 128 deletions(-) diff --git a/src/gateway.rs b/src/gateway.rs index 607fabf..d826bea 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -3,12 +3,12 @@ use crate::errors::ObserverError; use crate::gateway::events::Events; use crate::types; use futures_util::stream::SplitSink; +use futures_util::stream::SplitStream; use futures_util::SinkExt; use futures_util::StreamExt; use native_tls::TlsConnector; use std::sync::Arc; use tokio::net::TcpStream; -use tokio::sync::mpsc; use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::Sender; use tokio::sync::Mutex; @@ -282,7 +282,7 @@ impl GatewayHandle { } /// Closes the websocket connection and stops all gateway tasks - async fn close(&mut self) { + pub async fn close(&mut self) { self.kill_send.send(()).unwrap(); self.websocket_send.lock().await.close().await.unwrap(); } @@ -290,7 +290,7 @@ impl GatewayHandle { pub struct Gateway { pub events: Arc>, - heartbeat_handler: Option, + heartbeat_handler: HeartbeatHandler, pub websocket_send: Arc< Mutex< SplitSink< @@ -299,6 +299,7 @@ pub struct Gateway { >, >, >, + pub websocket_receive: SplitStream>>, kill_send: tokio::sync::broadcast::Sender<()>, } @@ -327,16 +328,7 @@ impl Gateway { let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); // Create a shared broadcast channel for killing all gateway tasks - let (kill_send, mut kill_receive) = tokio::sync::broadcast::channel::<()>(16); - - let mut gateway = Gateway { - events: Arc::new(Mutex::new(Events::default())), - heartbeat_handler: None, - websocket_send: shared_websocket_send.clone(), - kill_send: kill_send.clone(), - }; - - let shared_events = gateway.events.clone(); + let (kill_send, mut _kill_receive) = tokio::sync::broadcast::channel::<()>(16); // Wait for the first hello and then spawn both tasks so we avoid nested tasks // This automatically spawns the heartbeat task, but from the main thread @@ -354,33 +346,24 @@ impl Gateway { let gateway_hello: types::HelloData = serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap(); - gateway.heartbeat_handler = Some(HeartbeatHandler::new( - gateway_hello.heartbeat_interval, - shared_websocket_send.clone(), - kill_send.subscribe(), - )); + + let mut gateway = Gateway { + events: Arc::new(Mutex::new(Events::default())), + heartbeat_handler: HeartbeatHandler::new( + gateway_hello.heartbeat_interval, + shared_websocket_send.clone(), + kill_send.subscribe(), + ), + websocket_send: shared_websocket_send.clone(), + websocket_receive, + kill_send: kill_send.clone(), + }; + + let shared_events = gateway.events.clone(); // Now we can continuously check for messages in a different task, since we aren't going to receive another hello let handle: JoinHandle<()> = task::spawn(async move { - loop { - let msg = websocket_receive.next().await; - - // This if chain can be much better but if let is unstable on stable rust - if msg.as_ref().is_some() { - if msg.as_ref().unwrap().is_ok() { - let msg_unwrapped = msg.unwrap().unwrap(); - gateway - .handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped)) - .await; - - continue; - } - } - - // We couldn't receive the next message or it was an error, something is wrong with the websocket, close - println!("GW: Websocket is broken, stopping gateway"); - break; - } + gateway.gateway_listen_task().await; }); return Ok(GatewayHandle { @@ -392,6 +375,30 @@ impl Gateway { }); } + /// The main gateway listener task; + /// + /// Can only be stopped by closing the websocket, cannot be made to listen for kill + pub async fn gateway_listen_task(&mut self) { + loop { + let msg = self.websocket_receive.next().await; + + // This if chain can be much better but if let is unstable on stable rust + if msg.as_ref().is_some() { + if msg.as_ref().unwrap().is_ok() { + let msg_unwrapped = msg.unwrap().unwrap(); + self.handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped)) + .await; + + continue; + } + } + + // We couldn't receive the next message or it was an error, something is wrong with the websocket, close + println!("GW: Websocket is broken, stopping gateway"); + break; + } + } + /// Closes the websocket connection and stops all tasks async fn close(&mut self) { self.kill_send.send(()).unwrap(); @@ -1297,10 +1304,6 @@ impl Gateway { GATEWAY_HEARTBEAT => { println!("GW: Received Heartbeat // Heartbeat Request"); - if self.heartbeat_handler.is_none() { - return; - } - // Tell the heartbeat handler it should send a heartbeat right away let heartbeat_communication = HeartbeatThreadCommunication { @@ -1309,8 +1312,6 @@ impl Gateway { }; self.heartbeat_handler - .as_mut() - .unwrap() .send .send(heartbeat_communication) .await @@ -1330,10 +1331,6 @@ impl Gateway { GATEWAY_HEARTBEAT_ACK => { println!("GW: Received Heartbeat ACK"); - if self.heartbeat_handler.is_none() { - return; - } - // Tell the heartbeat handler we received an ack let heartbeat_communication = HeartbeatThreadCommunication { @@ -1342,8 +1339,6 @@ impl Gateway { }; self.heartbeat_handler - .as_mut() - .unwrap() .send .send(heartbeat_communication) .await @@ -1366,8 +1361,8 @@ impl Gateway { } } - // If we have an active heartbeat thread and we received a seq number we should let it know - if gateway_payload.sequence_number.is_some() && self.heartbeat_handler.is_some() { + // If we we received a seq number we should let it know + if gateway_payload.sequence_number.is_some() { let heartbeat_communication = HeartbeatThreadCommunication { sequence_number: Some(gateway_payload.sequence_number.unwrap()), // Op code is irrelevant here @@ -1375,8 +1370,6 @@ impl Gateway { }; self.heartbeat_handler - .as_mut() - .unwrap() .send .send(heartbeat_communication) .await @@ -1410,83 +1403,17 @@ impl HeartbeatHandler { >, kill_rc: tokio::sync::broadcast::Receiver<()>, ) -> HeartbeatHandler { - let (send, mut receive) = mpsc::channel(32); - let mut kill_receive = kill_rc.resubscribe(); + let (send, receive) = tokio::sync::mpsc::channel(32); + let kill_receive = kill_rc.resubscribe(); let handle: JoinHandle<()> = task::spawn(async move { - let mut last_heartbeat_timestamp: Instant = time::Instant::now(); - let mut last_heartbeat_acknowledged = true; - let mut last_seq_number: Option = None; - - loop { - let should_shutdown = kill_receive.try_recv().is_ok(); - if should_shutdown { - break; - } - - let mut should_send; - - let time_to_send = - last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval; - - should_send = time_to_send; - - let received_communication: Result = - receive.try_recv(); - if received_communication.is_ok() { - let communication = received_communication.unwrap(); - - // If we received a seq number update, use that as the last seq number - if communication.sequence_number.is_some() { - last_seq_number = Some(communication.sequence_number.unwrap()); - } - - if communication.op_code.is_some() { - match communication.op_code.unwrap() { - GATEWAY_HEARTBEAT => { - // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately - should_send = true; - } - GATEWAY_HEARTBEAT_ACK => { - // The server received our heartbeat - last_heartbeat_acknowledged = true; - } - _ => {} - } - } - } - - // If the server hasn't acknowledged our heartbeat we should resend it - if !last_heartbeat_acknowledged - && last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT - { - should_send = true; - println!("GW: Timed out waiting for a heartbeat ack, resending"); - } - - if should_send { - println!("GW: Sending Heartbeat.."); - - let heartbeat = types::GatewayHeartbeat { - op: GATEWAY_HEARTBEAT, - d: last_seq_number, - }; - - let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); - - let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); - - let send_result = websocket_tx.lock().await.send(msg).await; - if send_result.is_err() { - // We couldn't send, the websocket is broken - println!("GW: Couldnt send heartbeat, websocket seems broken"); - break; - } - - last_heartbeat_timestamp = time::Instant::now(); - last_heartbeat_acknowledged = false; - } - } + HeartbeatHandler::heartbeat_task( + websocket_tx, + heartbeat_interval, + receive, + kill_receive, + ) + .await; }); Self { @@ -1495,6 +1422,97 @@ impl HeartbeatHandler { handle, } } + + /// The main heartbeat task; + /// + /// Can be killed by the kill broadcast; + /// If the websocket is closed, will die out next time it tries to send a heartbeat; + pub async fn heartbeat_task( + websocket_tx: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + heartbeat_interval: u128, + mut receive: tokio::sync::mpsc::Receiver, + mut kill_receive: tokio::sync::broadcast::Receiver<()>, + ) { + let mut last_heartbeat_timestamp: Instant = time::Instant::now(); + let mut last_heartbeat_acknowledged = true; + let mut last_seq_number: Option = None; + + loop { + let should_shutdown = kill_receive.try_recv().is_ok(); + if should_shutdown { + break; + } + + let mut should_send; + + let time_to_send = last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval; + + should_send = time_to_send; + + let received_communication: Result = + receive.try_recv(); + if received_communication.is_ok() { + let communication = received_communication.unwrap(); + + // If we received a seq number update, use that as the last seq number + if communication.sequence_number.is_some() { + last_seq_number = Some(communication.sequence_number.unwrap()); + } + + if communication.op_code.is_some() { + match communication.op_code.unwrap() { + GATEWAY_HEARTBEAT => { + // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately + should_send = true; + } + GATEWAY_HEARTBEAT_ACK => { + // The server received our heartbeat + last_heartbeat_acknowledged = true; + } + _ => {} + } + } + } + + // If the server hasn't acknowledged our heartbeat we should resend it + if !last_heartbeat_acknowledged + && last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT + { + should_send = true; + println!("GW: Timed out waiting for a heartbeat ack, resending"); + } + + if should_send { + println!("GW: Sending Heartbeat.."); + + let heartbeat = types::GatewayHeartbeat { + op: GATEWAY_HEARTBEAT, + d: last_seq_number, + }; + + let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); + + let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); + + let send_result = websocket_tx.lock().await.send(msg).await; + if send_result.is_err() { + // We couldn't send, the websocket is broken + println!("GW: Couldnt send heartbeat, websocket seems broken"); + break; + } + + last_heartbeat_timestamp = time::Instant::now(); + last_heartbeat_acknowledged = false; + } + } + } } /**