Fixed most errors, simplified new generic traits

This commit is contained in:
bitfl0wer 2023-11-18 18:39:01 +01:00
parent ce7ff49ee4
commit 16bdd68d98
7 changed files with 129 additions and 85 deletions

View File

@ -1,5 +1,6 @@
use async_trait::async_trait;
use chorus::gateway::{GatewayCapable, GatewayHandleCapable};
use chorus::gateway::GatewayCapable;
use chorus::GatewayHandle;
use chorus::{
self,
gateway::Observer,
@ -32,7 +33,7 @@ async fn main() {
let websocket_url_spacebar = "wss://gateway.old.server.spacebar.chat/".to_string();
// Initiate the gateway connection
let gateway = Gateway::get_handle(websocket_url_spacebar).await.unwrap();
let gateway: GatewayHandle = Gateway::get_handle(websocket_url_spacebar).await.unwrap();
// Create an instance of our observer
let observer = ExampleObserver {};

View File

@ -1,6 +1,7 @@
use std::time::Duration;
use chorus::gateway::{GatewayCapable, GatewayHandleCapable};
use chorus::GatewayHandle;
use chorus::{self, types::GatewayIdentifyPayload, Gateway};
use tokio::time::sleep;
@ -11,7 +12,7 @@ async fn main() {
let websocket_url_spacebar = "wss://gateway.old.server.spacebar.chat/".to_string();
// Initiate the gateway connection, starting a listener in one thread and a heartbeat handler in another
let gateway = Gateway::get_handle(websocket_url_spacebar).await.unwrap();
let gateway: GatewayHandle = Gateway::get_handle(websocket_url_spacebar).await.unwrap();
// At this point, we are connected to the server and are sending heartbeats, however we still haven't authenticated

View File

@ -1,7 +1,7 @@
use chorus::instance::Instance;
use chorus::UrlBundle;
#[tokio::main]
#[tokio::main(flavor = "current_thread")]
async fn main() {
let bundle = UrlBundle::new(
"https://example.com/api".to_string(),

View File

@ -2,7 +2,7 @@ use chorus::instance::Instance;
use chorus::types::LoginSchema;
use chorus::UrlBundle;
#[tokio::main]
#[tokio::main(flavor = "current_thread")]
async fn main() {
let bundle = UrlBundle::new(
"https://example.com/api".to_string(),

View File

@ -1,5 +1,6 @@
use super::*;
// TODO: Make me not a trait and delete this file
#[async_trait]
impl
HeartbeatHandlerCapable<
@ -14,4 +15,36 @@ impl
fn get_heartbeat_interval(&self) -> Duration {
self.heartbeat_interval
}
fn new(
heartbeat_interval: Duration,
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> HeartbeatHandler {
let (send, receive) = tokio::sync::mpsc::channel(32);
let kill_receive = kill_rc.resubscribe();
let handle: JoinHandle<()> = task::spawn(async move {
HeartbeatHandler::heartbeat_task(
websocket_tx,
heartbeat_interval,
receive,
kill_receive,
)
.await;
});
Self {
heartbeat_interval,
send,
handle,
}
}
}

View File

@ -87,7 +87,7 @@ const GATEWAY_CALL_SYNC: u8 = 13;
/// See [types::LazyRequest]
const GATEWAY_LAZY_REQUEST: u8 = 14;
pub trait MessageCapable {
pub trait MessageCapable: From<tokio_tungstenite::tungstenite::Message> {
fn as_string(&self) -> Option<String>;
fn as_bytes(&self) -> Option<Vec<u8>>;
fn is_empty(&self) -> bool;
@ -461,6 +461,85 @@ pub struct HeartbeatHandler {
handle: JoinHandle<()>,
}
impl HeartbeatHandler {
pub async fn heartbeat_task<T: MessageCapable + Send + 'static, S: Sink<T> + Send>(
websocket_tx: Arc<Mutex<SplitSink<S, T>>>,
heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) {
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
let mut last_heartbeat_acknowledged = true;
let mut last_seq_number: Option<u64> = None;
safina_timer::start_timer_thread();
loop {
if kill_receive.try_recv().is_ok() {
trace!("GW: Closing heartbeat task");
break;
}
let timeout = if last_heartbeat_acknowledged {
heartbeat_interval
} else {
// If the server hasn't acknowledged our heartbeat we should resend it
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
};
let mut should_send = false;
tokio::select! {
() = sleep_until(last_heartbeat_timestamp + timeout) => {
should_send = true;
}
Some(communication) = receive.recv() => {
// If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() {
last_seq_number = communication.sequence_number;
}
if let Some(op_code) = communication.op_code {
match op_code {
GATEWAY_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true;
}
GATEWAY_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
}
}
}
}
if should_send {
trace!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
d: last_seq_number,
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let send_result = websocket_tx.lock().await.send(msg.into()).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
warn!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false;
}
}
}
}
#[async_trait(?Send)]
pub trait GatewayHandleCapable<T, S>
where
@ -570,84 +649,14 @@ where
}
#[async_trait]
// TODO: Make me not a trait!!
pub trait HeartbeatHandlerCapable<T: MessageCapable + Send + 'static, S: Sink<T>> {
fn get_send(&self) -> &Sender<HeartbeatThreadCommunication>;
fn as_arc_mutex(&self) -> Arc<Mutex<Self>>;
fn get_heartbeat_interval(&self) -> Duration;
async fn heartbeat_task(
websocket_tx: Arc<Mutex<SplitSink<S, T>>>,
#[allow(clippy::new_ret_no_self)]
fn new(
heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) {
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
let mut last_heartbeat_acknowledged = true;
let mut last_seq_number: Option<u64> = None;
safina_timer::start_timer_thread();
loop {
if kill_receive.try_recv().is_ok() {
trace!("GW: Closing heartbeat task");
break;
}
let timeout = if last_heartbeat_acknowledged {
heartbeat_interval
} else {
// If the server hasn't acknowledged our heartbeat we should resend it
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
};
let mut should_send = false;
tokio::select! {
() = sleep_until(last_heartbeat_timestamp + timeout) => {
should_send = true;
}
Some(communication) = receive.recv() => {
// If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() {
last_seq_number = communication.sequence_number;
}
if let Some(op_code) = communication.op_code {
match op_code {
GATEWAY_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true;
}
GATEWAY_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
}
}
}
}
if should_send {
trace!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
d: last_seq_number,
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
warn!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false;
}
}
}
websocket_tx: Arc<Mutex<SplitSink<S, T>>>,
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> HeartbeatHandler;
}

View File

@ -2,15 +2,15 @@ mod common;
use std::sync::{Arc, RwLock};
use chorus::gateway::*;
use chorus::types::{self, ChannelModifySchema, RoleCreateModifySchema, RoleObject};
use chorus::{gateway::*, GatewayHandle};
#[tokio::test]
/// Tests establishing a connection (hello and heartbeats) on the local gateway;
async fn test_gateway_establish() {
let bundle = common::setup().await;
DefaultGateway::get_handle(bundle.urls.wss.clone())
let _: GatewayHandle = DefaultGateway::get_handle(bundle.urls.wss.clone())
.await
.unwrap();
common::teardown(bundle).await
@ -21,7 +21,7 @@ async fn test_gateway_establish() {
async fn test_gateway_authenticate() {
let bundle = common::setup().await;
let gateway = DefaultGateway::get_handle(bundle.urls.wss.clone())
let gateway: GatewayHandle = DefaultGateway::get_handle(bundle.urls.wss.clone())
.await
.unwrap();