From 16bdd68d98156921b8d9648a98b01a42671e1aa5 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Sat, 18 Nov 2023 18:39:01 +0100 Subject: [PATCH] Fixed most errors, simplified new generic traits --- examples/gateway_observers.rs | 5 +- examples/gateway_simple.rs | 3 +- examples/instance.rs | 2 +- examples/login.rs | 2 +- src/gateway/default/heartbeat.rs | 33 +++++++ src/gateway/mod.rs | 163 ++++++++++++++++--------------- tests/gateway.rs | 6 +- 7 files changed, 129 insertions(+), 85 deletions(-) diff --git a/examples/gateway_observers.rs b/examples/gateway_observers.rs index 511c568..921a048 100644 --- a/examples/gateway_observers.rs +++ b/examples/gateway_observers.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; -use chorus::gateway::{GatewayCapable, GatewayHandleCapable}; +use chorus::gateway::GatewayCapable; +use chorus::GatewayHandle; use chorus::{ self, gateway::Observer, @@ -32,7 +33,7 @@ async fn main() { let websocket_url_spacebar = "wss://gateway.old.server.spacebar.chat/".to_string(); // Initiate the gateway connection - let gateway = Gateway::get_handle(websocket_url_spacebar).await.unwrap(); + let gateway: GatewayHandle = Gateway::get_handle(websocket_url_spacebar).await.unwrap(); // Create an instance of our observer let observer = ExampleObserver {}; diff --git a/examples/gateway_simple.rs b/examples/gateway_simple.rs index e46d122..331fc6c 100644 --- a/examples/gateway_simple.rs +++ b/examples/gateway_simple.rs @@ -1,6 +1,7 @@ use std::time::Duration; use chorus::gateway::{GatewayCapable, GatewayHandleCapable}; +use chorus::GatewayHandle; use chorus::{self, types::GatewayIdentifyPayload, Gateway}; use tokio::time::sleep; @@ -11,7 +12,7 @@ async fn main() { let websocket_url_spacebar = "wss://gateway.old.server.spacebar.chat/".to_string(); // Initiate the gateway connection, starting a listener in one thread and a heartbeat handler in another - let gateway = Gateway::get_handle(websocket_url_spacebar).await.unwrap(); + let gateway: GatewayHandle = Gateway::get_handle(websocket_url_spacebar).await.unwrap(); // At this point, we are connected to the server and are sending heartbeats, however we still haven't authenticated diff --git a/examples/instance.rs b/examples/instance.rs index 337482b..d2a042f 100644 --- a/examples/instance.rs +++ b/examples/instance.rs @@ -1,7 +1,7 @@ use chorus::instance::Instance; use chorus::UrlBundle; -#[tokio::main] +#[tokio::main(flavor = "current_thread")] async fn main() { let bundle = UrlBundle::new( "https://example.com/api".to_string(), diff --git a/examples/login.rs b/examples/login.rs index 4595a06..b06eade 100644 --- a/examples/login.rs +++ b/examples/login.rs @@ -2,7 +2,7 @@ use chorus::instance::Instance; use chorus::types::LoginSchema; use chorus::UrlBundle; -#[tokio::main] +#[tokio::main(flavor = "current_thread")] async fn main() { let bundle = UrlBundle::new( "https://example.com/api".to_string(), diff --git a/src/gateway/default/heartbeat.rs b/src/gateway/default/heartbeat.rs index e7b7129..4b9fff7 100644 --- a/src/gateway/default/heartbeat.rs +++ b/src/gateway/default/heartbeat.rs @@ -1,5 +1,6 @@ use super::*; +// TODO: Make me not a trait and delete this file #[async_trait] impl HeartbeatHandlerCapable< @@ -14,4 +15,36 @@ impl fn get_heartbeat_interval(&self) -> Duration { self.heartbeat_interval } + + fn new( + heartbeat_interval: Duration, + websocket_tx: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + kill_rc: tokio::sync::broadcast::Receiver<()>, + ) -> HeartbeatHandler { + let (send, receive) = tokio::sync::mpsc::channel(32); + let kill_receive = kill_rc.resubscribe(); + + let handle: JoinHandle<()> = task::spawn(async move { + HeartbeatHandler::heartbeat_task( + websocket_tx, + heartbeat_interval, + receive, + kill_receive, + ) + .await; + }); + + Self { + heartbeat_interval, + send, + handle, + } + } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 071e52f..3d54137 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -87,7 +87,7 @@ const GATEWAY_CALL_SYNC: u8 = 13; /// See [types::LazyRequest] const GATEWAY_LAZY_REQUEST: u8 = 14; -pub trait MessageCapable { +pub trait MessageCapable: From { fn as_string(&self) -> Option; fn as_bytes(&self) -> Option>; fn is_empty(&self) -> bool; @@ -461,6 +461,85 @@ pub struct HeartbeatHandler { handle: JoinHandle<()>, } +impl HeartbeatHandler { + pub async fn heartbeat_task + Send>( + websocket_tx: Arc>>, + heartbeat_interval: Duration, + mut receive: tokio::sync::mpsc::Receiver, + mut kill_receive: tokio::sync::broadcast::Receiver<()>, + ) { + let mut last_heartbeat_timestamp: Instant = time::Instant::now(); + let mut last_heartbeat_acknowledged = true; + let mut last_seq_number: Option = None; + safina_timer::start_timer_thread(); + + loop { + if kill_receive.try_recv().is_ok() { + trace!("GW: Closing heartbeat task"); + break; + } + + let timeout = if last_heartbeat_acknowledged { + heartbeat_interval + } else { + // If the server hasn't acknowledged our heartbeat we should resend it + Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) + }; + + let mut should_send = false; + + tokio::select! { + () = sleep_until(last_heartbeat_timestamp + timeout) => { + should_send = true; + } + Some(communication) = receive.recv() => { + // If we received a seq number update, use that as the last seq number + if communication.sequence_number.is_some() { + last_seq_number = communication.sequence_number; + } + + if let Some(op_code) = communication.op_code { + match op_code { + GATEWAY_HEARTBEAT => { + // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately + should_send = true; + } + GATEWAY_HEARTBEAT_ACK => { + // The server received our heartbeat + last_heartbeat_acknowledged = true; + } + _ => {} + } + } + } + } + + if should_send { + trace!("GW: Sending Heartbeat.."); + + let heartbeat = types::GatewayHeartbeat { + op: GATEWAY_HEARTBEAT, + d: last_seq_number, + }; + + let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); + + let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); + + let send_result = websocket_tx.lock().await.send(msg.into()).await; + if send_result.is_err() { + // We couldn't send, the websocket is broken + warn!("GW: Couldnt send heartbeat, websocket seems broken"); + break; + } + + last_heartbeat_timestamp = time::Instant::now(); + last_heartbeat_acknowledged = false; + } + } + } +} + #[async_trait(?Send)] pub trait GatewayHandleCapable where @@ -570,84 +649,14 @@ where } #[async_trait] +// TODO: Make me not a trait!! pub trait HeartbeatHandlerCapable> { fn get_send(&self) -> &Sender; - fn as_arc_mutex(&self) -> Arc>; fn get_heartbeat_interval(&self) -> Duration; - async fn heartbeat_task( - websocket_tx: Arc>>, + #[allow(clippy::new_ret_no_self)] + fn new( heartbeat_interval: Duration, - mut receive: tokio::sync::mpsc::Receiver, - mut kill_receive: tokio::sync::broadcast::Receiver<()>, - ) { - let mut last_heartbeat_timestamp: Instant = time::Instant::now(); - let mut last_heartbeat_acknowledged = true; - let mut last_seq_number: Option = None; - safina_timer::start_timer_thread(); - - loop { - if kill_receive.try_recv().is_ok() { - trace!("GW: Closing heartbeat task"); - break; - } - - let timeout = if last_heartbeat_acknowledged { - heartbeat_interval - } else { - // If the server hasn't acknowledged our heartbeat we should resend it - Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) - }; - - let mut should_send = false; - - tokio::select! { - () = sleep_until(last_heartbeat_timestamp + timeout) => { - should_send = true; - } - Some(communication) = receive.recv() => { - // If we received a seq number update, use that as the last seq number - if communication.sequence_number.is_some() { - last_seq_number = communication.sequence_number; - } - - if let Some(op_code) = communication.op_code { - match op_code { - GATEWAY_HEARTBEAT => { - // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately - should_send = true; - } - GATEWAY_HEARTBEAT_ACK => { - // The server received our heartbeat - last_heartbeat_acknowledged = true; - } - _ => {} - } - } - } - } - - if should_send { - trace!("GW: Sending Heartbeat.."); - - let heartbeat = types::GatewayHeartbeat { - op: GATEWAY_HEARTBEAT, - d: last_seq_number, - }; - - let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); - - let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); - - let send_result = websocket_tx.lock().await.send(msg).await; - if send_result.is_err() { - // We couldn't send, the websocket is broken - warn!("GW: Couldnt send heartbeat, websocket seems broken"); - break; - } - - last_heartbeat_timestamp = time::Instant::now(); - last_heartbeat_acknowledged = false; - } - } - } + websocket_tx: Arc>>, + kill_rc: tokio::sync::broadcast::Receiver<()>, + ) -> HeartbeatHandler; } diff --git a/tests/gateway.rs b/tests/gateway.rs index 34c841f..f71c1fc 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -2,15 +2,15 @@ mod common; use std::sync::{Arc, RwLock}; -use chorus::gateway::*; use chorus::types::{self, ChannelModifySchema, RoleCreateModifySchema, RoleObject}; +use chorus::{gateway::*, GatewayHandle}; #[tokio::test] /// Tests establishing a connection (hello and heartbeats) on the local gateway; async fn test_gateway_establish() { let bundle = common::setup().await; - DefaultGateway::get_handle(bundle.urls.wss.clone()) + let _: GatewayHandle = DefaultGateway::get_handle(bundle.urls.wss.clone()) .await .unwrap(); common::teardown(bundle).await @@ -21,7 +21,7 @@ async fn test_gateway_establish() { async fn test_gateway_authenticate() { let bundle = common::setup().await; - let gateway = DefaultGateway::get_handle(bundle.urls.wss.clone()) + let gateway: GatewayHandle = DefaultGateway::get_handle(bundle.urls.wss.clone()) .await .unwrap();