Merge pull request #50 from kozabrada123/main

Closes #22
This commit is contained in:
Flori 2023-05-13 18:06:26 +02:00 committed by GitHub
commit 5372c6bce5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 510 additions and 115 deletions

View File

@ -10,7 +10,7 @@ serde = {version = "1.0.159", features = ["derive"]}
serde_json = "1.0.95" serde_json = "1.0.95"
reqwest = {version = "0.11.16", features = ["multipart"]} reqwest = {version = "0.11.16", features = ["multipart"]}
url = "2.3.1" url = "2.3.1"
chrono = "0.4.24" chrono = {version = "0.4.24", features = ["serde"]}
regex = "1.7.3" regex = "1.7.3"
custom_error = "1.9.2" custom_error = "1.9.2"
native-tls = "0.2.11" native-tls = "0.2.11"

View File

@ -4,6 +4,7 @@ https://discord.com/developers/docs .
I do not feel like re-documenting all of this, as everything is already perfectly explained there. I do not feel like re-documenting all of this, as everything is already perfectly explained there.
*/ */
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{api::limits::Limits, instance::Instance}; use crate::{api::limits::Limits, instance::Instance};
@ -132,6 +133,12 @@ pub struct Error {
pub code: String, pub code: String,
} }
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct UnavailableGuild {
id: String,
unavailable: bool
}
#[derive(Serialize, Deserialize, Debug, Default)] #[derive(Serialize, Deserialize, Debug, Default)]
pub struct UserObject { pub struct UserObject {
pub id: String, pub id: String,
@ -158,7 +165,7 @@ pub struct UserObject {
premium: bool, premium: bool,
purchased_flags: i32, purchased_flags: i32,
premium_usage_flags: i32, premium_usage_flags: i32,
disabled: bool, disabled: Option<bool>,
} }
#[derive(Debug)] #[derive(Debug)]
@ -717,7 +724,7 @@ pub struct PresenceUpdate {
since: Option<i64>, since: Option<i64>,
activities: Vec<Activity>, activities: Vec<Activity>,
status: String, status: String,
afk: bool, afk: Option<bool>,
} }
impl WebSocketEvent for PresenceUpdate {} impl WebSocketEvent for PresenceUpdate {}
@ -785,6 +792,42 @@ pub struct GatewayResume {
impl WebSocketEvent for GatewayResume {} impl WebSocketEvent for GatewayResume {}
#[derive(Debug, Deserialize, Serialize, Default)]
pub struct GatewayReady {
pub v: u8,
pub user: UserObject,
pub guilds: Vec<UnavailableGuild>,
pub session_id: String,
pub resume_gateway_url: Option<String>,
pub shard: Option<(u64, u64)>,
}
impl WebSocketEvent for GatewayReady {}
#[derive(Debug, Deserialize, Serialize, Default)]
/// See https://discord.com/developers/docs/topics/gateway-events#request-guild-members-request-guild-members-structure
pub struct GatewayRequestGuildMembers {
pub guild_id: String,
pub query: Option<String>,
pub limit: u64,
pub presence: Option<bool>,
pub user_ids: Option<String>,
pub nonce: Option<String>,
}
impl WebSocketEvent for GatewayRequestGuildMembers {}
#[derive(Debug, Deserialize, Serialize, Default)]
/// See https://discord.com/developers/docs/topics/gateway-events#update-voice-state-gateway-voice-state-update-structure
pub struct GatewayVoiceStateUpdate {
pub guild_id: String,
pub channel_id: Option<String>,
pub self_mute: bool,
pub self_deaf: bool,
}
impl WebSocketEvent for GatewayVoiceStateUpdate {}
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
pub struct GatewayHello { pub struct GatewayHello {
pub op: i32, pub op: i32,
@ -795,7 +838,7 @@ impl WebSocketEvent for GatewayHello {}
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
pub struct HelloData { pub struct HelloData {
pub heartbeat_interval: i32, pub heartbeat_interval: u128,
} }
impl WebSocketEvent for HelloData {} impl WebSocketEvent for HelloData {}
@ -803,7 +846,7 @@ impl WebSocketEvent for HelloData {}
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
pub struct GatewayHeartbeat { pub struct GatewayHeartbeat {
pub op: u8, pub op: u8,
pub d: u64, pub d: Option<u64>,
} }
impl WebSocketEvent for GatewayHeartbeat {} impl WebSocketEvent for GatewayHeartbeat {}
@ -815,11 +858,48 @@ pub struct GatewayHeartbeatAck {
impl WebSocketEvent for GatewayHeartbeatAck {} impl WebSocketEvent for GatewayHeartbeatAck {}
#[derive(Debug, Default, Deserialize, Serialize)]
/// See https://discord.com/developers/docs/topics/gateway-events#channel-pins-update
pub struct ChannelPinsUpdate {
pub guild_id: Option<String>,
pub channel_id: String,
pub last_pin_timestamp: Option<DateTime<Utc>>
}
impl WebSocketEvent for ChannelPinsUpdate {}
#[derive(Debug, Default, Deserialize, Serialize)]
/// See https://discord.com/developers/docs/topics/gateway-events#guild-ban-add-guild-ban-add-event-fields
pub struct GuildBanAdd {
pub guild_id: String,
pub user: UserObject,
}
impl WebSocketEvent for GuildBanAdd {}
#[derive(Debug, Default, Deserialize, Serialize)]
/// See https://discord.com/developers/docs/topics/gateway-events#guild-ban-remove
pub struct GuildBanRemove {
pub guild_id: String,
pub user: UserObject,
}
impl WebSocketEvent for GuildBanRemove {}
#[derive(Debug, Default, Deserialize, Serialize)]
/// See https://discord.com/developers/docs/topics/gateway-events#user-update
/// Not directly serialized, as the inner payload is the user object
pub struct UserUpdate {
pub user: UserObject,
}
impl WebSocketEvent for UserUpdate {}
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
pub struct GatewayPayload { pub struct GatewayPayload {
pub op: i32, pub op: u8,
pub d: Option<String>, pub d: Option<serde_json::Value>,
pub s: Option<i64>, pub s: Option<u64>,
pub t: Option<String>, pub t: Option<String>,
} }

View File

@ -1,58 +1,112 @@
use std::sync::Arc;
use crate::api::types::*; use crate::api::types::*;
use crate::api::WebSocketEvent; use crate::api::WebSocketEvent;
use crate::errors::ObserverError; use crate::errors::ObserverError;
use crate::gateway::events::Events; use crate::gateway::events::Events;
use futures_util::SinkExt;
use futures_util::StreamExt;
use futures_util::stream::SplitSink;
use native_tls::TlsConnector;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
use tokio::task;
use tokio::time;
use tokio::time::Instant;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{WebSocketStream, Connector, connect_async_tls_with_config};
#[derive(Debug)]
/** /**
Represents a Gateway connection. A Gateway connection will create observable Represents a handle to a Gateway connection. A Gateway connection will create observable
[`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently
implemented [Types] with the trait [`WebSocketEvent`] implemented [Types] with the trait [`WebSocketEvent`]
Using this handle you can also send Gateway Events directly.
*/ */
pub struct Gateway<'a> { pub struct GatewayHandle {
pub url: String, pub url: String,
pub token: String, pub events: Arc<Mutex<Events>>,
pub events: Events<'a>, pub websocket_tx: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>,
} }
impl<'a> Gateway<'a> { impl GatewayHandle {
/// Sends json to the gateway with an opcode
async fn send_json_event(&self, op: u8, to_send: serde_json::Value) {
let gateway_payload = GatewayPayload { op, d: Some(to_send), s: None, t: None };
let payload_json = serde_json::to_string(&gateway_payload).unwrap();
let message = tokio_tungstenite::tungstenite::Message::text(payload_json);
self.websocket_tx.lock().await.send(message).await.unwrap();
}
/// Sends an identify event to the gateway
pub async fn send_identify(&self, to_send: GatewayIdentifyPayload) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Identify..");
self.send_json_event(2, to_send_value).await;
}
/// Sends a resume event to the gateway
pub async fn send_resume(&self, to_send: GatewayResume) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Resume..");
self.send_json_event(6, to_send_value).await;
}
/// Sends an update presence event to the gateway
pub async fn send_update_presence(&self, to_send: PresenceUpdate) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Presence Update..");
self.send_json_event(3, to_send_value).await;
}
/// Sends a Request Guild Members to the server
pub async fn send_request_guild_members(&self, to_send: GatewayRequestGuildMembers) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Request Guild Members..");
self.send_json_event(8, to_send_value).await;
}
/// Sends a Request Guild Members to the server
pub async fn send_update_voice_state(&self, to_send: GatewayVoiceStateUpdate) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Voice State Update..");
self.send_json_event(4, to_send_value).await;
}
}
pub struct Gateway {
pub events: Arc<Mutex<Events>>,
heartbeat_handler: Option<HeartbeatHandler>,
pub websocket_tx: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>
}
impl Gateway {
pub async fn new( pub async fn new(
websocket_url: String, websocket_url: String,
token: String, ) -> Result<GatewayHandle, tokio_tungstenite::tungstenite::Error> {
) -> Result<Gateway<'a>, tokio_tungstenite::tungstenite::Error> {
Ok(Gateway {
url: websocket_url,
token,
events: Events::default(),
})
}
}
/*struct WebSocketConnection { let (ws_stream, _) = match connect_async_tls_with_config(
rx: Arc<Mutex<Receiver<tokio_tungstenite::tungstenite::Message>>>,
tx: Arc<Mutex<Sender<tokio_tungstenite::tungstenite::Message>>>,
}
impl<'a> WebSocketConnection {
async fn new(websocket_url: String) -> WebSocketConnection {
let parsed_url = Url::parse(&URLBundle::parse_url(websocket_url.clone())).unwrap();
if parsed_url.scheme() != "ws" && parsed_url.scheme() != "wss" {
return Err(tokio_tungstenite::tungstenite::Error::Url(
UrlError::UnsupportedUrlScheme,
));
}
let (mut channel_write, mut channel_read): (
Sender<tokio_tungstenite::tungstenite::Message>,
Receiver<tokio_tungstenite::tungstenite::Message>,
) = channel(32);
let shared_channel_write = Arc::new(Mutex::new(channel_write));
let clone_shared_channel_write = shared_channel_write.clone();
let shared_channel_read = Arc::new(Mutex::new(channel_read));
let clone_shared_channel_read = shared_channel_read.clone();
task::spawn(async move {
let (mut ws_stream, _) = match connect_async_tls_with_config(
&websocket_url, &websocket_url,
None, None,
Some(Connector::NativeTls( Some(Connector::NativeTls(
@ -62,33 +116,281 @@ impl<'a> WebSocketConnection {
.await .await
{ {
Ok(ws_stream) => ws_stream, Ok(ws_stream) => ws_stream,
Err(_) => return, /*return Err(tokio_tungstenite::tungstenite::Error::Io( Err(e) => return Err(e),
io::ErrorKind::ConnectionAborted.into(),
))*/
}; };
let (mut write_tx, mut write_rx) = ws_stream.split(); let (ws_tx, mut ws_rx) = ws_stream.split();
while let Some(msg) = shared_channel_read.lock().await.recv().await { let shared_tx = Arc::new(Mutex::new(ws_tx));
write_tx.send(msg).await.unwrap();
let mut gateway = Gateway { events: Arc::new(Mutex::new(Events::default())), heartbeat_handler: None, websocket_tx: shared_tx.clone() };
let shared_events = gateway.events.clone();
// Wait for the first hello and then spawn both tasks so we avoid nested tasks
// This automatically spawns the heartbeat task, but from the main thread
let msg = ws_rx.next().await.unwrap().unwrap();
let gateway_payload: GatewayPayload = serde_json::from_str(msg.to_text().unwrap()).unwrap();
if gateway_payload.op != 10 {
println!("Recieved non hello on gateway init, what is happening?");
return Err(tokio_tungstenite::tungstenite::Error::Protocol(tokio_tungstenite::tungstenite::error::ProtocolError::InvalidOpcode(gateway_payload.op)))
} }
println!("GW: Received Hello");
let gateway_hello: HelloData = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
gateway.heartbeat_handler = Some(HeartbeatHandler::new(gateway_hello.heartbeat_interval, shared_tx.clone()));
// Now we can continously check for messages in a different task, since we aren't going to receive another hello
task::spawn(async move {
loop {
let msg = ws_rx.next().await;
if msg.as_ref().is_some() {
let msg_unwrapped = msg.unwrap().unwrap();
gateway.handle_event(msg_unwrapped).await;
};
}
});
return Ok(GatewayHandle {
url: websocket_url.clone(),
events: shared_events,
websocket_tx: shared_tx.clone(),
});
}
/// This handles a message as a websocket event and updates its events along with the events' observers
pub async fn handle_event(&mut self, msg: tokio_tungstenite::tungstenite::Message) {
if msg.to_string() == String::new() {
return;
}
let gateway_payload: GatewayPayload = serde_json::from_str(msg.to_text().unwrap()).unwrap();
// See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes
match gateway_payload.op {
// Dispatch
// An event was dispatched, we need to look at the gateway event name t
0 => {
let gateway_payload_t = gateway_payload.t.unwrap();
println!("GW: Received {}..", gateway_payload_t);
// See https://discord.com/developers/docs/topics/gateway-events#receive-events
match gateway_payload_t.as_str() {
"READY" => {
let _data: GatewayReady = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
}
"RESUMED" => {}
"APPLICATION_COMMAND_PERMISSIONS_UPDATE" => {}
"AUTO_MODERATION_RULE_CREATE" => {}
"AUTO_MODERATION_RULE_UPDATE" => {}
"AUTO_MODERATION_RULE_DELETE" => {}
"AUTO_MODERATION_ACTION_EXECUTION" => {}
"CHANNEL_CREATE" => {}
"CHANNEL_UPDATE" => {}
"CHANNEL_DELETE" => {}
"CHANNEL_PINS_UPDATE" => {}
"THREAD_CREATE" => {}
"THREAD_UPDATE" => {}
"THREAD_DELETE" => {}
"THREAD_LIST_SYNC" => {}
"THREAD_MEMBER_UPDATE" => {}
"THREAD_MEMBERS_UPDATE" => {}
"GUILD_CREATE" => {}
"GUILD_UPDATE" => {}
"GUILD_DELETE" => {
let _new_data: UnavailableGuild = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
}
"GUILD_AUDIT_LOG_ENTRY_CREATE" => {}
"GUILD_BAN_ADD" => {
let _new_data: GuildBanAdd = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
}
"GUILD_BAN_REMOVE" => {
let _new_data: GuildBanRemove = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
}
"GUILD_EMOJIS_UPDATE" => {}
"GUILD_STICKERS_UPDATE" => {}
"GUILD_INTEGRATIONS_UPDATE" => {}
"GUILD_MEMBER_ADD" => {}
"GUILD_MEMBER_REMOVE" => {}
"GUILD_MEMBER_UPDATE" => {}
"GUILD_MEMBERS_CHUNK" => {}
"GUILD_ROLE_CREATE" => {}
"GUILD_ROLE_UPDATE" => {}
"GUILD_ROLE_DELETE" => {}
"GUILD_SCHEDULED_EVENT_CREATE" => {}
"GUILD_SCHEDULED_EVENT_UPDATE" => {}
"GUILD_SCHEDULED_EVENT_DELETE" => {}
"GUILD_SCHEDULED_EVENT_USER_ADD" => {}
"GUILD_SCHEDULED_EVENT_USER_REMOVE" => {}
"INTEGRATION_CREATE" => {}
"INTEGRATION_UPDATE" => {}
"INTEGRATION_DELETE" => {}
"INTERACTION_CREATE" => {}
"INVITE_CREATE" => {}
"INVITE_DELETE" => {}
"MESSAGE_CREATE" => {
let new_data: MessageCreate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.message.create.update_data(new_data).await;
}
"MESSAGE_UPDATE" => {
let new_data: MessageUpdate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.message.update.update_data(new_data).await;
}
"MESSAGE_DELETE" => {
let new_data: MessageDelete = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.message.delete.update_data(new_data).await;
}
"MESSAGE_DELETE_BULK" => {
let new_data: MessageDeleteBulk = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.message.delete_bulk.update_data(new_data).await;
}
"MESSAGE_REACTION_ADD" => {
let new_data: MessageReactionAdd = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.message.reaction_add.update_data(new_data).await;
}
"MESSAGE_REACTION_REMOVE" => {
let new_data: MessageReactionRemove = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.message.reaction_remove.update_data(new_data).await;
}
"MESSAGE_REACTION_REMOVE_ALL" => {
let new_data: MessageReactionRemoveAll = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.message.reaction_remove_all.update_data(new_data).await;
}
"MESSAGE_REACTION_REMOVE_EMOJI" => {
let new_data: MessageReactionRemoveEmoji= serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.message.reaction_remove_emoji.update_data(new_data).await;
}
"PRESENCE_UPDATE" => {
let new_data: PresenceUpdate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.user.presence_update.update_data(new_data).await;
}
"STAGE_INSTANCE_CREATE" => {}
"STAGE_INSTANCE_UPDATE" => {}
"STAGE_INSTANCE_DELETE" => {}
// Not documented in discord docs, I assume this isnt for bots / apps but is for users?
"SESSIONS_REPLACE" => {}
"TYPING_START" => {
let new_data: TypingStartEvent = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.lock().await.user.typing_start_event.update_data(new_data).await;
}
"USER_UPDATE" => {
let user: UserObject = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
let new_data = UserUpdate {user};
self.events.lock().await.user.user_update.update_data(new_data).await;
}
"VOICE_STATE_UPDATE" => {}
"VOICE_SERVER_UPDATE" => {}
"WEBHOOKS_UPDATE" => {}
_ => {panic!("Invalid gateway event ({})", &gateway_payload_t)}
}
}
// Heartbeat
// We received a heartbeat from the server
1 => {}
// Reconnect
7 => {todo!()}
// Invalid Session
9 => {todo!()}
// Hello
// Starts our heartbeat
// We should have already handled this in gateway init
10 => {
panic!("Recieved hello when it was unexpected");
}
// Heartbeat ACK
11 => {
println!("GW: Received Heartbeat ACK");
}
2 | 3 | 4 | 6 | 8 => {panic!("Received Gateway op code that's meant to be sent, not received ({})", gateway_payload.op)}
_ => {panic!("Received Invalid Gateway op code ({})", gateway_payload.op)}
}
// If we have an active heartbeat thread and we received a seq number we should let it know
if gateway_payload.s.is_some() {
if self.heartbeat_handler.is_some() {
let heartbeat_communication = HeartbeatThreadCommunication { op: gateway_payload.op, d: gateway_payload.s.unwrap() };
self.heartbeat_handler.as_mut().unwrap().tx.send(heartbeat_communication).await.unwrap();
}
}
}
}
/**
Handles sending heartbeats to the gateway in another thread
*/
struct HeartbeatHandler {
/// The heartbeat interval in milliseconds
heartbeat_interval: u128,
tx: Sender<HeartbeatThreadCommunication>,
}
impl HeartbeatHandler {
pub fn new(heartbeat_interval: u128, websocket_tx: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>) -> HeartbeatHandler {
let (tx, mut rx) = mpsc::channel(32);
task::spawn(async move {
let mut last_heartbeat: Instant = time::Instant::now();
let mut last_seq_number: Option<u64> = None;
loop {
// If we received a seq number update, use that as the last seq number
let hb_communication: Result<HeartbeatThreadCommunication, TryRecvError> = rx.try_recv();
if hb_communication.is_ok() {
last_seq_number = Some(hb_communication.unwrap().d);
}
if last_heartbeat.elapsed().as_millis() > heartbeat_interval {
println!("GW: Sending Heartbeat..");
let heartbeat = GatewayHeartbeat {
op: 1,
d: last_seq_number
}; };
Ok(Gateway { let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
url: websocket_url,
token, let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
events: Events::default(),
socket: ws_stream, websocket_tx.lock().await
}) .send(msg)
.await
.unwrap();
last_heartbeat = time::Instant::now();
}
}
});
Self { heartbeat_interval, tx }
}
}
/**
Used to communicate with the main thread.
Either signifies a sequence number update or a received heartbeat ack
*/
#[derive(Clone, Copy, Debug)]
struct HeartbeatThreadCommunication {
/// An opcode for the communication we received
op: u8,
/// The sequence number we got from discord
d: u64
} }
}*/
/** /**
Trait which defines the behaviour of an Observer. An Observer is an object which is subscribed to Trait which defines the behaviour of an Observer. An Observer is an object which is subscribed to
an Observable. The Observer is notified when the Observable's data changes. an Observable. The Observer is notified when the Observable's data changes.
In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent. In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent.
*/ */
pub trait Observer<T: WebSocketEvent> { pub trait Observer<T: WebSocketEvent>: std::fmt::Debug {
fn update(&self, data: &T); fn update(&self, data: &T);
} }
@ -96,14 +398,14 @@ pub trait Observer<T: WebSocketEvent> {
change in the WebSocketEvent. GatewayEvents are observable. change in the WebSocketEvent. GatewayEvents are observable.
*/ */
#[derive(Default)] #[derive(Default, Debug)]
pub struct GatewayEvent<'a, T: WebSocketEvent> { pub struct GatewayEvent<T: WebSocketEvent> {
observers: Vec<&'a dyn Observer<T>>, observers: Vec<Arc<Mutex<dyn Observer<T> + Sync + Send>>>,
pub event_data: T, pub event_data: T,
pub is_observed: bool, pub is_observed: bool,
} }
impl<'a, T: WebSocketEvent> GatewayEvent<'a, T> { impl<T: WebSocketEvent> GatewayEvent<T> {
fn new(event_data: T) -> Self { fn new(event_data: T) -> Self {
Self { Self {
is_observed: false, is_observed: false,
@ -126,7 +428,7 @@ impl<'a, T: WebSocketEvent> GatewayEvent<'a, T> {
Returns an error if the GatewayEvent is already observed. Returns an error if the GatewayEvent is already observed.
Error type: [`ObserverError::AlreadySubscribedError`] Error type: [`ObserverError::AlreadySubscribedError`]
*/ */
pub fn subscribe(&mut self, observable: &'a dyn Observer<T>) -> Option<ObserverError> { pub fn subscribe(&mut self, observable: Arc<Mutex<dyn Observer<T> + Sync + Send>>) -> Option<ObserverError> {
if self.is_observed { if self.is_observed {
return Some(ObserverError::AlreadySubscribedError); return Some(ObserverError::AlreadySubscribedError);
} }
@ -138,57 +440,60 @@ impl<'a, T: WebSocketEvent> GatewayEvent<'a, T> {
/** /**
Unsubscribes an Observer from the GatewayEvent. Unsubscribes an Observer from the GatewayEvent.
*/ */
pub fn unsubscribe(&mut self, observable: &'a dyn Observer<T>) { pub fn unsubscribe(&mut self, observable: Arc<Mutex<dyn Observer<T> + Sync + Send>>) {
// .retain()'s closure retains only those elements of the vector, which have a different // .retain()'s closure retains only those elements of the vector, which have a different
// pointer value than observable. // pointer value than observable.
self.observers.retain(|obs| !std::ptr::eq(*obs, observable)); // The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same
// anddd there is no way to do that without using format
self.observers.retain(|obs| !(format!("{:?}", obs) == format!("{:?}", &observable)));
self.is_observed = !self.observers.is_empty(); self.is_observed = !self.observers.is_empty();
} }
/** /**
Updates the GatewayEvent's data and notifies the observers. Updates the GatewayEvent's data and notifies the observers.
*/ */
fn update_data(&mut self, new_event_data: T) { async fn update_data(&mut self, new_event_data: T) {
self.event_data = new_event_data; self.event_data = new_event_data;
self.notify(); self.notify().await;
} }
/** /**
Notifies the observers of the GatewayEvent. Notifies the observers of the GatewayEvent.
*/ */
fn notify(&self) { async fn notify(&self) {
for observer in &self.observers { for observer in &self.observers {
observer.update(&self.event_data); observer.lock().await.update(&self.event_data);
} }
} }
} }
mod events { mod events {
use super::*; use super::*;
#[derive(Default)] #[derive(Default, Debug)]
pub struct Events<'a> { pub struct Events {
pub message: Message<'a>, pub message: Message,
pub user: User<'a>, pub user: User,
pub gateway_identify_payload: GatewayEvent<'a, GatewayIdentifyPayload>, pub gateway_identify_payload: GatewayEvent<GatewayIdentifyPayload>,
pub gateway_resume: GatewayEvent<'a, GatewayResume>, pub gateway_resume: GatewayEvent<GatewayResume>,
} }
#[derive(Default)] #[derive(Default, Debug)]
pub struct Message<'a> { pub struct Message {
pub create: GatewayEvent<'a, MessageCreate>, pub create: GatewayEvent<MessageCreate>,
pub update: GatewayEvent<'a, MessageUpdate>, pub update: GatewayEvent<MessageUpdate>,
pub delete: GatewayEvent<'a, MessageDelete>, pub delete: GatewayEvent<MessageDelete>,
pub delete_bulk: GatewayEvent<'a, MessageDeleteBulk>, pub delete_bulk: GatewayEvent<MessageDeleteBulk>,
pub reaction_add: GatewayEvent<'a, MessageReactionAdd>, pub reaction_add: GatewayEvent<MessageReactionAdd>,
pub reaction_remove: GatewayEvent<'a, MessageReactionRemove>, pub reaction_remove: GatewayEvent<MessageReactionRemove>,
pub reaction_remove_all: GatewayEvent<'a, MessageReactionRemoveAll>, pub reaction_remove_all: GatewayEvent<MessageReactionRemoveAll>,
pub reaction_remove_emoji: GatewayEvent<'a, MessageReactionRemoveEmoji>, pub reaction_remove_emoji: GatewayEvent<MessageReactionRemoveEmoji>,
} }
#[derive(Default)] #[derive(Default, Debug)]
pub struct User<'a> { pub struct User {
pub presence_update: GatewayEvent<'a, PresenceUpdate>, pub user_update: GatewayEvent<UserUpdate>,
pub typing_start_event: GatewayEvent<'a, TypingStartEvent>, pub presence_update: GatewayEvent<PresenceUpdate>,
pub typing_start_event: GatewayEvent<TypingStartEvent>,
} }
} }
@ -197,6 +502,7 @@ mod example {
use super::*; use super::*;
use crate::api::types::GatewayResume; use crate::api::types::GatewayResume;
#[derive(Debug)]
struct Consumer; struct Consumer;
impl Observer<GatewayResume> for Consumer { impl Observer<GatewayResume> for Consumer {
fn update(&self, data: &GatewayResume) { fn update(&self, data: &GatewayResume) {
@ -204,8 +510,8 @@ mod example {
} }
} }
#[test] #[tokio::test]
fn test_observer_behaviour() { async fn test_observer_behaviour() {
let mut event = GatewayEvent::new(GatewayResume { let mut event = GatewayEvent::new(GatewayResume {
token: "start".to_string(), token: "start".to_string(),
session_id: "start".to_string(), session_id: "start".to_string(),
@ -219,24 +525,34 @@ mod example {
}; };
let consumer = Consumer; let consumer = Consumer;
let arc_mut_consumer = Arc::new(Mutex::new(consumer));
event.subscribe(&consumer); event.subscribe(arc_mut_consumer.clone());
event.notify(); event.notify().await;
event.update_data(new_data); event.update_data(new_data).await;
let second_consumer = Consumer; let second_consumer = Consumer;
let arc_mut_second_consumer = Arc::new(Mutex::new(second_consumer));
match event.subscribe(&second_consumer) { match event.subscribe(arc_mut_second_consumer.clone()) {
None => assert!(false), None => assert!(false),
Some(err) => println!("You cannot subscribe twice: {}", err), Some(err) => println!("You cannot subscribe twice: {}", err),
} }
event.unsubscribe(arc_mut_consumer.clone());
match event.subscribe(arc_mut_second_consumer.clone()) {
None => assert!(true),
Some(_) => assert!(false),
}
} }
#[tokio::test] #[tokio::test]
async fn test_gateway() { async fn test_gateway_establish() {
let _gateway = Gateway::new("ws://localhost:3001/".to_string(), "none".to_string()) let _gateway = Gateway::new("ws://localhost:3001/".to_string())
.await .await
.unwrap(); .unwrap();
} }

View File

@ -16,7 +16,6 @@ pub struct Instance {
pub instance_info: InstancePolicies, pub instance_info: InstancePolicies,
pub requester: LimitedRequester, pub requester: LimitedRequester,
pub limits: Limits, pub limits: Limits,
//pub gateway: Gateway,
} }
impl Instance { impl Instance {