diff --git a/Cargo.toml b/Cargo.toml index 2e51270..4126146 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ backend = ["poem", "sqlx"] client = [] [dependencies] -tokio = {version = "1.29.1"} +tokio = {version = "1.29.1", features = ["macros"]} serde = {version = "1.0.171", features = ["derive"]} serde_json = {version= "1.0.103", features = ["raw_value"]} serde-aux = "4.2.0" diff --git a/src/gateway.rs b/src/gateway.rs index 840c4b0..9e4eb9c 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -4,6 +4,8 @@ use crate::types; use crate::types::WebSocketEvent; use async_trait::async_trait; use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep_until; use futures_util::stream::SplitSink; use futures_util::stream::SplitStream; @@ -12,7 +14,6 @@ use futures_util::StreamExt; use log::{info, trace, warn}; use native_tls::TlsConnector; use tokio::net::TcpStream; -use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::Sender; use tokio::sync::Mutex; use tokio::task; @@ -71,7 +72,7 @@ const GATEWAY_CALL_SYNC: u8 = 13; const GATEWAY_LAZY_REQUEST: u8 = 14; /// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms -const HEARTBEAT_ACK_TIMEOUT: u128 = 2000; +const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; /// Represents a messsage received from the gateway. This will be either a [GatewayReceivePayload], containing events, or a [GatewayError]. /// This struct is used internally when handling messages. @@ -327,7 +328,7 @@ impl Gateway { let mut gateway = Gateway { events: shared_events.clone(), heartbeat_handler: HeartbeatHandler::new( - gateway_hello.heartbeat_interval, + Duration::from_millis(gateway_hello.heartbeat_interval), shared_websocket_send.clone(), kill_send.subscribe(), ), @@ -626,8 +627,8 @@ impl Gateway { /// Handles sending heartbeats to the gateway in another thread #[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used struct HeartbeatHandler { - /// The heartbeat interval in milliseconds - pub heartbeat_interval: u128, + /// How ofter heartbeats need to be sent at a minimum + pub heartbeat_interval: Duration, /// The send channel for the heartbeat thread pub send: Sender, /// The handle of the thread @@ -636,7 +637,7 @@ struct HeartbeatHandler { impl HeartbeatHandler { pub fn new( - heartbeat_interval: u128, + heartbeat_interval: Duration, websocket_tx: Arc< Mutex< SplitSink< @@ -680,7 +681,7 @@ impl HeartbeatHandler { >, >, >, - heartbeat_interval: u128, + heartbeat_interval: Duration, mut receive: tokio::sync::mpsc::Receiver, mut kill_receive: tokio::sync::broadcast::Receiver<()>, ) { @@ -689,49 +690,46 @@ impl HeartbeatHandler { let mut last_seq_number: Option = None; loop { - let should_shutdown = kill_receive.try_recv().is_ok(); - if should_shutdown { + if kill_receive.try_recv().is_ok() { trace!("GW: Closing heartbeat task"); break; } - let mut should_send; + let timeout = if last_heartbeat_acknowledged { + heartbeat_interval + } else { + // If the server hasn't acknowledged our heartbeat we should resend it + Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) + }; - let time_to_send = last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval; + let mut should_send = false; - should_send = time_to_send; - - let received_communication: Result = - receive.try_recv(); - if let Ok(communication) = received_communication { - // If we received a seq number update, use that as the last seq number - if communication.sequence_number.is_some() { - last_seq_number = communication.sequence_number; + tokio::select! { + () = sleep_until(last_heartbeat_timestamp + timeout) => { + should_send = true; } + Some(communication) = receive.recv() => { + // If we received a seq number update, use that as the last seq number + if communication.sequence_number.is_some() { + last_seq_number = communication.sequence_number; + } - if let Some(op_code) = communication.op_code { - match op_code { - GATEWAY_HEARTBEAT => { - // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately - should_send = true; + if let Some(op_code) = communication.op_code { + match op_code { + GATEWAY_HEARTBEAT => { + // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately + should_send = true; + } + GATEWAY_HEARTBEAT_ACK => { + // The server received our heartbeat + last_heartbeat_acknowledged = true; + } + _ => {} } - GATEWAY_HEARTBEAT_ACK => { - // The server received our heartbeat - last_heartbeat_acknowledged = true; - } - _ => {} } } } - // If the server hasn't acknowledged our heartbeat we should resend it - if !last_heartbeat_acknowledged - && last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT - { - should_send = true; - info!("GW: Timed out waiting for a heartbeat ack, resending"); - } - if should_send { trace!("GW: Sending Heartbeat.."); diff --git a/src/types/events/hello.rs b/src/types/events/hello.rs index e370f9c..44f1e4f 100644 --- a/src/types/events/hello.rs +++ b/src/types/events/hello.rs @@ -14,8 +14,7 @@ impl WebSocketEvent for GatewayHello {} /// Contains info on how often the client should send heartbeats to the server; pub struct HelloData { /// How often a client should send heartbeats, in milliseconds - // u128 because std used u128s for milliseconds - pub heartbeat_interval: u128, + pub heartbeat_interval: u64, } impl WebSocketEvent for HelloData {}