Make heartbeathandler shared struct

This commit is contained in:
bitfl0wer 2023-11-16 20:02:01 +01:00
parent 0a8105d7fc
commit 595832ef63
No known key found for this signature in database
GPG Key ID: 0ACD574FCF5226CF
4 changed files with 103 additions and 120 deletions

View File

@ -7,7 +7,7 @@ use crate::types::{self, WebSocketEvent};
#[derive(Debug)] #[derive(Debug)]
pub struct DefaultGateway { pub struct DefaultGateway {
events: Arc<Mutex<Events>>, events: Arc<Mutex<Events>>,
heartbeat_handler: DefaultHeartbeatHandler, heartbeat_handler: HeartbeatHandler,
websocket_send: Arc< websocket_send: Arc<
Mutex< Mutex<
SplitSink< SplitSink<
@ -28,10 +28,10 @@ impl
WebSocketStream<MaybeTlsStream<TcpStream>>, WebSocketStream<MaybeTlsStream<TcpStream>>,
WebSocketStream<MaybeTlsStream<TcpStream>>, WebSocketStream<MaybeTlsStream<TcpStream>>,
DefaultGatewayHandle, DefaultGatewayHandle,
DefaultHeartbeatHandler, HeartbeatHandler,
> for DefaultGateway > for DefaultGateway
{ {
fn get_heartbeat_handler(&self) -> &DefaultHeartbeatHandler { fn get_heartbeat_handler(&self) -> &HeartbeatHandler {
&self.heartbeat_handler &self.heartbeat_handler
} }
@ -95,7 +95,7 @@ impl
let mut gateway = DefaultGateway { let mut gateway = DefaultGateway {
events: shared_events.clone(), events: shared_events.clone(),
heartbeat_handler: DefaultHeartbeatHandler::new( heartbeat_handler: HeartbeatHandler::new(
Duration::from_millis(gateway_hello.heartbeat_interval), Duration::from_millis(gateway_hello.heartbeat_interval),
shared_websocket_send.clone(), shared_websocket_send.clone(),
kill_send.subscribe(), kill_send.subscribe(),

View File

@ -1,22 +1,7 @@
use crate::types;
use super::*; use super::*;
/// Handles sending heartbeats to the gateway in another thread #[async_trait]
#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used impl HeartbeatHandlerCapable<WebSocketStream<MaybeTlsStream<TcpStream>>> for HeartbeatHandler {
#[derive(Debug)]
pub struct DefaultHeartbeatHandler {
/// How ofter heartbeats need to be sent at a minimum
pub heartbeat_interval: Duration,
/// The send channel for the heartbeat thread
pub send: Sender<HeartbeatThreadCommunication>,
/// The handle of the thread
handle: JoinHandle<()>,
}
impl HeartbeatHandlerCapable<WebSocketStream<MaybeTlsStream<TcpStream>>>
for DefaultHeartbeatHandler
{
fn new( fn new(
heartbeat_interval: Duration, heartbeat_interval: Duration,
websocket_tx: Arc< websocket_tx: Arc<
@ -28,12 +13,12 @@ impl HeartbeatHandlerCapable<WebSocketStream<MaybeTlsStream<TcpStream>>>
>, >,
>, >,
kill_rc: tokio::sync::broadcast::Receiver<()>, kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> DefaultHeartbeatHandler { ) -> HeartbeatHandler {
let (send, receive) = tokio::sync::mpsc::channel(32); let (send, receive) = tokio::sync::mpsc::channel(32);
let kill_receive = kill_rc.resubscribe(); let kill_receive = kill_rc.resubscribe();
let handle: JoinHandle<()> = task::spawn(async move { let handle: JoinHandle<()> = task::spawn(async move {
DefaultHeartbeatHandler::heartbeat_task( HeartbeatHandler::heartbeat_task(
websocket_tx, websocket_tx,
heartbeat_interval, heartbeat_interval,
receive, receive,
@ -57,92 +42,3 @@ impl HeartbeatHandlerCapable<WebSocketStream<MaybeTlsStream<TcpStream>>>
self.heartbeat_interval self.heartbeat_interval
} }
} }
impl DefaultHeartbeatHandler {
/// The main heartbeat task;
///
/// Can be killed by the kill broadcast;
/// If the websocket is closed, will die out next time it tries to send a heartbeat;
pub async fn heartbeat_task(
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) {
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
let mut last_heartbeat_acknowledged = true;
let mut last_seq_number: Option<u64> = None;
loop {
if kill_receive.try_recv().is_ok() {
trace!("GW: Closing heartbeat task");
break;
}
let timeout = if last_heartbeat_acknowledged {
heartbeat_interval
} else {
// If the server hasn't acknowledged our heartbeat we should resend it
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
};
let mut should_send = false;
tokio::select! {
() = sleep_until(last_heartbeat_timestamp + timeout) => {
should_send = true;
}
Some(communication) = receive.recv() => {
// If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() {
last_seq_number = communication.sequence_number;
}
if let Some(op_code) = communication.op_code {
match op_code {
GATEWAY_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true;
}
GATEWAY_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
}
}
}
}
if should_send {
trace!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
d: last_seq_number,
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
warn!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false;
}
}
}
}

View File

@ -5,7 +5,6 @@ pub mod heartbeat;
use super::*; use super::*;
pub use gateway::*; pub use gateway::*;
pub use handle::*; pub use handle::*;
use heartbeat::*;
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
use crate::errors::GatewayError; use crate::errors::GatewayError;
@ -15,18 +14,15 @@ use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use tokio::time::sleep_until;
use futures_util::stream::SplitSink; use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream; use futures_util::stream::SplitStream;
use log::{info, trace, warn}; use log::{info, warn};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task; use tokio::task;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::time;
use tokio::time::Instant;
use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream};

View File

@ -8,6 +8,8 @@ pub mod wasm;
#[cfg(all(not(target_arch = "wasm32"), feature = "client"))] #[cfg(all(not(target_arch = "wasm32"), feature = "client"))]
pub use default::*; pub use default::*;
pub use message::*; pub use message::*;
use safina_timer::sleep_until;
use tokio::task::JoinHandle;
#[cfg(all(target_arch = "wasm32", feature = "client"))] #[cfg(all(target_arch = "wasm32", feature = "client"))]
pub use wasm::*; pub use wasm::*;
@ -22,7 +24,7 @@ use crate::types::{
use std::any::Any; use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::{self, Duration, Instant};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::stream::SplitSink; use futures_util::stream::SplitSink;
@ -154,7 +156,7 @@ impl<T: WebSocketEvent> GatewayEvent<T> {
pub trait GatewayCapable<R, S, G, H> pub trait GatewayCapable<R, S, G, H>
where where
R: Stream, R: Stream,
S: Sink<Message>, S: Sink<Message> + Send + 'static,
G: GatewayHandleCapable<R, S>, G: GatewayHandleCapable<R, S>,
H: HeartbeatHandlerCapable<S> + Send + Sync, H: HeartbeatHandlerCapable<S> + Send + Sync,
{ {
@ -441,6 +443,18 @@ where
} }
} }
/// Handles sending heartbeats to the gateway in another thread
#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used
#[derive(Debug)]
pub struct HeartbeatHandler {
/// How ofter heartbeats need to be sent at a minimum
pub heartbeat_interval: Duration,
/// The send channel for the heartbeat thread
pub send: Sender<HeartbeatThreadCommunication>,
/// The handle of the thread
handle: JoinHandle<()>,
}
#[async_trait(?Send)] #[async_trait(?Send)]
pub trait GatewayHandleCapable<R, S> pub trait GatewayHandleCapable<R, S>
where where
@ -541,7 +555,8 @@ where
async fn close(&self); async fn close(&self);
} }
pub trait HeartbeatHandlerCapable<S: Sink<Message>> { #[async_trait]
pub trait HeartbeatHandlerCapable<S: Sink<Message> + Send + 'static> {
fn new( fn new(
heartbeat_interval: Duration, heartbeat_interval: Duration,
websocket_tx: Arc<Mutex<SplitSink<S, Message>>>, websocket_tx: Arc<Mutex<SplitSink<S, Message>>>,
@ -550,4 +565,80 @@ pub trait HeartbeatHandlerCapable<S: Sink<Message>> {
fn get_send(&self) -> &Sender<HeartbeatThreadCommunication>; fn get_send(&self) -> &Sender<HeartbeatThreadCommunication>;
fn get_heartbeat_interval(&self) -> Duration; fn get_heartbeat_interval(&self) -> Duration;
async fn heartbeat_task(
websocket_tx: Arc<Mutex<SplitSink<S, Message>>>,
heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) {
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
let mut last_heartbeat_acknowledged = true;
let mut last_seq_number: Option<u64> = None;
safina_timer::start_timer_thread();
loop {
if kill_receive.try_recv().is_ok() {
trace!("GW: Closing heartbeat task");
break;
}
let timeout = if last_heartbeat_acknowledged {
heartbeat_interval
} else {
// If the server hasn't acknowledged our heartbeat we should resend it
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
};
let mut should_send = false;
tokio::select! {
() = sleep_until(last_heartbeat_timestamp + timeout) => {
should_send = true;
}
Some(communication) = receive.recv() => {
// If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() {
last_seq_number = communication.sequence_number;
}
if let Some(op_code) = communication.op_code {
match op_code {
GATEWAY_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true;
}
GATEWAY_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
}
}
}
}
if should_send {
trace!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
d: last_seq_number,
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
warn!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false;
}
}
}
} }