Fix gateway heartbeat blocking (#162)

fix gateway heartbeat blocking
This commit is contained in:
SpecificProtagonist 2023-07-21 13:59:40 +02:00 committed by GitHub
parent b5a8562d89
commit f97d26bc6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 40 deletions

View File

@ -10,7 +10,7 @@ backend = ["poem", "sqlx"]
client = [] client = []
[dependencies] [dependencies]
tokio = {version = "1.29.1"} tokio = {version = "1.29.1", features = ["macros"]}
serde = {version = "1.0.171", features = ["derive"]} serde = {version = "1.0.171", features = ["derive"]}
serde_json = {version= "1.0.103", features = ["raw_value"]} serde_json = {version= "1.0.103", features = ["raw_value"]}
serde-aux = "4.2.0" serde-aux = "4.2.0"

View File

@ -4,6 +4,8 @@ use crate::types;
use crate::types::WebSocketEvent; use crate::types::WebSocketEvent;
use async_trait::async_trait; use async_trait::async_trait;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep_until;
use futures_util::stream::SplitSink; use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream; use futures_util::stream::SplitStream;
@ -12,7 +14,6 @@ use futures_util::StreamExt;
use log::{info, trace, warn}; use log::{info, trace, warn};
use native_tls::TlsConnector; use native_tls::TlsConnector;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task; use tokio::task;
@ -71,7 +72,7 @@ const GATEWAY_CALL_SYNC: u8 = 13;
const GATEWAY_LAZY_REQUEST: u8 = 14; const GATEWAY_LAZY_REQUEST: u8 = 14;
/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms /// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms
const HEARTBEAT_ACK_TIMEOUT: u128 = 2000; const HEARTBEAT_ACK_TIMEOUT: u64 = 2000;
/// Represents a messsage received from the gateway. This will be either a [GatewayReceivePayload], containing events, or a [GatewayError]. /// Represents a messsage received from the gateway. This will be either a [GatewayReceivePayload], containing events, or a [GatewayError].
/// This struct is used internally when handling messages. /// This struct is used internally when handling messages.
@ -327,7 +328,7 @@ impl Gateway {
let mut gateway = Gateway { let mut gateway = Gateway {
events: shared_events.clone(), events: shared_events.clone(),
heartbeat_handler: HeartbeatHandler::new( heartbeat_handler: HeartbeatHandler::new(
gateway_hello.heartbeat_interval, Duration::from_millis(gateway_hello.heartbeat_interval),
shared_websocket_send.clone(), shared_websocket_send.clone(),
kill_send.subscribe(), kill_send.subscribe(),
), ),
@ -626,8 +627,8 @@ impl Gateway {
/// Handles sending heartbeats to the gateway in another thread /// Handles sending heartbeats to the gateway in another thread
#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used #[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used
struct HeartbeatHandler { struct HeartbeatHandler {
/// The heartbeat interval in milliseconds /// How ofter heartbeats need to be sent at a minimum
pub heartbeat_interval: u128, pub heartbeat_interval: Duration,
/// The send channel for the heartbeat thread /// The send channel for the heartbeat thread
pub send: Sender<HeartbeatThreadCommunication>, pub send: Sender<HeartbeatThreadCommunication>,
/// The handle of the thread /// The handle of the thread
@ -636,7 +637,7 @@ struct HeartbeatHandler {
impl HeartbeatHandler { impl HeartbeatHandler {
pub fn new( pub fn new(
heartbeat_interval: u128, heartbeat_interval: Duration,
websocket_tx: Arc< websocket_tx: Arc<
Mutex< Mutex<
SplitSink< SplitSink<
@ -680,7 +681,7 @@ impl HeartbeatHandler {
>, >,
>, >,
>, >,
heartbeat_interval: u128, heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>, mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>, mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) { ) {
@ -689,21 +690,25 @@ impl HeartbeatHandler {
let mut last_seq_number: Option<u64> = None; let mut last_seq_number: Option<u64> = None;
loop { loop {
let should_shutdown = kill_receive.try_recv().is_ok(); if kill_receive.try_recv().is_ok() {
if should_shutdown {
trace!("GW: Closing heartbeat task"); trace!("GW: Closing heartbeat task");
break; break;
} }
let mut should_send; let timeout = if last_heartbeat_acknowledged {
heartbeat_interval
} else {
// If the server hasn't acknowledged our heartbeat we should resend it
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
};
let time_to_send = last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval; let mut should_send = false;
should_send = time_to_send; tokio::select! {
() = sleep_until(last_heartbeat_timestamp + timeout) => {
let received_communication: Result<HeartbeatThreadCommunication, TryRecvError> = should_send = true;
receive.try_recv(); }
if let Ok(communication) = received_communication { Some(communication) = receive.recv() => {
// If we received a seq number update, use that as the last seq number // If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() { if communication.sequence_number.is_some() {
last_seq_number = communication.sequence_number; last_seq_number = communication.sequence_number;
@ -723,13 +728,6 @@ impl HeartbeatHandler {
} }
} }
} }
// If the server hasn't acknowledged our heartbeat we should resend it
if !last_heartbeat_acknowledged
&& last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT
{
should_send = true;
info!("GW: Timed out waiting for a heartbeat ack, resending");
} }
if should_send { if should_send {

View File

@ -14,8 +14,7 @@ impl WebSocketEvent for GatewayHello {}
/// Contains info on how often the client should send heartbeats to the server; /// Contains info on how often the client should send heartbeats to the server;
pub struct HelloData { pub struct HelloData {
/// How often a client should send heartbeats, in milliseconds /// How often a client should send heartbeats, in milliseconds
// u128 because std used u128s for milliseconds pub heartbeat_interval: u64,
pub heartbeat_interval: u128,
} }
impl WebSocketEvent for HelloData {} impl WebSocketEvent for HelloData {}