From 595832ef63fd2d5f75114b80871c308675825d8c Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Thu, 16 Nov 2023 20:02:01 +0100 Subject: [PATCH] Make heartbeathandler shared struct --- src/gateway/default/gateway.rs | 8 +-- src/gateway/default/heartbeat.rs | 112 ++----------------------------- src/gateway/default/mod.rs | 6 +- src/gateway/mod.rs | 97 +++++++++++++++++++++++++- 4 files changed, 103 insertions(+), 120 deletions(-) diff --git a/src/gateway/default/gateway.rs b/src/gateway/default/gateway.rs index 49ce89e..8ec825d 100644 --- a/src/gateway/default/gateway.rs +++ b/src/gateway/default/gateway.rs @@ -7,7 +7,7 @@ use crate::types::{self, WebSocketEvent}; #[derive(Debug)] pub struct DefaultGateway { events: Arc>, - heartbeat_handler: DefaultHeartbeatHandler, + heartbeat_handler: HeartbeatHandler, websocket_send: Arc< Mutex< SplitSink< @@ -28,10 +28,10 @@ impl WebSocketStream>, WebSocketStream>, DefaultGatewayHandle, - DefaultHeartbeatHandler, + HeartbeatHandler, > for DefaultGateway { - fn get_heartbeat_handler(&self) -> &DefaultHeartbeatHandler { + fn get_heartbeat_handler(&self) -> &HeartbeatHandler { &self.heartbeat_handler } @@ -95,7 +95,7 @@ impl let mut gateway = DefaultGateway { events: shared_events.clone(), - heartbeat_handler: DefaultHeartbeatHandler::new( + heartbeat_handler: HeartbeatHandler::new( Duration::from_millis(gateway_hello.heartbeat_interval), shared_websocket_send.clone(), kill_send.subscribe(), diff --git a/src/gateway/default/heartbeat.rs b/src/gateway/default/heartbeat.rs index 8a6a8b5..ed39d8a 100644 --- a/src/gateway/default/heartbeat.rs +++ b/src/gateway/default/heartbeat.rs @@ -1,22 +1,7 @@ -use crate::types; - use super::*; -/// Handles sending heartbeats to the gateway in another thread -#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used -#[derive(Debug)] -pub struct DefaultHeartbeatHandler { - /// How ofter heartbeats need to be sent at a minimum - pub heartbeat_interval: Duration, - /// The send channel for the heartbeat thread - pub send: Sender, - /// The handle of the thread - handle: JoinHandle<()>, -} - -impl HeartbeatHandlerCapable>> - for DefaultHeartbeatHandler -{ +#[async_trait] +impl HeartbeatHandlerCapable>> for HeartbeatHandler { fn new( heartbeat_interval: Duration, websocket_tx: Arc< @@ -28,12 +13,12 @@ impl HeartbeatHandlerCapable>> >, >, kill_rc: tokio::sync::broadcast::Receiver<()>, - ) -> DefaultHeartbeatHandler { + ) -> HeartbeatHandler { let (send, receive) = tokio::sync::mpsc::channel(32); let kill_receive = kill_rc.resubscribe(); let handle: JoinHandle<()> = task::spawn(async move { - DefaultHeartbeatHandler::heartbeat_task( + HeartbeatHandler::heartbeat_task( websocket_tx, heartbeat_interval, receive, @@ -57,92 +42,3 @@ impl HeartbeatHandlerCapable>> self.heartbeat_interval } } - -impl DefaultHeartbeatHandler { - /// The main heartbeat task; - /// - /// Can be killed by the kill broadcast; - /// If the websocket is closed, will die out next time it tries to send a heartbeat; - pub async fn heartbeat_task( - websocket_tx: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - heartbeat_interval: Duration, - mut receive: tokio::sync::mpsc::Receiver, - mut kill_receive: tokio::sync::broadcast::Receiver<()>, - ) { - let mut last_heartbeat_timestamp: Instant = time::Instant::now(); - let mut last_heartbeat_acknowledged = true; - let mut last_seq_number: Option = None; - - loop { - if kill_receive.try_recv().is_ok() { - trace!("GW: Closing heartbeat task"); - break; - } - - let timeout = if last_heartbeat_acknowledged { - heartbeat_interval - } else { - // If the server hasn't acknowledged our heartbeat we should resend it - Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) - }; - - let mut should_send = false; - - tokio::select! { - () = sleep_until(last_heartbeat_timestamp + timeout) => { - should_send = true; - } - Some(communication) = receive.recv() => { - // If we received a seq number update, use that as the last seq number - if communication.sequence_number.is_some() { - last_seq_number = communication.sequence_number; - } - - if let Some(op_code) = communication.op_code { - match op_code { - GATEWAY_HEARTBEAT => { - // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately - should_send = true; - } - GATEWAY_HEARTBEAT_ACK => { - // The server received our heartbeat - last_heartbeat_acknowledged = true; - } - _ => {} - } - } - } - } - - if should_send { - trace!("GW: Sending Heartbeat.."); - - let heartbeat = types::GatewayHeartbeat { - op: GATEWAY_HEARTBEAT, - d: last_seq_number, - }; - - let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); - - let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); - - let send_result = websocket_tx.lock().await.send(msg).await; - if send_result.is_err() { - // We couldn't send, the websocket is broken - warn!("GW: Couldnt send heartbeat, websocket seems broken"); - break; - } - - last_heartbeat_timestamp = time::Instant::now(); - last_heartbeat_acknowledged = false; - } - } - } -} diff --git a/src/gateway/default/mod.rs b/src/gateway/default/mod.rs index 8ec29e8..14e5c1e 100644 --- a/src/gateway/default/mod.rs +++ b/src/gateway/default/mod.rs @@ -5,7 +5,6 @@ pub mod heartbeat; use super::*; pub use gateway::*; pub use handle::*; -use heartbeat::*; use tokio_tungstenite::tungstenite::Message; use crate::errors::GatewayError; @@ -15,18 +14,15 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::{Arc, RwLock}; use std::time::Duration; -use tokio::time::sleep_until; use futures_util::stream::SplitSink; use futures_util::stream::SplitStream; -use log::{info, trace, warn}; +use log::{info, warn}; use tokio::net::TcpStream; use tokio::sync::mpsc::Sender; use tokio::sync::Mutex; use tokio::task; use tokio::task::JoinHandle; -use tokio::time; -use tokio::time::Instant; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index bdf150d..e2e3552 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -8,6 +8,8 @@ pub mod wasm; #[cfg(all(not(target_arch = "wasm32"), feature = "client"))] pub use default::*; pub use message::*; +use safina_timer::sleep_until; +use tokio::task::JoinHandle; #[cfg(all(target_arch = "wasm32", feature = "client"))] pub use wasm::*; @@ -22,7 +24,7 @@ use crate::types::{ use std::any::Any; use std::collections::HashMap; use std::sync::{Arc, RwLock}; -use std::time::Duration; +use std::time::{self, Duration, Instant}; use async_trait::async_trait; use futures_util::stream::SplitSink; @@ -154,7 +156,7 @@ impl GatewayEvent { pub trait GatewayCapable where R: Stream, - S: Sink, + S: Sink + Send + 'static, G: GatewayHandleCapable, H: HeartbeatHandlerCapable + Send + Sync, { @@ -441,6 +443,18 @@ where } } +/// Handles sending heartbeats to the gateway in another thread +#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used +#[derive(Debug)] +pub struct HeartbeatHandler { + /// How ofter heartbeats need to be sent at a minimum + pub heartbeat_interval: Duration, + /// The send channel for the heartbeat thread + pub send: Sender, + /// The handle of the thread + handle: JoinHandle<()>, +} + #[async_trait(?Send)] pub trait GatewayHandleCapable where @@ -541,7 +555,8 @@ where async fn close(&self); } -pub trait HeartbeatHandlerCapable> { +#[async_trait] +pub trait HeartbeatHandlerCapable + Send + 'static> { fn new( heartbeat_interval: Duration, websocket_tx: Arc>>, @@ -550,4 +565,80 @@ pub trait HeartbeatHandlerCapable> { fn get_send(&self) -> &Sender; fn get_heartbeat_interval(&self) -> Duration; + async fn heartbeat_task( + websocket_tx: Arc>>, + heartbeat_interval: Duration, + mut receive: tokio::sync::mpsc::Receiver, + mut kill_receive: tokio::sync::broadcast::Receiver<()>, + ) { + let mut last_heartbeat_timestamp: Instant = time::Instant::now(); + let mut last_heartbeat_acknowledged = true; + let mut last_seq_number: Option = None; + safina_timer::start_timer_thread(); + + loop { + if kill_receive.try_recv().is_ok() { + trace!("GW: Closing heartbeat task"); + break; + } + + let timeout = if last_heartbeat_acknowledged { + heartbeat_interval + } else { + // If the server hasn't acknowledged our heartbeat we should resend it + Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) + }; + + let mut should_send = false; + + tokio::select! { + () = sleep_until(last_heartbeat_timestamp + timeout) => { + should_send = true; + } + Some(communication) = receive.recv() => { + // If we received a seq number update, use that as the last seq number + if communication.sequence_number.is_some() { + last_seq_number = communication.sequence_number; + } + + if let Some(op_code) = communication.op_code { + match op_code { + GATEWAY_HEARTBEAT => { + // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately + should_send = true; + } + GATEWAY_HEARTBEAT_ACK => { + // The server received our heartbeat + last_heartbeat_acknowledged = true; + } + _ => {} + } + } + } + } + + if should_send { + trace!("GW: Sending Heartbeat.."); + + let heartbeat = types::GatewayHeartbeat { + op: GATEWAY_HEARTBEAT, + d: last_seq_number, + }; + + let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); + + let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); + + let send_result = websocket_tx.lock().await.send(msg).await; + if send_result.is_err() { + // We couldn't send, the websocket is broken + warn!("GW: Couldnt send heartbeat, websocket seems broken"); + break; + } + + last_heartbeat_timestamp = time::Instant::now(); + last_heartbeat_acknowledged = false; + } + } + } }