diff --git a/Cargo.toml b/Cargo.toml index b441eb6..1769377 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ license = "AGPL-3" edition = "2021" [dependencies] -tokio = {version = "1.27.0", features = ["rt", "macros"]} +tokio = {version = "1.27.0", features = ["rt", "macros", "rt-multi-thread"]} serde = {version = "1.0.159", features = ["derive"]} serde_json = "1.0.95" reqwest = "0.11.16" diff --git a/src/api/types.rs b/src/api/types.rs index c50520f..43ee964 100644 --- a/src/api/types.rs +++ b/src/api/types.rs @@ -2,7 +2,7 @@ To learn more about the types implemented here, please visit https://discord.com/developers/docs . I do not feel like re-documenting all of this, as everything is already perfectly explained there. - */ +*/ use std::fmt; @@ -10,6 +10,8 @@ use serde::{Deserialize, Serialize}; use crate::{api::limits::Limits, URLBundle}; +pub trait WebSocketEvent {} + #[derive(Debug, Serialize, Deserialize)] pub struct LoginResult { token: String, @@ -132,7 +134,7 @@ pub struct Error { pub code: String, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default)] pub struct UserObject { id: String, username: String, @@ -201,8 +203,8 @@ impl User { } } -#[derive(Debug, Serialize, Deserialize)] -struct Message { +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Message { id: String, channel_id: String, author: UserObject, @@ -237,8 +239,8 @@ struct Message { role_subscription_data: Option, } -#[derive(Debug, Serialize, Deserialize)] -struct MessageCreate { +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct MessageCreate { #[serde(flatten)] message: Message, guild_id: Option, @@ -246,7 +248,9 @@ struct MessageCreate { mentions: Vec<(UserObject, GuildMember)>, // Not sure if this is correct: https://discord.com/developers/docs/topics/gateway-events#message-create } -#[derive(Debug, Serialize, Deserialize)] +impl WebSocketEvent for MessageCreate {} + +#[derive(Debug, Serialize, Deserialize, Default)] struct PartialMessage { id: Option, channel_id: Option, @@ -284,8 +288,8 @@ struct PartialMessage { member: Option, } -#[derive(Debug, Serialize, Deserialize)] -struct MessageUpdate { +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct MessageUpdate { #[serde(flatten)] message: PartialMessage, guild_id: Option, @@ -293,22 +297,28 @@ struct MessageUpdate { mentions: Vec<(UserObject, GuildMember)>, // Not sure if this is correct: https://discord.com/developers/docs/topics/gateway-events#message-create } -#[derive(Debug, Serialize, Deserialize)] -struct MessageDelete { +impl WebSocketEvent for MessageUpdate {} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct MessageDelete { id: String, channel_id: String, guild_id: Option, } -#[derive(Debug, Serialize, Deserialize)] -struct MessageDeleteBulk { +impl WebSocketEvent for MessageDelete {} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct MessageDeleteBulk { ids: Vec, channel_id: String, guild_id: Option, } -#[derive(Debug, Serialize, Deserialize)] -struct MessageReactionAdd { +impl WebSocketEvent for MessageDeleteBulk {} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct MessageReactionAdd { user_id: String, channel_id: String, message_id: String, @@ -317,8 +327,10 @@ struct MessageReactionAdd { emoji: Emoji, } -#[derive(Debug, Serialize, Deserialize)] -struct MessageReactionRemove { +impl WebSocketEvent for MessageReactionAdd {} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct MessageReactionRemove { user_id: String, channel_id: String, message_id: String, @@ -326,21 +338,27 @@ struct MessageReactionRemove { emoji: Emoji, } -#[derive(Debug, Serialize, Deserialize)] -struct MessageReactionRemoveAll { +impl WebSocketEvent for MessageReactionRemove {} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct MessageReactionRemoveAll { channel_id: String, message_id: String, guild_id: Option, } -#[derive(Debug, Serialize, Deserialize)] -struct MessageReactionRemoveEmoji { +impl WebSocketEvent for MessageReactionRemoveAll {} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct MessageReactionRemoveEmoji { channel_id: String, message_id: String, guild_id: Option, emoji: Emoji, } +impl WebSocketEvent for MessageReactionRemoveEmoji {} + #[derive(Debug, Serialize, Deserialize)] struct ChannelMention { id: String, @@ -445,7 +463,7 @@ struct Reaction { emoji: Emoji, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, Default)] struct Emoji { id: Option, name: Option, @@ -680,8 +698,8 @@ struct RoleSubscriptionData { is_renewal: bool, } -#[derive(Debug, Deserialize, Serialize)] -struct TypingStartEvent { +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct TypingStartEvent { channel_id: String, guild_id: Option, user_id: String, @@ -689,32 +707,38 @@ struct TypingStartEvent { member: Option, } -#[derive(Debug, Deserialize, Serialize)] -struct GatewayIdentifyPayload { - token: String, - properties: GatewayIdentifyConnectionProps, - compress: Option, - large_threshold: Option, //default: 50 - shard: Option>, - presence: Option, - intents: i32, +impl WebSocketEvent for TypingStartEvent {} + +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct GatewayIdentifyPayload { + pub token: String, + pub properties: GatewayIdentifyConnectionProps, + pub compress: Option, + pub large_threshold: Option, //default: 50 + pub shard: Option>, + pub presence: Option, + pub intents: i32, } -#[derive(Debug, Deserialize, Serialize)] -struct GatewayIdentifyConnectionProps { - os: String, - browser: String, - device: String, +impl WebSocketEvent for GatewayIdentifyPayload {} + +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct GatewayIdentifyConnectionProps { + pub os: String, + pub browser: String, + pub device: String, } -#[derive(Debug, Deserialize, Serialize)] -struct PresenceUpdate { +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct PresenceUpdate { since: Option, activities: Vec, status: String, afk: bool, } +impl WebSocketEvent for PresenceUpdate {} + #[derive(Debug, Deserialize, Serialize)] struct Activity { name: String, @@ -769,9 +793,11 @@ struct ActivityButton { url: String, } -#[derive(Debug, Deserialize, Serialize)] -struct GatewayResume { - token: String, - session_id: String, - seq: String, +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct GatewayResume { + pub token: String, + pub session_id: String, + pub seq: String, } + +impl WebSocketEvent for GatewayResume {} diff --git a/src/errors.rs b/src/errors.rs index 58dc856..f74c7cd 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -19,3 +19,9 @@ custom_error! { InvalidFormBodyError{error_type: String, error:String} = "The server responded with: {error_type}: {error}", RateLimited = "Ratelimited.", } + +custom_error! { + #[derive(PartialEq, Eq)] + pub ObserverError + AlreadySubscribedError = "Each event can only be subscribed to once." +} diff --git a/src/gateway.rs b/src/gateway.rs index d68be22..f56e0d4 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,4 +1,203 @@ -#[derive(Debug)] -pub struct Gateway { - url: String, +use std::sync::Arc; +use std::sync::Mutex; +use std::thread::JoinHandle; + +use crate::api::types::*; +use crate::api::WebSocketEvent; +use crate::errors::ObserverError; +use crate::gateway::events::Events; +use crate::URLBundle; +use reqwest::Url; +use serde_json::to_string; +use tokio::net::TcpStream; +use tokio_tungstenite::tungstenite::Error; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; + +/** +Represents a Gateway connection. A Gateway connection will create observable +[`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently +implemented [Types] with the trait [`WebSocketEvent`] +*/ +pub struct Gateway<'a> { + pub url: String, + pub events: Events<'a>, + socket: Arc>>>>, + thread_handle: Option>, +} + +impl<'a> Gateway<'a> { + pub async fn new(websocket_url: String, token: String) { + let parsed_url = Url::parse(&URLBundle::parse_url(websocket_url.clone())).unwrap(); + if parsed_url.scheme() != "ws" && parsed_url.scheme() != "wss" { + //return Err(Error::Url(UrlError::UnsupportedUrlScheme)); + } + let payload = GatewayIdentifyPayload { + token: token, + properties: GatewayIdentifyConnectionProps { + os: "any".to_string(), + browser: "chorus-polyphony".to_string(), + device: "chorus-lib".to_string(), + }, + compress: Some(true), + large_threshold: None, + shard: None, + presence: None, + intents: 3276799, + }; + let payload_string = to_string(&payload).unwrap(); + } +} + +/** +Trait which defines the behaviour of an Observer. An Observer is an object which is subscribed to +an Observable. The Observer is notified when the Observable's data changes. +In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent. + */ +pub trait Observer { + fn update(&self, data: &T); +} + +/** GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a +change in the WebSocketEvent. GatewayEvents are observable. +*/ + +#[derive(Default)] +pub struct GatewayEvent<'a, T: WebSocketEvent> { + observers: Vec<&'a dyn Observer>, + pub event_data: T, + pub is_observed: bool, +} + +impl<'a, T: WebSocketEvent> GatewayEvent<'a, T> { + fn new(event_data: T) -> Self { + Self { + is_observed: false, + observers: Vec::new(), + event_data, + } + } + + /** + Returns true if the GatewayEvent is observed by at least one Observer. + */ + pub fn is_observed(&self) -> bool { + self.is_observed + } + + /** + Subscribes an Observer to the GatewayEvent. Returns an error if the GatewayEvent is already + observed. + # Errors + Returns an error if the GatewayEvent is already observed. + Error type: [`ObserverError::AlreadySubscribedError`] + */ + pub fn subscribe(&mut self, observable: &'a dyn Observer) -> Option { + if self.is_observed { + return Some(ObserverError::AlreadySubscribedError); + } + self.is_observed = true; + self.observers.push(observable); + None + } + + /** + Unsubscribes an Observer from the GatewayEvent. + */ + pub fn unsubscribe(&mut self, observable: &'a dyn Observer) { + // .retain()'s closure retains only those elements of the vector, which have a different + // pointer value than observable. + self.observers.retain(|obs| !std::ptr::eq(*obs, observable)); + self.is_observed = !self.observers.is_empty(); + return; + } + + /** + Updates the GatewayEvent's data and notifies the observers. + */ + fn update_data(&mut self, new_event_data: T) { + self.event_data = new_event_data; + self.notify(); + } + + /** + Notifies the observers of the GatewayEvent. + */ + fn notify(&self) { + for observer in &self.observers { + observer.update(&self.event_data); + } + } +} + +mod events { + use super::*; + #[derive(Default)] + pub struct Events<'a> { + pub message: Message<'a>, + pub user: User<'a>, + pub gateway_identify_payload: GatewayEvent<'a, GatewayIdentifyPayload>, + pub gateway_resume: GatewayEvent<'a, GatewayResume>, + } + + #[derive(Default)] + pub struct Message<'a> { + pub create: GatewayEvent<'a, MessageCreate>, + pub update: GatewayEvent<'a, MessageUpdate>, + pub delete: GatewayEvent<'a, MessageDelete>, + pub delete_bulk: GatewayEvent<'a, MessageDeleteBulk>, + pub reaction_add: GatewayEvent<'a, MessageReactionAdd>, + pub reaction_remove: GatewayEvent<'a, MessageReactionRemove>, + pub reaction_remove_all: GatewayEvent<'a, MessageReactionRemoveAll>, + pub reaction_remove_emoji: GatewayEvent<'a, MessageReactionRemoveEmoji>, + } + + #[derive(Default)] + pub struct User<'a> { + pub presence_update: GatewayEvent<'a, PresenceUpdate>, + pub typing_start_event: GatewayEvent<'a, TypingStartEvent>, + } +} + +#[cfg(test)] +mod example { + use super::*; + use crate::api::types::GatewayResume; + + struct Consumer; + impl Observer for Consumer { + fn update(&self, data: &GatewayResume) { + println!("{}", data.token) + } + } + + #[test] + fn test_observer_behaviour() { + let mut event = GatewayEvent::new(GatewayResume { + token: "start".to_string(), + session_id: "start".to_string(), + seq: "start".to_string(), + }); + + let new_data = GatewayResume { + token: "token_3276ha37am3".to_string(), + session_id: "89346671230".to_string(), + seq: "3".to_string(), + }; + + let consumer = Consumer; + + event.subscribe(&consumer); + + event.notify(); + + event.update_data(new_data); + + let second_consumer = Consumer; + + match event.subscribe(&second_consumer) { + None => assert!(false), + Some(err) => println!("You cannot subscribe twice: {}", err), + } + } }