This commit is contained in:
kozabrada123 2023-06-08 18:24:11 +02:00
parent 74785b4b2b
commit da98f3c581
1 changed files with 146 additions and 128 deletions

View File

@ -3,12 +3,12 @@ use crate::errors::ObserverError;
use crate::gateway::events::Events;
use crate::types;
use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream;
use futures_util::SinkExt;
use futures_util::StreamExt;
use native_tls::TlsConnector;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
@ -282,7 +282,7 @@ impl GatewayHandle {
}
/// Closes the websocket connection and stops all gateway tasks
async fn close(&mut self) {
pub async fn close(&mut self) {
self.kill_send.send(()).unwrap();
self.websocket_send.lock().await.close().await.unwrap();
}
@ -290,7 +290,7 @@ impl GatewayHandle {
pub struct Gateway {
pub events: Arc<Mutex<Events>>,
heartbeat_handler: Option<HeartbeatHandler>,
heartbeat_handler: HeartbeatHandler,
pub websocket_send: Arc<
Mutex<
SplitSink<
@ -299,6 +299,7 @@ pub struct Gateway {
>,
>,
>,
pub websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
kill_send: tokio::sync::broadcast::Sender<()>,
}
@ -327,16 +328,7 @@ impl Gateway {
let shared_websocket_send = Arc::new(Mutex::new(websocket_send));
// Create a shared broadcast channel for killing all gateway tasks
let (kill_send, mut kill_receive) = tokio::sync::broadcast::channel::<()>(16);
let mut gateway = Gateway {
events: Arc::new(Mutex::new(Events::default())),
heartbeat_handler: None,
websocket_send: shared_websocket_send.clone(),
kill_send: kill_send.clone(),
};
let shared_events = gateway.events.clone();
let (kill_send, mut _kill_receive) = tokio::sync::broadcast::channel::<()>(16);
// Wait for the first hello and then spawn both tasks so we avoid nested tasks
// This automatically spawns the heartbeat task, but from the main thread
@ -354,33 +346,24 @@ impl Gateway {
let gateway_hello: types::HelloData =
serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap();
gateway.heartbeat_handler = Some(HeartbeatHandler::new(
gateway_hello.heartbeat_interval,
shared_websocket_send.clone(),
kill_send.subscribe(),
));
let mut gateway = Gateway {
events: Arc::new(Mutex::new(Events::default())),
heartbeat_handler: HeartbeatHandler::new(
gateway_hello.heartbeat_interval,
shared_websocket_send.clone(),
kill_send.subscribe(),
),
websocket_send: shared_websocket_send.clone(),
websocket_receive,
kill_send: kill_send.clone(),
};
let shared_events = gateway.events.clone();
// Now we can continuously check for messages in a different task, since we aren't going to receive another hello
let handle: JoinHandle<()> = task::spawn(async move {
loop {
let msg = websocket_receive.next().await;
// This if chain can be much better but if let is unstable on stable rust
if msg.as_ref().is_some() {
if msg.as_ref().unwrap().is_ok() {
let msg_unwrapped = msg.unwrap().unwrap();
gateway
.handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped))
.await;
continue;
}
}
// We couldn't receive the next message or it was an error, something is wrong with the websocket, close
println!("GW: Websocket is broken, stopping gateway");
break;
}
gateway.gateway_listen_task().await;
});
return Ok(GatewayHandle {
@ -392,6 +375,30 @@ impl Gateway {
});
}
/// The main gateway listener task;
///
/// Can only be stopped by closing the websocket, cannot be made to listen for kill
pub async fn gateway_listen_task(&mut self) {
loop {
let msg = self.websocket_receive.next().await;
// This if chain can be much better but if let is unstable on stable rust
if msg.as_ref().is_some() {
if msg.as_ref().unwrap().is_ok() {
let msg_unwrapped = msg.unwrap().unwrap();
self.handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped))
.await;
continue;
}
}
// We couldn't receive the next message or it was an error, something is wrong with the websocket, close
println!("GW: Websocket is broken, stopping gateway");
break;
}
}
/// Closes the websocket connection and stops all tasks
async fn close(&mut self) {
self.kill_send.send(()).unwrap();
@ -1297,10 +1304,6 @@ impl Gateway {
GATEWAY_HEARTBEAT => {
println!("GW: Received Heartbeat // Heartbeat Request");
if self.heartbeat_handler.is_none() {
return;
}
// Tell the heartbeat handler it should send a heartbeat right away
let heartbeat_communication = HeartbeatThreadCommunication {
@ -1309,8 +1312,6 @@ impl Gateway {
};
self.heartbeat_handler
.as_mut()
.unwrap()
.send
.send(heartbeat_communication)
.await
@ -1330,10 +1331,6 @@ impl Gateway {
GATEWAY_HEARTBEAT_ACK => {
println!("GW: Received Heartbeat ACK");
if self.heartbeat_handler.is_none() {
return;
}
// Tell the heartbeat handler we received an ack
let heartbeat_communication = HeartbeatThreadCommunication {
@ -1342,8 +1339,6 @@ impl Gateway {
};
self.heartbeat_handler
.as_mut()
.unwrap()
.send
.send(heartbeat_communication)
.await
@ -1366,8 +1361,8 @@ impl Gateway {
}
}
// If we have an active heartbeat thread and we received a seq number we should let it know
if gateway_payload.sequence_number.is_some() && self.heartbeat_handler.is_some() {
// If we we received a seq number we should let it know
if gateway_payload.sequence_number.is_some() {
let heartbeat_communication = HeartbeatThreadCommunication {
sequence_number: Some(gateway_payload.sequence_number.unwrap()),
// Op code is irrelevant here
@ -1375,8 +1370,6 @@ impl Gateway {
};
self.heartbeat_handler
.as_mut()
.unwrap()
.send
.send(heartbeat_communication)
.await
@ -1410,83 +1403,17 @@ impl HeartbeatHandler {
>,
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> HeartbeatHandler {
let (send, mut receive) = mpsc::channel(32);
let mut kill_receive = kill_rc.resubscribe();
let (send, receive) = tokio::sync::mpsc::channel(32);
let kill_receive = kill_rc.resubscribe();
let handle: JoinHandle<()> = task::spawn(async move {
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
let mut last_heartbeat_acknowledged = true;
let mut last_seq_number: Option<u64> = None;
loop {
let should_shutdown = kill_receive.try_recv().is_ok();
if should_shutdown {
break;
}
let mut should_send;
let time_to_send =
last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval;
should_send = time_to_send;
let received_communication: Result<HeartbeatThreadCommunication, TryRecvError> =
receive.try_recv();
if received_communication.is_ok() {
let communication = received_communication.unwrap();
// If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() {
last_seq_number = Some(communication.sequence_number.unwrap());
}
if communication.op_code.is_some() {
match communication.op_code.unwrap() {
GATEWAY_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true;
}
GATEWAY_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
}
}
}
// If the server hasn't acknowledged our heartbeat we should resend it
if !last_heartbeat_acknowledged
&& last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT
{
should_send = true;
println!("GW: Timed out waiting for a heartbeat ack, resending");
}
if should_send {
println!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
d: last_seq_number,
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
println!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false;
}
}
HeartbeatHandler::heartbeat_task(
websocket_tx,
heartbeat_interval,
receive,
kill_receive,
)
.await;
});
Self {
@ -1495,6 +1422,97 @@ impl HeartbeatHandler {
handle,
}
}
/// The main heartbeat task;
///
/// Can be killed by the kill broadcast;
/// If the websocket is closed, will die out next time it tries to send a heartbeat;
pub async fn heartbeat_task(
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
heartbeat_interval: u128,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) {
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
let mut last_heartbeat_acknowledged = true;
let mut last_seq_number: Option<u64> = None;
loop {
let should_shutdown = kill_receive.try_recv().is_ok();
if should_shutdown {
break;
}
let mut should_send;
let time_to_send = last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval;
should_send = time_to_send;
let received_communication: Result<HeartbeatThreadCommunication, TryRecvError> =
receive.try_recv();
if received_communication.is_ok() {
let communication = received_communication.unwrap();
// If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() {
last_seq_number = Some(communication.sequence_number.unwrap());
}
if communication.op_code.is_some() {
match communication.op_code.unwrap() {
GATEWAY_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true;
}
GATEWAY_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
}
}
}
// If the server hasn't acknowledged our heartbeat we should resend it
if !last_heartbeat_acknowledged
&& last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT
{
should_send = true;
println!("GW: Timed out waiting for a heartbeat ack, resending");
}
if should_send {
println!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
d: last_seq_number,
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
println!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false;
}
}
}
}
/**