Refactor
This commit is contained in:
parent
74785b4b2b
commit
da98f3c581
126
src/gateway.rs
126
src/gateway.rs
|
@ -3,12 +3,12 @@ use crate::errors::ObserverError;
|
|||
use crate::gateway::events::Events;
|
||||
use crate::types;
|
||||
use futures_util::stream::SplitSink;
|
||||
use futures_util::stream::SplitStream;
|
||||
use futures_util::SinkExt;
|
||||
use futures_util::StreamExt;
|
||||
use native_tls::TlsConnector;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::Mutex;
|
||||
|
@ -282,7 +282,7 @@ impl GatewayHandle {
|
|||
}
|
||||
|
||||
/// Closes the websocket connection and stops all gateway tasks
|
||||
async fn close(&mut self) {
|
||||
pub async fn close(&mut self) {
|
||||
self.kill_send.send(()).unwrap();
|
||||
self.websocket_send.lock().await.close().await.unwrap();
|
||||
}
|
||||
|
@ -290,7 +290,7 @@ impl GatewayHandle {
|
|||
|
||||
pub struct Gateway {
|
||||
pub events: Arc<Mutex<Events>>,
|
||||
heartbeat_handler: Option<HeartbeatHandler>,
|
||||
heartbeat_handler: HeartbeatHandler,
|
||||
pub websocket_send: Arc<
|
||||
Mutex<
|
||||
SplitSink<
|
||||
|
@ -299,6 +299,7 @@ pub struct Gateway {
|
|||
>,
|
||||
>,
|
||||
>,
|
||||
pub websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||
kill_send: tokio::sync::broadcast::Sender<()>,
|
||||
}
|
||||
|
||||
|
@ -327,16 +328,7 @@ impl Gateway {
|
|||
let shared_websocket_send = Arc::new(Mutex::new(websocket_send));
|
||||
|
||||
// Create a shared broadcast channel for killing all gateway tasks
|
||||
let (kill_send, mut kill_receive) = tokio::sync::broadcast::channel::<()>(16);
|
||||
|
||||
let mut gateway = Gateway {
|
||||
events: Arc::new(Mutex::new(Events::default())),
|
||||
heartbeat_handler: None,
|
||||
websocket_send: shared_websocket_send.clone(),
|
||||
kill_send: kill_send.clone(),
|
||||
};
|
||||
|
||||
let shared_events = gateway.events.clone();
|
||||
let (kill_send, mut _kill_receive) = tokio::sync::broadcast::channel::<()>(16);
|
||||
|
||||
// Wait for the first hello and then spawn both tasks so we avoid nested tasks
|
||||
// This automatically spawns the heartbeat task, but from the main thread
|
||||
|
@ -354,23 +346,47 @@ impl Gateway {
|
|||
|
||||
let gateway_hello: types::HelloData =
|
||||
serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap();
|
||||
gateway.heartbeat_handler = Some(HeartbeatHandler::new(
|
||||
|
||||
let mut gateway = Gateway {
|
||||
events: Arc::new(Mutex::new(Events::default())),
|
||||
heartbeat_handler: HeartbeatHandler::new(
|
||||
gateway_hello.heartbeat_interval,
|
||||
shared_websocket_send.clone(),
|
||||
kill_send.subscribe(),
|
||||
));
|
||||
),
|
||||
websocket_send: shared_websocket_send.clone(),
|
||||
websocket_receive,
|
||||
kill_send: kill_send.clone(),
|
||||
};
|
||||
|
||||
let shared_events = gateway.events.clone();
|
||||
|
||||
// Now we can continuously check for messages in a different task, since we aren't going to receive another hello
|
||||
let handle: JoinHandle<()> = task::spawn(async move {
|
||||
gateway.gateway_listen_task().await;
|
||||
});
|
||||
|
||||
return Ok(GatewayHandle {
|
||||
url: websocket_url.clone(),
|
||||
events: shared_events,
|
||||
websocket_send: shared_websocket_send.clone(),
|
||||
handle,
|
||||
kill_send: kill_send.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
/// The main gateway listener task;
|
||||
///
|
||||
/// Can only be stopped by closing the websocket, cannot be made to listen for kill
|
||||
pub async fn gateway_listen_task(&mut self) {
|
||||
loop {
|
||||
let msg = websocket_receive.next().await;
|
||||
let msg = self.websocket_receive.next().await;
|
||||
|
||||
// This if chain can be much better but if let is unstable on stable rust
|
||||
if msg.as_ref().is_some() {
|
||||
if msg.as_ref().unwrap().is_ok() {
|
||||
let msg_unwrapped = msg.unwrap().unwrap();
|
||||
gateway
|
||||
.handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped))
|
||||
self.handle_event(GatewayMessage::from_tungstenite_message(msg_unwrapped))
|
||||
.await;
|
||||
|
||||
continue;
|
||||
|
@ -381,15 +397,6 @@ impl Gateway {
|
|||
println!("GW: Websocket is broken, stopping gateway");
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
return Ok(GatewayHandle {
|
||||
url: websocket_url.clone(),
|
||||
events: shared_events,
|
||||
websocket_send: shared_websocket_send.clone(),
|
||||
handle,
|
||||
kill_send: kill_send.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Closes the websocket connection and stops all tasks
|
||||
|
@ -1297,10 +1304,6 @@ impl Gateway {
|
|||
GATEWAY_HEARTBEAT => {
|
||||
println!("GW: Received Heartbeat // Heartbeat Request");
|
||||
|
||||
if self.heartbeat_handler.is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Tell the heartbeat handler it should send a heartbeat right away
|
||||
|
||||
let heartbeat_communication = HeartbeatThreadCommunication {
|
||||
|
@ -1309,8 +1312,6 @@ impl Gateway {
|
|||
};
|
||||
|
||||
self.heartbeat_handler
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.send
|
||||
.send(heartbeat_communication)
|
||||
.await
|
||||
|
@ -1330,10 +1331,6 @@ impl Gateway {
|
|||
GATEWAY_HEARTBEAT_ACK => {
|
||||
println!("GW: Received Heartbeat ACK");
|
||||
|
||||
if self.heartbeat_handler.is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Tell the heartbeat handler we received an ack
|
||||
|
||||
let heartbeat_communication = HeartbeatThreadCommunication {
|
||||
|
@ -1342,8 +1339,6 @@ impl Gateway {
|
|||
};
|
||||
|
||||
self.heartbeat_handler
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.send
|
||||
.send(heartbeat_communication)
|
||||
.await
|
||||
|
@ -1366,8 +1361,8 @@ impl Gateway {
|
|||
}
|
||||
}
|
||||
|
||||
// If we have an active heartbeat thread and we received a seq number we should let it know
|
||||
if gateway_payload.sequence_number.is_some() && self.heartbeat_handler.is_some() {
|
||||
// If we we received a seq number we should let it know
|
||||
if gateway_payload.sequence_number.is_some() {
|
||||
let heartbeat_communication = HeartbeatThreadCommunication {
|
||||
sequence_number: Some(gateway_payload.sequence_number.unwrap()),
|
||||
// Op code is irrelevant here
|
||||
|
@ -1375,8 +1370,6 @@ impl Gateway {
|
|||
};
|
||||
|
||||
self.heartbeat_handler
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.send
|
||||
.send(heartbeat_communication)
|
||||
.await
|
||||
|
@ -1410,10 +1403,43 @@ impl HeartbeatHandler {
|
|||
>,
|
||||
kill_rc: tokio::sync::broadcast::Receiver<()>,
|
||||
) -> HeartbeatHandler {
|
||||
let (send, mut receive) = mpsc::channel(32);
|
||||
let mut kill_receive = kill_rc.resubscribe();
|
||||
let (send, receive) = tokio::sync::mpsc::channel(32);
|
||||
let kill_receive = kill_rc.resubscribe();
|
||||
|
||||
let handle: JoinHandle<()> = task::spawn(async move {
|
||||
HeartbeatHandler::heartbeat_task(
|
||||
websocket_tx,
|
||||
heartbeat_interval,
|
||||
receive,
|
||||
kill_receive,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
Self {
|
||||
heartbeat_interval,
|
||||
send,
|
||||
handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// The main heartbeat task;
|
||||
///
|
||||
/// Can be killed by the kill broadcast;
|
||||
/// If the websocket is closed, will die out next time it tries to send a heartbeat;
|
||||
pub async fn heartbeat_task(
|
||||
websocket_tx: Arc<
|
||||
Mutex<
|
||||
SplitSink<
|
||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
tokio_tungstenite::tungstenite::Message,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
heartbeat_interval: u128,
|
||||
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
|
||||
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
|
||||
) {
|
||||
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
|
||||
let mut last_heartbeat_acknowledged = true;
|
||||
let mut last_seq_number: Option<u64> = None;
|
||||
|
@ -1426,8 +1452,7 @@ impl HeartbeatHandler {
|
|||
|
||||
let mut should_send;
|
||||
|
||||
let time_to_send =
|
||||
last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval;
|
||||
let time_to_send = last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval;
|
||||
|
||||
should_send = time_to_send;
|
||||
|
||||
|
@ -1487,13 +1512,6 @@ impl HeartbeatHandler {
|
|||
last_heartbeat_acknowledged = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
heartbeat_interval,
|
||||
send,
|
||||
handle,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue