From b4888d2f702b83d34b9b6a1adf8ac53b5d0cd561 Mon Sep 17 00:00:00 2001 From: kozabrada123 <“kozabrada123@users.noreply.github.com”> Date: Sun, 7 May 2023 11:58:12 +0200 Subject: [PATCH] Rethink websockets, fix thread blocks --- src/gateway.rs | 92 +++++++++++++++++++------------------------------- 1 file changed, 35 insertions(+), 57 deletions(-) diff --git a/src/gateway.rs b/src/gateway.rs index 756f7f7..6f65f57 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,32 +1,24 @@ use std::sync::Arc; -use std::thread; use crate::api::types::*; use crate::api::WebSocketEvent; use crate::errors::ObserverError; use crate::gateway::events::Events; -use crate::URLBundle; -use futures_util::stream::{FilterMap, SplitSink, SplitStream}; use futures_util::SinkExt; use futures_util::StreamExt; +use futures_util::stream::SplitSink; use native_tls::TlsConnector; -use reqwest::Url; -use serde::Deserialize; -use serde::Serialize; -use serde_json::from_str; -use tokio::io; use tokio::net::TcpStream; use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::Mutex; use tokio::task; use tokio::time; use tokio::time::Instant; -use tokio_tungstenite::tungstenite::error::UrlError; -use tokio_tungstenite::WebSocketStream; -use tokio_tungstenite::{connect_async, connect_async_tls_with_config}; -use tokio_tungstenite::{Connector, MaybeTlsStream}; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::{WebSocketStream, Connector, connect_async_tls_with_config}; /** Represents a Gateway connection. A Gateway connection will create observable @@ -54,7 +46,13 @@ impl<'a> Gateway<'a> { /// This function reads all messages from the gateway's websocket and updates its events along with the events' observers pub async fn update_events(&mut self) { - while let Some(msg) = self.websocket.rx.lock().await.recv().await { + + while let Ok(msg) = self.websocket.rx.lock().await.try_recv() { + + if msg.to_string() == String::new() { + continue; + } + println!("Debug GW: Received WSE: {}", msg.to_string()); let gateway_payload: GatewayPayload = serde_json::from_str(msg.to_text().unwrap()).unwrap(); @@ -151,6 +149,8 @@ impl<'a> Gateway<'a> { "STAGE_INSTANCE_CREATE" => {} "STAGE_INSTANCE_UPDATE" => {} "STAGE_INSTANCE_DELETE" => {} + // Not documented in discord docs, I assume this isnt for bots / apps but is for users? + "SESSIONS_REPLACE" => {} "TYPING_START" => { let new_data: TypingStartEvent = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); self.events.user.typing_start_event.update_data(new_data); @@ -242,8 +242,8 @@ struct HeartbeatHandler { } impl HeartbeatHandler { - pub fn new(heartbeat_interval: u128, websocket_tx: Arc>>) -> HeartbeatHandler { - let (mut tx, mut rx) = mpsc::channel(32); + pub fn new(heartbeat_interval: u128, websocket_tx: Arc>, tokio_tungstenite::tungstenite::Message>>>) -> HeartbeatHandler { + let (tx, mut rx) = mpsc::channel(32); task::spawn(async move { let mut last_heartbeat: Instant = time::Instant::now(); @@ -252,8 +252,8 @@ impl HeartbeatHandler { loop { // If we received a seq number update, use that as the last seq number - let hb_communication: Option = rx.recv().await; - while hb_communication.is_some() { + let hb_communication: Result = rx.try_recv(); + if hb_communication.is_ok() { last_seq_number = Some(hb_communication.unwrap().d); } @@ -278,7 +278,6 @@ impl HeartbeatHandler { last_heartbeat = time::Instant::now(); } - } }); @@ -300,49 +299,35 @@ struct HeartbeatThreadCommunication { struct WebSocketConnection { rx: Arc>>, - tx: Arc>>, + tx: Arc>, tokio_tungstenite::tungstenite::Message>>>, } impl<'a> WebSocketConnection { async fn new(websocket_url: String) -> WebSocketConnection { - let parsed_url = Url::parse(&URLBundle::parse_url(websocket_url.clone())).unwrap(); - /*if parsed_url.scheme() != "ws" && parsed_url.scheme() != "wss" { - return Err(tokio_tungstenite::tungstenite::Error::Url( - UrlError::UnsupportedUrlScheme, - )); - }*/ - let (mut send_channel_write, mut send_channel_read): ( + let (receive_channel_write, receive_channel_read): ( Sender, Receiver, ) = channel(32); - let (mut receive_channel_write, mut receive_channel_read): ( - Sender, - Receiver, - ) = channel(32); - - let shared_send_channel_write = Arc::new(Mutex::new(send_channel_write)); let shared_receive_channel_read = Arc::new(Mutex::new(receive_channel_read)); + let (ws_stream, _) = match connect_async_tls_with_config( + &websocket_url, + None, + Some(Connector::NativeTls( + TlsConnector::builder().build().unwrap(), + )), + ) + .await + { + Ok(ws_stream) => ws_stream, + Err(e) => panic!("{:?}", e), + }; + + let (ws_tx, mut ws_rx) = ws_stream.split(); + task::spawn(async move { - let (mut ws_stream, _) = match connect_async_tls_with_config( - &websocket_url, - None, - Some(Connector::NativeTls( - TlsConnector::builder().build().unwrap(), - )), - ) - .await - { - Ok(ws_stream) => ws_stream, - Err(_) => return, /*return Err(tokio_tungstenite::tungstenite::Error::Io( - io::ErrorKind::ConnectionAborted.into(), - ))*/ - }; - - let (mut ws_tx, mut ws_rx) = ws_stream.split(); - loop { // Write received messages to the receive channel @@ -354,18 +339,11 @@ impl<'a> WebSocketConnection { .await .unwrap(); }; - - // Send messages from the send channel - let msg = send_channel_read.recv().await; - if msg.as_ref().is_some() { - let msg_unwrapped = msg.unwrap(); - ws_tx.send(msg_unwrapped).await.unwrap(); - } } }); WebSocketConnection { - tx: shared_send_channel_write, + tx: Arc::new(Mutex::new(ws_tx)), rx: shared_receive_channel_read, } }