checkpoint

This commit is contained in:
bitfl0wer 2023-11-18 18:08:12 +01:00
parent 17a58f6f40
commit ce7ff49ee4
7 changed files with 105 additions and 76 deletions

View File

@ -4,7 +4,7 @@ use reqwest::Client;
use serde_json::to_string; use serde_json::to_string;
use crate::errors::ChorusResult; use crate::errors::ChorusResult;
use crate::gateway::{GatewayCapable, GatewayHandleCapable}; use crate::gateway::{DefaultGatewayHandle, GatewayCapable, GatewayHandleCapable};
use crate::instance::{ChorusUser, Instance}; use crate::instance::{ChorusUser, Instance};
use crate::ratelimiter::ChorusRequest; use crate::ratelimiter::ChorusRequest;
use crate::types::{GatewayIdentifyPayload, LimitType, LoginResult, LoginSchema}; use crate::types::{GatewayIdentifyPayload, LimitType, LoginResult, LoginSchema};
@ -37,7 +37,8 @@ impl Instance {
self.limits_information.as_mut().unwrap().ratelimits = shell.limits.clone().unwrap(); self.limits_information.as_mut().unwrap().ratelimits = shell.limits.clone().unwrap();
} }
let mut identify = GatewayIdentifyPayload::common(); let mut identify = GatewayIdentifyPayload::common();
let gateway = Gateway::get_handle(self.urls.wss.clone()).await.unwrap(); let gateway: DefaultGatewayHandle =
Gateway::get_handle(self.urls.wss.clone()).await.unwrap();
identify.token = login_result.token.clone(); identify.token = login_result.token.clone();
gateway.send_identify(identify).await; gateway.send_identify(identify).await;
let user = ChorusUser::new( let user = ChorusUser::new(

View File

@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock};
pub use login::*; pub use login::*;
pub use register::*; pub use register::*;
use crate::gateway::{GatewayCapable, GatewayHandleCapable}; use crate::gateway::{DefaultGatewayHandle, GatewayCapable, GatewayHandleCapable};
use crate::{ use crate::{
errors::ChorusResult, errors::ChorusResult,
gateway::DefaultGateway, gateway::DefaultGateway,
@ -26,7 +26,7 @@ impl Instance {
.await .await
.unwrap(); .unwrap();
let mut identify = GatewayIdentifyPayload::common(); let mut identify = GatewayIdentifyPayload::common();
let gateway = DefaultGateway::get_handle(self.urls.wss.clone()) let gateway: DefaultGatewayHandle = DefaultGateway::get_handle(self.urls.wss.clone())
.await .await
.unwrap(); .unwrap();
identify.token = token.clone(); identify.token = token.clone();

View File

@ -1,4 +1,5 @@
use futures_util::StreamExt; use futures_util::StreamExt;
use tokio_tungstenite::tungstenite::Message;
use super::events::Events; use super::events::Events;
use super::*; use super::*;
@ -25,18 +26,19 @@ pub struct DefaultGateway {
#[async_trait] #[async_trait]
impl impl
GatewayCapable< GatewayCapable<
tokio_tungstenite::tungstenite::Message,
WebSocketStream<MaybeTlsStream<TcpStream>>, WebSocketStream<MaybeTlsStream<TcpStream>>,
WebSocketStream<MaybeTlsStream<TcpStream>>,
DefaultGatewayHandle,
HeartbeatHandler,
> for DefaultGateway > for DefaultGateway
{ {
fn get_heartbeat_handler(&self) -> &HeartbeatHandler { fn get_heartbeat_handler(&self) -> &HeartbeatHandler {
&self.heartbeat_handler &self.heartbeat_handler
} }
#[allow(clippy::new_ret_no_self)] async fn get_handle<
async fn get_handle(websocket_url: String) -> Result<DefaultGatewayHandle, GatewayError> { G: GatewayHandleCapable<Message, WebSocketStream<MaybeTlsStream<TcpStream>>>,
>(
websocket_url: String,
) -> Result<G, GatewayError> {
let mut roots = rustls::RootCertStore::empty(); let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{ {
@ -112,13 +114,13 @@ impl
gateway.gateway_listen_task().await; gateway.gateway_listen_task().await;
}); });
Ok(DefaultGatewayHandle { Ok(G::new(
url: websocket_url.clone(), websocket_url.clone(),
events: shared_events, shared_events,
websocket_send: shared_websocket_send.clone(), shared_websocket_send.clone(),
kill_send: kill_send.clone(), kill_send.clone(),
store, store,
}) ))
} }
/// Closes the websocket connection and stops all tasks /// Closes the websocket connection and stops all tasks

View File

@ -4,7 +4,7 @@ use crate::types::{self, Composite};
#[async_trait(?Send)] #[async_trait(?Send)]
impl impl
GatewayHandleCapable< GatewayHandleCapable<
WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message,
WebSocketStream<MaybeTlsStream<TcpStream>>, WebSocketStream<MaybeTlsStream<TcpStream>>,
> for DefaultGatewayHandle > for DefaultGatewayHandle
{ {
@ -23,6 +23,29 @@ impl
self.kill_send.send(()).unwrap(); self.kill_send.send(()).unwrap();
self.websocket_send.lock().await.close().await.unwrap(); self.websocket_send.lock().await.close().await.unwrap();
} }
fn new(
url: String,
events: Arc<Mutex<Events>>,
websocket_send: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
kill_send: tokio::sync::broadcast::Sender<()>,
store: GatewayStore,
) -> Self {
Self {
url,
events,
websocket_send,
kill_send,
store,
}
}
} }
/// Represents a handle to a Gateway connection. A Gateway connection will create observable /// Represents a handle to a Gateway connection. A Gateway connection will create observable

View File

@ -1,39 +1,12 @@
use super::*; use super::*;
#[async_trait] #[async_trait]
impl HeartbeatHandlerCapable<WebSocketStream<MaybeTlsStream<TcpStream>>> for HeartbeatHandler { impl
fn new( HeartbeatHandlerCapable<
heartbeat_interval: Duration,
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message, tokio_tungstenite::tungstenite::Message,
>, WebSocketStream<MaybeTlsStream<TcpStream>>,
>, > for HeartbeatHandler
>, {
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> HeartbeatHandler {
let (send, receive) = tokio::sync::mpsc::channel(32);
let kill_receive = kill_rc.resubscribe();
let handle: JoinHandle<()> = task::spawn(async move {
HeartbeatHandler::heartbeat_task(
websocket_tx,
heartbeat_interval,
receive,
kill_receive,
)
.await;
});
Self {
heartbeat_interval,
send,
handle,
}
}
fn get_send(&self) -> &Sender<HeartbeatThreadCommunication> { fn get_send(&self) -> &Sender<HeartbeatThreadCommunication> {
&self.send &self.send
} }

View File

@ -26,6 +26,27 @@ use tokio::task::JoinHandle;
use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream};
impl crate::gateway::MessageCapable for tokio_tungstenite::tungstenite::Message {
fn as_string(&self) -> Option<String> {
match self {
Message::Text(text) => Some(text.clone()),
_ => None,
}
}
fn is_empty(&self) -> bool {
todo!()
}
fn is_error(&self) -> bool {
todo!()
}
fn as_bytes(&self) -> Option<Vec<u8>> {
todo!()
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::types; use crate::types;

View File

@ -29,11 +29,10 @@ use std::time::{self, Duration, Instant};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::stream::SplitSink; use futures_util::stream::SplitSink;
use futures_util::Sink; use futures_util::Sink;
use futures_util::{SinkExt, Stream}; use futures_util::SinkExt;
use log::{info, trace, warn}; use log::{info, trace, warn};
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;
pub type GatewayStore = Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>; pub type GatewayStore = Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>;
@ -88,6 +87,13 @@ const GATEWAY_CALL_SYNC: u8 = 13;
/// See [types::LazyRequest] /// See [types::LazyRequest]
const GATEWAY_LAZY_REQUEST: u8 = 14; const GATEWAY_LAZY_REQUEST: u8 = 14;
pub trait MessageCapable {
fn as_string(&self) -> Option<String>;
fn as_bytes(&self) -> Option<Vec<u8>>;
fn is_empty(&self) -> bool;
fn is_error(&self) -> bool;
}
pub type ObservableObject = dyn Send + Sync + Any; pub type ObservableObject = dyn Send + Sync + Any;
/// Used for communications between the heartbeat and gateway thread. /// Used for communications between the heartbeat and gateway thread.
@ -153,22 +159,22 @@ impl<T: WebSocketEvent> GatewayEvent<T> {
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
#[async_trait] #[async_trait]
pub trait GatewayCapable<R, S, G, H> pub trait GatewayCapable<T, S>
where where
R: Stream, T: MessageCapable + Send + 'static,
S: Sink<Message> + Send + 'static, S: Sink<T> + Send,
G: GatewayHandleCapable<R, S>,
H: HeartbeatHandlerCapable<S> + Send + Sync,
{ {
fn get_events(&self) -> Arc<Mutex<Events>>; fn get_events(&self) -> Arc<Mutex<Events>>;
fn get_websocket_send(&self) -> Arc<Mutex<SplitSink<S, Message>>>; fn get_websocket_send(&self) -> Arc<Mutex<SplitSink<S, T>>>;
fn get_store(&self) -> GatewayStore; fn get_store(&self) -> GatewayStore;
fn get_url(&self) -> String; fn get_url(&self) -> String;
fn get_heartbeat_handler(&self) -> &H; fn get_heartbeat_handler(&self) -> &HeartbeatHandler;
/// Returns a Result with a matching impl of [`GatewayHandleCapable`], or a [`GatewayError`] /// Returns a Result with a matching impl of [`GatewayHandleCapable`], or a [`GatewayError`]
/// ///
/// DOCUMENTME: Explain what this method has to do to be a good get_handle() impl, or link to such documentation /// DOCUMENTME: Explain what this method has to do to be a good get_handle() impl, or link to such documentation
async fn get_handle(websocket_url: String) -> Result<G, GatewayError>; async fn get_handle<G: GatewayHandleCapable<T, S>>(
websocket_url: String,
) -> Result<G, GatewayError>;
async fn close(&mut self); async fn close(&mut self);
/// This handles a message as a websocket event and updates its events along with the events' observers /// This handles a message as a websocket event and updates its events along with the events' observers
async fn handle_message(&mut self, msg: GatewayMessage) { async fn handle_message(&mut self, msg: GatewayMessage) {
@ -456,27 +462,35 @@ pub struct HeartbeatHandler {
} }
#[async_trait(?Send)] #[async_trait(?Send)]
pub trait GatewayHandleCapable<R, S> pub trait GatewayHandleCapable<T, S>
where where
R: Stream, T: MessageCapable + Send + 'static,
S: Sink<Message>, S: Sink<T>,
{ {
fn new(
url: String,
events: Arc<Mutex<Events>>,
websocket_send: Arc<Mutex<SplitSink<S, T>>>,
kill_send: tokio::sync::broadcast::Sender<()>,
store: GatewayStore,
) -> Self;
/// Sends json to the gateway with an opcode /// Sends json to the gateway with an opcode
async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value); async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value);
/// Observes an Item `<T: Updateable>`, which will update itself, if new information about this /// Observes an Item `<T: Updateable>`, which will update itself, if new information about this
/// item arrives on the corresponding Gateway Thread /// item arrives on the corresponding Gateway Thread
async fn observe<T: Updateable + Clone + std::fmt::Debug + Composite<T> + Send + Sync>( async fn observe<U: Updateable + Clone + std::fmt::Debug + Composite<U> + Send + Sync>(
&self, &self,
object: Arc<RwLock<T>>, object: Arc<RwLock<U>>,
) -> Arc<RwLock<T>>; ) -> Arc<RwLock<U>>;
/// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T`
/// with all of its observable fields being observed. /// with all of its observable fields being observed.
async fn observe_and_into_inner<T: Updateable + Clone + std::fmt::Debug + Composite<T>>( async fn observe_and_into_inner<U: Updateable + Clone + std::fmt::Debug + Composite<U>>(
&self, &self,
object: Arc<RwLock<T>>, object: Arc<RwLock<U>>,
) -> T { ) -> U {
let channel = self.observe(object.clone()).await; let channel = self.observe(object.clone()).await;
let object = channel.read().unwrap().clone(); let object = channel.read().unwrap().clone();
object object
@ -556,17 +570,12 @@ where
} }
#[async_trait] #[async_trait]
pub trait HeartbeatHandlerCapable<S: Sink<Message> + Send + 'static> { pub trait HeartbeatHandlerCapable<T: MessageCapable + Send + 'static, S: Sink<T>> {
fn new(
heartbeat_interval: Duration,
websocket_tx: Arc<Mutex<SplitSink<S, Message>>>,
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> Self;
fn get_send(&self) -> &Sender<HeartbeatThreadCommunication>; fn get_send(&self) -> &Sender<HeartbeatThreadCommunication>;
fn as_arc_mutex(&self) -> Arc<Mutex<Self>>;
fn get_heartbeat_interval(&self) -> Duration; fn get_heartbeat_interval(&self) -> Duration;
async fn heartbeat_task( async fn heartbeat_task(
websocket_tx: Arc<Mutex<SplitSink<S, Message>>>, websocket_tx: Arc<Mutex<SplitSink<S, T>>>,
heartbeat_interval: Duration, heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>, mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>, mut kill_receive: tokio::sync::broadcast::Receiver<()>,