Rethink websockets, fix thread blocks

This commit is contained in:
kozabrada123 2023-05-07 11:58:12 +02:00
parent d4a546efa3
commit ea6bacd7b8
1 changed files with 35 additions and 57 deletions

View File

@ -1,32 +1,24 @@
use std::sync::Arc; use std::sync::Arc;
use std::thread;
use crate::api::types::*; use crate::api::types::*;
use crate::api::WebSocketEvent; use crate::api::WebSocketEvent;
use crate::errors::ObserverError; use crate::errors::ObserverError;
use crate::gateway::events::Events; use crate::gateway::events::Events;
use crate::URLBundle;
use futures_util::stream::{FilterMap, SplitSink, SplitStream};
use futures_util::SinkExt; use futures_util::SinkExt;
use futures_util::StreamExt; use futures_util::StreamExt;
use futures_util::stream::SplitSink;
use native_tls::TlsConnector; use native_tls::TlsConnector;
use reqwest::Url;
use serde::Deserialize;
use serde::Serialize;
use serde_json::from_str;
use tokio::io;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task; use tokio::task;
use tokio::time; use tokio::time;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_tungstenite::tungstenite::error::UrlError; use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::{WebSocketStream, Connector, connect_async_tls_with_config};
use tokio_tungstenite::{connect_async, connect_async_tls_with_config};
use tokio_tungstenite::{Connector, MaybeTlsStream};
/** /**
Represents a Gateway connection. A Gateway connection will create observable Represents a Gateway connection. A Gateway connection will create observable
@ -54,7 +46,13 @@ impl<'a> Gateway<'a> {
/// This function reads all messages from the gateway's websocket and updates its events along with the events' observers /// This function reads all messages from the gateway's websocket and updates its events along with the events' observers
pub async fn update_events(&mut self) { pub async fn update_events(&mut self) {
while let Some(msg) = self.websocket.rx.lock().await.recv().await {
while let Ok(msg) = self.websocket.rx.lock().await.try_recv() {
if msg.to_string() == String::new() {
continue;
}
println!("Debug GW: Received WSE: {}", msg.to_string()); println!("Debug GW: Received WSE: {}", msg.to_string());
let gateway_payload: GatewayPayload = serde_json::from_str(msg.to_text().unwrap()).unwrap(); let gateway_payload: GatewayPayload = serde_json::from_str(msg.to_text().unwrap()).unwrap();
@ -151,6 +149,8 @@ impl<'a> Gateway<'a> {
"STAGE_INSTANCE_CREATE" => {} "STAGE_INSTANCE_CREATE" => {}
"STAGE_INSTANCE_UPDATE" => {} "STAGE_INSTANCE_UPDATE" => {}
"STAGE_INSTANCE_DELETE" => {} "STAGE_INSTANCE_DELETE" => {}
// Not documented in discord docs, I assume this isnt for bots / apps but is for users?
"SESSIONS_REPLACE" => {}
"TYPING_START" => { "TYPING_START" => {
let new_data: TypingStartEvent = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); let new_data: TypingStartEvent = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.user.typing_start_event.update_data(new_data); self.events.user.typing_start_event.update_data(new_data);
@ -242,8 +242,8 @@ struct HeartbeatHandler {
} }
impl HeartbeatHandler { impl HeartbeatHandler {
pub fn new(heartbeat_interval: u128, websocket_tx: Arc<Mutex<Sender<tokio_tungstenite::tungstenite::Message>>>) -> HeartbeatHandler { pub fn new(heartbeat_interval: u128, websocket_tx: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>) -> HeartbeatHandler {
let (mut tx, mut rx) = mpsc::channel(32); let (tx, mut rx) = mpsc::channel(32);
task::spawn(async move { task::spawn(async move {
let mut last_heartbeat: Instant = time::Instant::now(); let mut last_heartbeat: Instant = time::Instant::now();
@ -252,8 +252,8 @@ impl HeartbeatHandler {
loop { loop {
// If we received a seq number update, use that as the last seq number // If we received a seq number update, use that as the last seq number
let hb_communication: Option<HeartbeatThreadCommunication> = rx.recv().await; let hb_communication: Result<HeartbeatThreadCommunication, TryRecvError> = rx.try_recv();
while hb_communication.is_some() { if hb_communication.is_ok() {
last_seq_number = Some(hb_communication.unwrap().d); last_seq_number = Some(hb_communication.unwrap().d);
} }
@ -278,7 +278,6 @@ impl HeartbeatHandler {
last_heartbeat = time::Instant::now(); last_heartbeat = time::Instant::now();
} }
} }
}); });
@ -300,49 +299,35 @@ struct HeartbeatThreadCommunication {
struct WebSocketConnection { struct WebSocketConnection {
rx: Arc<Mutex<Receiver<tokio_tungstenite::tungstenite::Message>>>, rx: Arc<Mutex<Receiver<tokio_tungstenite::tungstenite::Message>>>,
tx: Arc<Mutex<Sender<tokio_tungstenite::tungstenite::Message>>>, tx: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>,
} }
impl<'a> WebSocketConnection { impl<'a> WebSocketConnection {
async fn new(websocket_url: String) -> WebSocketConnection { async fn new(websocket_url: String) -> WebSocketConnection {
let parsed_url = Url::parse(&URLBundle::parse_url(websocket_url.clone())).unwrap();
/*if parsed_url.scheme() != "ws" && parsed_url.scheme() != "wss" {
return Err(tokio_tungstenite::tungstenite::Error::Url(
UrlError::UnsupportedUrlScheme,
));
}*/
let (mut send_channel_write, mut send_channel_read): ( let (receive_channel_write, receive_channel_read): (
Sender<tokio_tungstenite::tungstenite::Message>, Sender<tokio_tungstenite::tungstenite::Message>,
Receiver<tokio_tungstenite::tungstenite::Message>, Receiver<tokio_tungstenite::tungstenite::Message>,
) = channel(32); ) = channel(32);
let (mut receive_channel_write, mut receive_channel_read): (
Sender<tokio_tungstenite::tungstenite::Message>,
Receiver<tokio_tungstenite::tungstenite::Message>,
) = channel(32);
let shared_send_channel_write = Arc::new(Mutex::new(send_channel_write));
let shared_receive_channel_read = Arc::new(Mutex::new(receive_channel_read)); let shared_receive_channel_read = Arc::new(Mutex::new(receive_channel_read));
let (ws_stream, _) = match connect_async_tls_with_config(
&websocket_url,
None,
Some(Connector::NativeTls(
TlsConnector::builder().build().unwrap(),
)),
)
.await
{
Ok(ws_stream) => ws_stream,
Err(e) => panic!("{:?}", e),
};
let (ws_tx, mut ws_rx) = ws_stream.split();
task::spawn(async move { task::spawn(async move {
let (mut ws_stream, _) = match connect_async_tls_with_config(
&websocket_url,
None,
Some(Connector::NativeTls(
TlsConnector::builder().build().unwrap(),
)),
)
.await
{
Ok(ws_stream) => ws_stream,
Err(_) => return, /*return Err(tokio_tungstenite::tungstenite::Error::Io(
io::ErrorKind::ConnectionAborted.into(),
))*/
};
let (mut ws_tx, mut ws_rx) = ws_stream.split();
loop { loop {
// Write received messages to the receive channel // Write received messages to the receive channel
@ -354,18 +339,11 @@ impl<'a> WebSocketConnection {
.await .await
.unwrap(); .unwrap();
}; };
// Send messages from the send channel
let msg = send_channel_read.recv().await;
if msg.as_ref().is_some() {
let msg_unwrapped = msg.unwrap();
ws_tx.send(msg_unwrapped).await.unwrap();
}
} }
}); });
WebSocketConnection { WebSocketConnection {
tx: shared_send_channel_write, tx: Arc::new(Mutex::new(ws_tx)),
rx: shared_receive_channel_read, rx: shared_receive_channel_read,
} }
} }