diff --git a/Cargo.lock b/Cargo.lock index c90b667..d193cf4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,6 +189,7 @@ dependencies = [ "async-trait", "base64 0.21.2", "bitflags 2.3.3", + "chorus-macros", "chrono", "custom_error", "futures-util", @@ -215,6 +216,14 @@ dependencies = [ "url", ] +[[package]] +name = "chorus-macros" +version = "0.1.0" +dependencies = [ + "quote", + "syn 2.0.26", +] + [[package]] name = "chrono" version = "0.4.26" diff --git a/Cargo.toml b/Cargo.toml index 4126146..8dbb172 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ thiserror = "1.0.43" jsonwebtoken = "8.3.0" log = "0.4.19" async-trait = "0.1.71" +chorus-macros = {path = "chorus-macros"} [dev-dependencies] tokio = {version = "1.29.1", features = ["full"]} diff --git a/chorus-macros/Cargo.lock b/chorus-macros/Cargo.lock new file mode 100644 index 0000000..4aa96d3 --- /dev/null +++ b/chorus-macros/Cargo.lock @@ -0,0 +1,46 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "chorus-macros" +version = "0.1.0" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fe8a65d69dd0808184ebb5f836ab526bb259db23c657efa38711b1072ee47f0" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b60f673f44a8255b9c8c657daf66a596d435f2da81a555b06dc644d080ba45e0" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" diff --git a/chorus-macros/Cargo.toml b/chorus-macros/Cargo.toml new file mode 100644 index 0000000..cffffe1 --- /dev/null +++ b/chorus-macros/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "chorus-macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +quote = "1" +syn = "2" diff --git a/chorus-macros/src/lib.rs b/chorus-macros/src/lib.rs new file mode 100644 index 0000000..c8d5e87 --- /dev/null +++ b/chorus-macros/src/lib.rs @@ -0,0 +1,18 @@ +use proc_macro::TokenStream; +use quote::quote; + +#[proc_macro_derive(Updateable)] +pub fn updateable_macro_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + // No need for macro hygiene, we're only using this in chorus + quote! { + impl Updateable for #name { + fn id(&self) -> Snowflake { + self.id + } + } + } + .into() +} diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 1109fb8..60ee52b 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -6,9 +6,10 @@ use serde_json::to_string; use crate::api::LimitType; use crate::errors::ChorusResult; +use crate::gateway::Gateway; use crate::instance::{Instance, UserMeta}; use crate::ratelimiter::ChorusRequest; -use crate::types::{LoginResult, LoginSchema}; +use crate::types::{GatewayIdentifyPayload, LoginResult, LoginSchema}; impl Instance { /// Logs into an existing account on the spacebar server. @@ -23,7 +24,8 @@ impl Instance { // We do not have a user yet, and the UserRateLimits will not be affected by a login // request (since login is an instance wide limit), which is why we are just cloning the // instances' limits to pass them on as user_rate_limits later. - let mut shell = UserMeta::shell(Rc::new(RefCell::new(self.clone())), "None".to_string()); + let mut shell = + UserMeta::shell(Rc::new(RefCell::new(self.clone())), "None".to_string()).await; let login_result = chorus_request .deserialize_response::(&mut shell) .await?; @@ -31,12 +33,17 @@ impl Instance { if self.limits_information.is_some() { self.limits_information.as_mut().unwrap().ratelimits = shell.limits.clone().unwrap(); } + let mut identify = GatewayIdentifyPayload::common(); + let gateway = Gateway::new(self.urls.wss.clone()).await.unwrap(); + identify.token = login_result.token.clone(); + gateway.send_identify(identify).await; let user = UserMeta::new( Rc::new(RefCell::new(self.clone())), login_result.token, self.clone_limits_if_some(), login_result.settings, object, + gateway, ); Ok(user) } diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index fa3d6c9..630ad9d 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -3,6 +3,8 @@ use std::{cell::RefCell, rc::Rc}; use reqwest::Client; use serde_json::to_string; +use crate::gateway::Gateway; +use crate::types::GatewayIdentifyPayload; use crate::{ api::policies::instance::LimitType, errors::ChorusResult, @@ -27,7 +29,8 @@ impl Instance { // We do not have a user yet, and the UserRateLimits will not be affected by a login // request (since register is an instance wide limit), which is why we are just cloning // the instances' limits to pass them on as user_rate_limits later. - let mut shell = UserMeta::shell(Rc::new(RefCell::new(self.clone())), "None".to_string()); + let mut shell = + UserMeta::shell(Rc::new(RefCell::new(self.clone())), "None".to_string()).await; let token = chorus_request .deserialize_response::(&mut shell) .await? @@ -37,12 +40,17 @@ impl Instance { } let user_object = self.get_user(token.clone(), None).await.unwrap(); let settings = UserMeta::get_settings(&token, &self.urls.api.clone(), self).await?; + let mut identify = GatewayIdentifyPayload::common(); + let gateway = Gateway::new(self.urls.wss.clone()).await.unwrap(); + identify.token = token.clone(); + gateway.send_identify(identify).await; let user = UserMeta::new( Rc::new(RefCell::new(self.clone())), token.clone(), self.clone_limits_if_some(), settings, user_object, + gateway, ); Ok(user) } diff --git a/src/api/channels/channels.rs b/src/api/channels/channels.rs index a3f7884..a8bc7ee 100644 --- a/src/api/channels/channels.rs +++ b/src/api/channels/channels.rs @@ -41,11 +41,11 @@ impl Channel { /// Modifies a channel with the provided data. /// Replaces self with the new channel object. pub async fn modify( - &mut self, + &self, modify_data: ChannelModifySchema, channel_id: Snowflake, user: &mut UserMeta, - ) -> ChorusResult<()> { + ) -> ChorusResult { let chorus_request = ChorusRequest { request: Client::new() .patch(format!( @@ -57,9 +57,7 @@ impl Channel { .body(to_string(&modify_data).unwrap()), limit_type: LimitType::Channel(channel_id), }; - let new_channel = chorus_request.deserialize_response::(user).await?; - let _ = std::mem::replace(self, new_channel); - Ok(()) + chorus_request.deserialize_response::(user).await } /// Fetches recent messages from a channel. diff --git a/src/api/users/users.rs b/src/api/users/users.rs index 0f6a5ee..9e7d2af 100644 --- a/src/api/users/users.rs +++ b/src/api/users/users.rs @@ -121,7 +121,8 @@ impl User { let request: reqwest::RequestBuilder = Client::new() .get(format!("{}/users/@me/settings/", url_api)) .bearer_auth(token); - let mut user = UserMeta::shell(Rc::new(RefCell::new(instance.clone())), token.clone()); + let mut user = + UserMeta::shell(Rc::new(RefCell::new(instance.clone())), token.clone()).await; let chorus_request = ChorusRequest { request, limit_type: LimitType::Global, @@ -148,7 +149,7 @@ impl Instance { // # Notes // This function is a wrapper around [`User::get`]. pub async fn get_user(&mut self, token: String, id: Option<&String>) -> ChorusResult { - let mut user = UserMeta::shell(Rc::new(RefCell::new(self.clone())), token); + let mut user = UserMeta::shell(Rc::new(RefCell::new(self.clone())), token).await; let result = User::get(&mut user, id).await; if self.limits_information.is_some() { self.limits_information.as_mut().unwrap().ratelimits = diff --git a/src/gateway.rs b/src/gateway.rs index 01968a4..eaac29f 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,10 +1,14 @@ use crate::errors::GatewayError; use crate::gateway::events::Events; -use crate::types; -use crate::types::WebSocketEvent; +use crate::types::{self, Channel, ChannelUpdate, Snowflake}; +use crate::types::{UpdateMessage, WebSocketEvent}; use async_trait::async_trait; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; use std::sync::Arc; use std::time::Duration; +use tokio::sync::watch; use tokio::time::sleep_until; use futures_util::stream::SplitSink; @@ -163,6 +167,12 @@ pub struct GatewayHandle { pub handle: JoinHandle<()>, /// Tells gateway tasks to close kill_send: tokio::sync::broadcast::Sender<()>, + store: Arc>>>, +} + +/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. +pub trait Updateable: 'static + Send + Sync { + fn id(&self) -> Snowflake; } impl GatewayHandle { @@ -186,6 +196,27 @@ impl GatewayHandle { .unwrap(); } + pub async fn observe(&self, object: T) -> watch::Receiver { + let mut store = self.store.lock().await; + if let Some(channel) = store.get(&object.id()) { + let (_, rx) = channel + .downcast_ref::<(watch::Sender, watch::Receiver)>() + .unwrap_or_else(|| { + panic!( + "Snowflake {} already exists in the store, but it is not of type T.", + object.id() + ) + }); + rx.clone() + } else { + let id = object.id(); + let channel = watch::channel(object); + let receiver = channel.1.clone(); + store.insert(id, Box::new(channel)); + receiver + } + } + /// Sends an identify event to the gateway pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { let to_send_value = serde_json::to_value(&to_send).unwrap(); @@ -263,9 +294,9 @@ impl GatewayHandle { } pub struct Gateway { - pub events: Arc>, + events: Arc>, heartbeat_handler: HeartbeatHandler, - pub websocket_send: Arc< + websocket_send: Arc< Mutex< SplitSink< WebSocketStream>, @@ -273,8 +304,9 @@ pub struct Gateway { >, >, >, - pub websocket_receive: SplitStream>>, + websocket_receive: SplitStream>>, kill_send: tokio::sync::broadcast::Sender<()>, + store: Arc>>>, } impl Gateway { @@ -325,6 +357,8 @@ impl Gateway { let events = Events::default(); let shared_events = Arc::new(Mutex::new(events)); + let store = Arc::new(Mutex::new(HashMap::new())); + let mut gateway = Gateway { events: shared_events.clone(), heartbeat_handler: HeartbeatHandler::new( @@ -335,6 +369,7 @@ impl Gateway { websocket_send: shared_websocket_send.clone(), websocket_receive, kill_send: kill_send.clone(), + store: store.clone(), }; // Now we can continuously check for messages in a different task, since we aren't going to receive another hello @@ -348,6 +383,7 @@ impl Gateway { websocket_send: shared_websocket_send.clone(), handle, kill_send: kill_send.clone(), + store, }) } @@ -379,6 +415,7 @@ impl Gateway { /// Deserializes and updates a dispatched event, when we already know its type; /// (Called for every event in handle_message) + #[allow(dead_code)] // TODO: Remove this allow annotation async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>( data: &'a str, event: &mut GatewayEvent, @@ -431,17 +468,25 @@ impl Gateway { trace!("Gateway: Received {event_name}"); macro_rules! handle { - ($($name:literal => $($path:ident).+),*) => { + ($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => { match event_name.as_str() { $($name => { let event = &mut self.events.lock().await.$($path).+; - - let result = - Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) - .await; - - if let Err(err) = result { - warn!("Failed to parse gateway event {event_name} ({err})"); + match serde_json::from_str(gateway_payload.event_data.unwrap().get()) { + Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), + Ok(message) => { + $( + let message: $message_type = message; + if let Some(to_update) = self.store.lock().await.get(&message.id()) { + if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<$update_type>, watch::Receiver<$update_type>)>() { + tx.send_modify(|object| message.update(object)); + } else { + warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id()) + } + } + )? + event.notify(message).await; + } } },)* "RESUMED" => (), @@ -482,7 +527,7 @@ impl Gateway { "AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete, "AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution, "CHANNEL_CREATE" => channel.create, - "CHANNEL_UPDATE" => channel.update, + "CHANNEL_UPDATE" => channel.update ChannelUpdate: Channel, "CHANNEL_UNREAD_UPDATE" => channel.unread_update, "CHANNEL_DELETE" => channel.delete, "CHANNEL_PINS_UPDATE" => channel.pins_update, diff --git a/src/instance.rs b/src/instance.rs index eb4ead3..2360390 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use crate::api::{Limit, LimitType}; use crate::errors::ChorusResult; +use crate::gateway::{Gateway, GatewayHandle}; use crate::ratelimiter::ChorusRequest; use crate::types::types::subconfigs::limits::rates::RateLimits; use crate::types::{GeneralConfiguration, User, UserSettings}; @@ -81,13 +82,14 @@ impl fmt::Display for Token { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct UserMeta { pub belongs_to: Rc>, pub token: String, pub limits: Option>, pub settings: UserSettings, pub object: User, + pub gateway: GatewayHandle, } impl UserMeta { @@ -105,6 +107,7 @@ impl UserMeta { limits: Option>, settings: UserSettings, object: User, + gateway: GatewayHandle, ) -> UserMeta { UserMeta { belongs_to, @@ -112,19 +115,24 @@ impl UserMeta { limits, settings, object, + gateway, } } /// Creates a new 'shell' of a user. The user does not exist as an object, and exists so that you have /// a UserMeta object to make Rate Limited requests with. This is useful in scenarios like /// registering or logging in to the Instance, where you do not yet have a User object, but still - /// need to make a RateLimited request. - pub(crate) fn shell(instance: Rc>, token: String) -> UserMeta { + /// need to make a RateLimited request. To use the [`GatewayHandle`], you will have to identify + /// first. + pub(crate) async fn shell(instance: Rc>, token: String) -> UserMeta { let settings = UserSettings::default(); let object = User::default(); + let wss_url = instance.borrow().urls.wss.clone(); + // Dummy gateway object + let gateway = Gateway::new(wss_url).await.unwrap(); UserMeta { - belongs_to: instance.clone(), token, + belongs_to: instance.clone(), limits: instance .borrow() .limits_information @@ -132,6 +140,7 @@ impl UserMeta { .map(|info| info.ratelimits.clone()), settings, object, + gateway, } } } diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index 01e9752..154b83c 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -1,14 +1,16 @@ +use chorus_macros::Updateable; use chrono::Utc; use serde::{Deserialize, Serialize}; use serde_aux::prelude::deserialize_string_from_number; use serde_repr::{Deserialize_repr, Serialize_repr}; +use crate::gateway::Updateable; use crate::types::{ entities::{GuildMember, User}, utils::Snowflake, }; -#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Updateable)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct Channel { pub application_id: Option, diff --git a/src/types/events/channel.rs b/src/types/events/channel.rs index 488bc47..002230c 100644 --- a/src/types/events/channel.rs +++ b/src/types/events/channel.rs @@ -3,6 +3,8 @@ use crate::types::{entities::Channel, Snowflake}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use super::UpdateMessage; + #[derive(Debug, Default, Deserialize, Serialize)] /// See pub struct ChannelPinsUpdate { @@ -31,6 +33,15 @@ pub struct ChannelUpdate { impl WebSocketEvent for ChannelUpdate {} +impl UpdateMessage for ChannelUpdate { + fn update(&self, object_to_update: &mut Channel) { + *object_to_update = self.channel.clone(); + } + fn id(&self) -> Snowflake { + self.channel.id + } +} + #[derive(Debug, Default, Deserialize, Serialize, Clone)] /// Officially undocumented. /// Sends updates to client about a new message with its id diff --git a/src/types/events/mod.rs b/src/types/events/mod.rs index ed22a70..d477800 100644 --- a/src/types/events/mod.rs +++ b/src/types/events/mod.rs @@ -26,6 +26,10 @@ pub use user::*; pub use voice::*; pub use webhooks::*; +use crate::gateway::Updateable; + +use super::Snowflake; + mod application; mod auto_moderation; mod call; @@ -92,3 +96,23 @@ pub struct GatewayReceivePayload<'a> { } impl<'a> WebSocketEvent for GatewayReceivePayload<'a> {} + +/// An [`UpdateMessage`] represents a received Gateway Message which contains updated +/// information for an [`Updateable`] of Type T. +/// # Example: +/// ```rs +/// impl UpdateMessage for ChannelUpdate { +/// fn update(...) {...} +/// fn id(...) {...} +/// } +/// ``` +/// This would imply, that the [`WebSocketEvent`] "[`ChannelUpdate`]" contains new/updated information +/// about a [`Channel`]. The update method describes how this new information will be turned into +/// a [`Channel`] object. +pub(crate) trait UpdateMessage: Clone +where + T: Updateable, +{ + fn update(&self, object_to_update: &mut T); + fn id(&self) -> Snowflake; +} diff --git a/tests/channels.rs b/tests/channels.rs index e2838be..c8564d7 100644 --- a/tests/channels.rs +++ b/tests/channels.rs @@ -28,10 +28,11 @@ async fn delete_channel() { #[tokio::test] async fn modify_channel() { + const CHANNEL_NAME: &str = "beepboop"; let mut bundle = common::setup().await; let channel = &mut bundle.channel; let modify_data: types::ChannelModifySchema = types::ChannelModifySchema { - name: Some("beepboop".to_string()), + name: Some(CHANNEL_NAME.to_string()), channel_type: None, topic: None, icon: None, @@ -49,10 +50,10 @@ async fn modify_channel() { default_thread_rate_limit_per_user: None, video_quality_mode: None, }; - Channel::modify(channel, modify_data, channel.id, &mut bundle.user) + let modified_channel = Channel::modify(channel, modify_data, channel.id, &mut bundle.user) .await .unwrap(); - assert_eq!(channel.name, Some("beepboop".to_string())); + assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string())); let permission_override = PermissionFlags::from_vec(Vec::from([ PermissionFlags::MANAGE_CHANNELS, diff --git a/tests/common/mod.rs b/tests/common/mod.rs index fc4a19c..747a8cd 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,3 +1,4 @@ +use chorus::gateway::Gateway; use chorus::{ instance::{Instance, UserMeta}, types::{ @@ -18,8 +19,8 @@ pub(crate) struct TestBundle { pub channel: Channel, } +#[allow(unused)] impl TestBundle { - #[allow(unused)] pub(crate) async fn create_user(&mut self, username: &str) -> UserMeta { let register_schema = RegisterSchema { username: username.to_string(), @@ -32,6 +33,16 @@ impl TestBundle { .await .unwrap() } + pub(crate) async fn clone_user_without_gateway(&self) -> UserMeta { + UserMeta { + belongs_to: self.user.belongs_to.clone(), + token: self.user.token.clone(), + limits: self.user.limits.clone(), + settings: self.user.settings.clone(), + object: self.user.object.clone(), + gateway: Gateway::new(self.instance.urls.wss.clone()).await.unwrap(), + } + } } // Set up a test by creating an Instance and a User. Reduces Test boilerplate. diff --git a/tests/gateway.rs b/tests/gateway.rs index c6f46dd..21a2018 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -1,13 +1,15 @@ mod common; + use chorus::gateway::*; -use chorus::types; +use chorus::types::{self, Channel}; #[tokio::test] /// Tests establishing a connection (hello and heartbeats) on the local gateway; async fn test_gateway_establish() { let bundle = common::setup().await; - Gateway::new(bundle.urls.wss).await.unwrap(); + Gateway::new(bundle.urls.wss.clone()).await.unwrap(); + common::teardown(bundle).await } #[tokio::test] @@ -15,10 +17,30 @@ async fn test_gateway_establish() { async fn test_gateway_authenticate() { let bundle = common::setup().await; - let gateway = Gateway::new(bundle.urls.wss).await.unwrap(); + let gateway = Gateway::new(bundle.urls.wss.clone()).await.unwrap(); let mut identify = types::GatewayIdentifyPayload::common(); - identify.token = bundle.user.token; + identify.token = bundle.user.token.clone(); gateway.send_identify(identify).await; + common::teardown(bundle).await +} + +#[tokio::test] +async fn test_self_updating_structs() { + let mut bundle = common::setup().await; + let channel_updater = bundle.user.gateway.observe(bundle.channel.clone()).await; + let received_channel = channel_updater.borrow().clone(); + assert_eq!(received_channel, bundle.channel); + let channel = &mut bundle.channel; + let modify_data = types::ChannelModifySchema { + name: Some("beepboop".to_string()), + ..Default::default() + }; + Channel::modify(channel, modify_data, channel.id, &mut bundle.user) + .await + .unwrap(); + let received_channel = channel_updater.borrow(); + assert_eq!(received_channel.name.as_ref().unwrap(), "beepboop"); + common::teardown(bundle).await } diff --git a/tests/invites.rs b/tests/invites.rs index b6206b9..d19be61 100644 --- a/tests/invites.rs +++ b/tests/invites.rs @@ -1,14 +1,12 @@ -use chorus::types::CreateChannelInviteSchema; - mod common; - +use chorus::types::CreateChannelInviteSchema; #[tokio::test] async fn create_accept_invite() { let mut bundle = common::setup().await; let channel = bundle.channel.clone(); - let mut user = bundle.user.clone(); - let create_channel_invite_schema = CreateChannelInviteSchema::default(); let mut other_user = bundle.create_user("testuser1312").await; + let user = &mut bundle.user; + let create_channel_invite_schema = CreateChannelInviteSchema::default(); assert!(chorus::types::Guild::get(bundle.guild.id, &mut other_user) .await .is_err());