diff --git a/Cargo.lock b/Cargo.lock index d193cf4..1e8be43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,7 +77,7 @@ checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -220,8 +220,9 @@ dependencies = [ name = "chorus-macros" version = "0.1.0" dependencies = [ + "async-trait", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -361,7 +362,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -372,7 +373,7 @@ checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ "darling_core", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -569,7 +570,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -1186,7 +1187,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -1302,7 +1303,7 @@ checksum = "ec2e072ecce94ec471b13398d5402c188e76ac03cf74dd1a975161b23a3f6d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -1683,7 +1684,7 @@ checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -1705,7 +1706,7 @@ checksum = "1d89a8107374290037607734c0b73a85db7ed80cae314b3c5791f192a496e731" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -1745,7 +1746,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -1978,9 +1979,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.26" +version = "2.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45c3457aacde3c65315de5031ec191ce46604304d2446e803d71ade03308d970" +checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" dependencies = [ "proc-macro2", "quote", @@ -2018,7 +2019,7 @@ checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -2102,7 +2103,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -2206,7 +2207,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", ] [[package]] @@ -2384,7 +2385,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", "wasm-bindgen-shared", ] @@ -2418,7 +2419,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.28", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/chorus-macros/Cargo.lock b/chorus-macros/Cargo.lock index 4aa96d3..0a84f90 100644 --- a/chorus-macros/Cargo.lock +++ b/chorus-macros/Cargo.lock @@ -2,10 +2,22 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "async-trait" +version = "0.1.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "chorus-macros" version = "0.1.0" dependencies = [ + "async-trait", "quote", "syn", ] diff --git a/chorus-macros/Cargo.toml b/chorus-macros/Cargo.toml index cffffe1..dc04a10 100644 --- a/chorus-macros/Cargo.toml +++ b/chorus-macros/Cargo.toml @@ -9,3 +9,4 @@ proc-macro = true [dependencies] quote = "1" syn = "2" +async-trait = "0.1.71" diff --git a/chorus-macros/src/lib.rs b/chorus-macros/src/lib.rs index c8d5e87..f825568 100644 --- a/chorus-macros/src/lib.rs +++ b/chorus-macros/src/lib.rs @@ -1,5 +1,6 @@ use proc_macro::TokenStream; use quote::quote; +use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, FieldsNamed}; #[proc_macro_derive(Updateable)] pub fn updateable_macro_derive(input: TokenStream) -> TokenStream { @@ -16,3 +17,106 @@ pub fn updateable_macro_derive(input: TokenStream) -> TokenStream { } .into() } + +#[proc_macro_derive(JsonField)] +pub fn jsonfield_macro_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + // No need for macro hygiene, we're only using this in chorus + quote! { + impl JsonField for #name { + fn get_json(&self) -> String { + self.json.clone() + } + fn set_json(&mut self, json: String) { + self.json = json; + } + } + } + .into() +} + +#[proc_macro_attribute] +pub fn observe_option(_args: TokenStream, input: TokenStream) -> TokenStream { + input +} + +#[proc_macro_attribute] +pub fn observe_option_vec(_args: TokenStream, input: TokenStream) -> TokenStream { + input +} + +#[proc_macro_attribute] +pub fn observe(_args: TokenStream, input: TokenStream) -> TokenStream { + input +} + +#[proc_macro_attribute] +pub fn observe_vec(_args: TokenStream, input: TokenStream) -> TokenStream { + input +} + +#[proc_macro_derive( + Composite, + attributes(observe_option_vec, observe_option, observe, observe_vec) +)] +pub fn composite_derive(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + let process_field = |field: &Field| { + let field_name = &field.ident; + let attrs = &field.attrs; + + let observe_option = attrs + .iter() + .any(|attr| attr.path().is_ident("observe_option")); + let observe_option_vec = attrs + .iter() + .any(|attr| attr.path().is_ident("observe_option_vec")); + let observe = attrs.iter().any(|attr| attr.path().is_ident("observe")); + let observe_vec = attrs.iter().any(|attr| attr.path().is_ident("observe_vec")); + + match (observe_option, observe_option_vec, observe, observe_vec) { + (true, _, _, _) => quote! { + #field_name: Self::option_observe_fn(self.#field_name, gateway).await + }, + (_, true, _, _) => quote! { + #field_name: Self::option_vec_observe_fn(self.#field_name, gateway).await + }, + (_, _, true, _) => quote! { + #field_name: Self::value_observe_fn(self.#field_name, gateway).await + }, + (_, _, _, true) => quote! { + #field_name: Self::vec_observe_fn(self.#field_name, gateway).await + }, + _ => quote! { + #field_name: self.#field_name + }, + } + }; + + match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(FieldsNamed { named, .. }) => { + let field_exprs = named.iter().map(process_field); + + let ident = &input.ident; + let expanded = quote! { + #[async_trait::async_trait(?Send)] + impl Composite for #ident { + async fn watch_whole(self, gateway: &GatewayHandle) -> Self { + Self { + #(#field_exprs,)* + } + } + } + }; + + TokenStream::from(expanded) + } + _ => panic!("Composite derive macro only supports named fields"), + }, + _ => panic!("Composite derive macro only supports structs"), + } +} diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 447e4c0..3b3ffc6 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -1,6 +1,6 @@ use std::cell::RefCell; use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use reqwest::Client; use serde_json::to_string; @@ -46,7 +46,7 @@ impl Instance { login_result.token, self.clone_limits_if_some(), login_result.settings, - Arc::new(Mutex::new(object)), + Arc::new(RwLock::new(object)), gateway, ); Ok(user) diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index ea74f29..7beaa76 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use std::{cell::RefCell, rc::Rc}; use reqwest::Client; @@ -52,8 +52,8 @@ impl Instance { Rc::new(RefCell::new(self.clone())), token.clone(), self.clone_limits_if_some(), - Arc::new(Mutex::new(settings)), - Arc::new(Mutex::new(user_object)), + Arc::new(RwLock::new(settings)), + Arc::new(RwLock::new(user_object)), gateway, ); Ok(user) diff --git a/src/api/channels/channels.rs b/src/api/channels/channels.rs index 0c6f3e3..1e8f93c 100644 --- a/src/api/channels/channels.rs +++ b/src/api/channels/channels.rs @@ -64,9 +64,9 @@ impl Channel { pub async fn modify( &self, modify_data: ChannelModifySchema, - channel_id: Snowflake, user: &mut UserMeta, ) -> ChorusResult { + let channel_id = self.id; let chorus_request = ChorusRequest { request: Client::new() .patch(format!( diff --git a/src/api/guilds/roles.rs b/src/api/guilds/roles.rs index 91b7bd4..271cf90 100644 --- a/src/api/guilds/roles.rs +++ b/src/api/guilds/roles.rs @@ -120,13 +120,13 @@ impl types::RoleObject { .await } - /// Updates a role in a guild. + /// Modifies a role in a guild. /// /// Requires the [`MANAGE_ROLES`](crate::types::PermissionFlags::MANAGE_ROLES) permission. /// /// # Reference /// See - pub async fn update( + pub async fn modify( user: &mut UserMeta, guild_id: Snowflake, role_id: Snowflake, diff --git a/src/api/users/users.rs b/src/api/users/users.rs index fd28b7a..ff0638d 100644 --- a/src/api/users/users.rs +++ b/src/api/users/users.rs @@ -1,4 +1,3 @@ -use std::sync::{Arc, Mutex}; use std::{cell::RefCell, rc::Rc}; use reqwest::Client; @@ -21,8 +20,8 @@ impl UserMeta { /// # Reference /// See and /// - pub async fn get(user: &mut UserMeta, id: Option<&String>) -> ChorusResult { - User::get(user, id).await + pub async fn get_user(&mut self, id: Option<&String>) -> ChorusResult { + User::get(self, id).await } /// Gets the user's settings. @@ -56,12 +55,7 @@ impl UserMeta { request, limit_type: LimitType::default(), }; - let user_updated = chorus_request - .deserialize_response::(self) - .await - .unwrap(); - self.object = Arc::new(Mutex::new(user_updated.clone())); - Ok(user_updated) + chorus_request.deserialize_response::(self).await } /// Deletes the user from the Instance. diff --git a/src/gateway.rs b/src/gateway.rs index eaac29f..9bd4e12 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,14 +1,15 @@ use crate::errors::GatewayError; use crate::gateway::events::Events; -use crate::types::{self, Channel, ChannelUpdate, Snowflake}; -use crate::types::{UpdateMessage, WebSocketEvent}; +use crate::types::{ + self, Channel, ChannelUpdate, Composite, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, + RoleObject, Snowflake, UpdateMessage, WebSocketEvent, +}; use async_trait::async_trait; use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::time::Duration; -use tokio::sync::watch; use tokio::time::sleep_until; use futures_util::stream::SplitSink; @@ -148,6 +149,8 @@ impl GatewayMessage { } } +pub type ObservableObject = dyn Send + Sync + Any; + /// Represents a handle to a Gateway connection. A Gateway connection will create observable /// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently /// implemented types with the trait [`WebSocketEvent`] @@ -167,7 +170,7 @@ pub struct GatewayHandle { pub handle: JoinHandle<()>, /// Tells gateway tasks to close kill_send: tokio::sync::broadcast::Sender<()>, - store: Arc>>>, + pub(crate) store: Arc>>>>, } /// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. @@ -196,27 +199,57 @@ impl GatewayHandle { .unwrap(); } - pub async fn observe(&self, object: T) -> watch::Receiver { + pub async fn observe>( + &self, + object: Arc>, + ) -> Arc> { let mut store = self.store.lock().await; - if let Some(channel) = store.get(&object.id()) { - let (_, rx) = channel - .downcast_ref::<(watch::Sender, watch::Receiver)>() + let id = object.read().unwrap().id(); + if let Some(channel) = store.get(&id) { + let object = channel.clone(); + drop(store); + object + .read() + .unwrap() + .downcast_ref::() .unwrap_or_else(|| { panic!( "Snowflake {} already exists in the store, but it is not of type T.", - object.id() + id ) }); - rx.clone() + let ptr = Arc::into_raw(object.clone()); + // SAFETY: + // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. + // - This operation doesn't read or write any shared data, and thus cannot cause a data race + // - The reference count is not being modified + let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock).clone() }; + let object = downcasted.read().unwrap().clone(); + + let watched_object = object.watch_whole(self).await; + *downcasted.write().unwrap() = watched_object; + downcasted } else { - let id = object.id(); - let channel = watch::channel(object); - let receiver = channel.1.clone(); - store.insert(id, Box::new(channel)); - receiver + let id = object.read().unwrap().id(); + let object = object.read().unwrap().clone(); + let object = object.clone().watch_whole(self).await; + let wrapped = Arc::new(RwLock::new(object)); + store.insert(id, wrapped.clone()); + wrapped } } + /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` + /// with all of its observable fields being observed. + pub async fn observe_and_into_inner>( + &self, + object: Arc>, + ) -> T { + let channel = self.observe(object.clone()).await; + let object = channel.read().unwrap().clone(); + object + } + /// Sends an identify event to the gateway pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { let to_send_value = serde_json::to_value(&to_send).unwrap(); @@ -306,7 +339,7 @@ pub struct Gateway { >, websocket_receive: SplitStream>>, kill_send: tokio::sync::broadcast::Sender<()>, - store: Arc>>>, + store: Arc>>>>, } impl Gateway { @@ -472,14 +505,26 @@ impl Gateway { match event_name.as_str() { $($name => { let event = &mut self.events.lock().await.$($path).+; - match serde_json::from_str(gateway_payload.event_data.unwrap().get()) { + let json = gateway_payload.event_data.unwrap().get(); + match serde_json::from_str(json) { Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), Ok(message) => { $( - let message: $message_type = message; - if let Some(to_update) = self.store.lock().await.get(&message.id()) { - if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<$update_type>, watch::Receiver<$update_type>)>() { - tx.send_modify(|object| message.update(object)); + let mut message: $message_type = message; + let store = self.store.lock().await; + if let Some(to_update) = store.get(&message.id()) { + let object = to_update.clone(); + let inner_object = object.read().unwrap(); + if let Some(_) = inner_object.downcast_ref::<$update_type>() { + let ptr = Arc::into_raw(object.clone()); + // SAFETY: + // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. + // - This operation doesn't read or write any shared data, and thus cannot cause a data race + // - The reference count is not being modified + let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<$update_type>).clone() }; + drop(inner_object); + message.set_json(json.to_string()); + message.update(downcasted.clone()); } else { warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id()) } @@ -553,8 +598,8 @@ impl Gateway { "GUILD_MEMBER_REMOVE" => guild.member_remove, "GUILD_MEMBER_UPDATE" => guild.member_update, "GUILD_MEMBERS_CHUNK" => guild.members_chunk, - "GUILD_ROLE_CREATE" => guild.role_create, - "GUILD_ROLE_UPDATE" => guild.role_update, + "GUILD_ROLE_CREATE" => guild.role_create GuildRoleCreate: Guild, + "GUILD_ROLE_UPDATE" => guild.role_update GuildRoleUpdate: RoleObject, "GUILD_ROLE_DELETE" => guild.role_delete, "GUILD_SCHEDULED_EVENT_CREATE" => guild.role_scheduled_event_create, "GUILD_SCHEDULED_EVENT_UPDATE" => guild.role_scheduled_event_update, diff --git a/src/instance.rs b/src/instance.rs index 5c2e98b..4ed0f60 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::fmt; use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -91,8 +91,8 @@ pub struct UserMeta { pub belongs_to: Rc>, pub token: String, pub limits: Option>, - pub settings: Arc>, - pub object: Arc>, + pub settings: Arc>, + pub object: Arc>, pub gateway: GatewayHandle, } @@ -114,8 +114,8 @@ impl UserMeta { belongs_to: Rc>, token: String, limits: Option>, - settings: Arc>, - object: Arc>, + settings: Arc>, + object: Arc>, gateway: GatewayHandle, ) -> UserMeta { UserMeta { @@ -134,8 +134,8 @@ impl UserMeta { /// need to make a RateLimited request. To use the [`GatewayHandle`], you will have to identify /// first. pub(crate) async fn shell(instance: Rc>, token: String) -> UserMeta { - let settings = Arc::new(Mutex::new(UserSettings::default())); - let object = Arc::new(Mutex::new(User::default())); + let settings = Arc::new(RwLock::new(UserSettings::default())); + let object = Arc::new(RwLock::new(User::default())); let wss_url = instance.borrow().urls.wss.clone(); // Dummy gateway object let gateway = Gateway::new(wss_url).await.unwrap(); diff --git a/src/types/entities/application.rs b/src/types/entities/application.rs index ad48ab4..0b55626 100644 --- a/src/types/entities/application.rs +++ b/src/types/entities/application.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use bitflags::bitflags; use serde::{Deserialize, Serialize}; @@ -27,7 +27,7 @@ pub struct Application { pub bot_require_code_grant: bool, pub verify_key: String, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub owner: Arc>, + pub owner: Arc>, pub flags: u64, #[cfg(feature = "sqlx")] pub redirect_uris: Option>>, @@ -49,7 +49,7 @@ pub struct Application { #[cfg(feature = "sqlx")] pub install_params: Option>, #[cfg(not(feature = "sqlx"))] - pub install_params: Option>>, + pub install_params: Option>>, pub terms_of_service_url: Option, pub privacy_policy_url: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] @@ -142,7 +142,7 @@ pub struct ApplicationCommand { pub application_id: Snowflake, pub name: String, pub description: String, - pub options: Vec>>, + pub options: Vec>>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -154,7 +154,7 @@ pub struct ApplicationCommandOption { pub description: String, pub required: bool, pub choices: Vec, - pub options: Arc>>, + pub options: Arc>>, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -190,14 +190,14 @@ pub enum ApplicationCommandOptionType { pub struct ApplicationCommandInteractionData { pub id: Snowflake, pub name: String, - pub options: Vec>>, + pub options: Vec>>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ApplicationCommandInteractionDataOption { pub name: String, pub value: Value, - pub options: Vec>>, + pub options: Vec>>, } #[derive(Debug, Default, Clone, Serialize, Deserialize)] @@ -206,7 +206,7 @@ pub struct GuildApplicationCommandPermissions { pub id: Snowflake, pub application_id: Snowflake, pub guild_id: Snowflake, - pub permissions: Vec>>, + pub permissions: Vec>>, } #[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] diff --git a/src/types/entities/audit_log.rs b/src/types/entities/audit_log.rs index e0d05e3..be14f0f 100644 --- a/src/types/entities/audit_log.rs +++ b/src/types/entities/audit_log.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use serde::{Deserialize, Serialize}; @@ -8,7 +8,7 @@ use crate::types::utils::Snowflake; /// See pub struct AuditLogEntry { pub target_id: Option, - pub changes: Option>>>, + pub changes: Option>>>, pub user_id: Option, pub id: Snowflake, // to:do implement an enum for these types diff --git a/src/types/entities/auto_moderation.rs b/src/types/entities/auto_moderation.rs index afb1ec6..77f4fa2 100644 --- a/src/types/entities/auto_moderation.rs +++ b/src/types/entities/auto_moderation.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; @@ -14,8 +14,8 @@ pub struct AutoModerationRule { pub creator_id: Snowflake, pub event_type: AutoModerationRuleEventType, pub trigger_type: AutoModerationRuleTriggerType, - pub trigger_metadata: Arc>, - pub actions: Vec>>, + pub trigger_metadata: Arc>, + pub actions: Vec>>, pub enabled: bool, pub exempt_roles: Vec, pub exempt_channels: Vec, @@ -92,7 +92,7 @@ pub enum AutoModerationRuleKeywordPresetType { pub struct AutoModerationAction { #[serde(rename = "type")] pub action_type: AutoModerationActionType, - pub metadata: Option>>, + pub metadata: Option>>, } #[derive(Serialize_repr, Deserialize_repr, Debug, Clone, Default)] diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index 61754f0..3866e01 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -1,18 +1,20 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; -use chorus_macros::Updateable; +use chorus_macros::{observe_option_vec, Composite, Updateable}; use chrono::Utc; use serde::{Deserialize, Serialize}; use serde_aux::prelude::deserialize_string_from_number; use serde_repr::{Deserialize_repr, Serialize_repr}; +use std::fmt::Debug; -use crate::gateway::Updateable; +use crate::gateway::{GatewayHandle, Updateable}; use crate::types::{ entities::{GuildMember, User}, utils::Snowflake, + Composite, }; -#[derive(Default, Debug, Serialize, Deserialize, Clone, Updateable)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, Updateable, Composite)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] /// Represents a guild of private channel /// @@ -48,7 +50,7 @@ pub struct Channel { pub last_pin_timestamp: Option, pub managed: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub member: Option>>, + pub member: Option, pub member_count: Option, pub message_count: Option, pub name: Option, @@ -58,12 +60,14 @@ pub struct Channel { #[cfg(feature = "sqlx")] pub permission_overwrites: Option>>, #[cfg(not(feature = "sqlx"))] - pub permission_overwrites: Option>>>, + #[observe_option_vec] + pub permission_overwrites: Option>>>, pub permissions: Option, pub position: Option, pub rate_limit_per_user: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub recipients: Option>>>, + #[observe_option_vec] + pub recipients: Option>>>, pub rtc_region: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub thread_metadata: Option, @@ -122,7 +126,9 @@ pub struct Tag { pub emoji_name: Option, } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd)] +#[derive( + Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Updateable, Composite, +)] pub struct PermissionOverwrite { pub id: Snowflake, #[serde(rename = "type")] @@ -156,7 +162,7 @@ pub struct ThreadMember { pub user_id: Option, pub join_timestamp: Option, pub flags: Option, - pub member: Option>>, + pub member: Option>>, } #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] diff --git a/src/types/entities/emoji.rs b/src/types/entities/emoji.rs index 38d7da7..68700e6 100644 --- a/src/types/entities/emoji.rs +++ b/src/types/entities/emoji.rs @@ -1,23 +1,26 @@ -use std::sync::{Arc, Mutex}; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; +use chorus_macros::{Composite, Updateable}; use serde::{Deserialize, Serialize}; +use crate::gateway::{GatewayHandle, Updateable}; use crate::types::entities::User; -use crate::types::Snowflake; +use crate::types::{Composite, Snowflake}; -#[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default, Updateable, Composite)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] /// # Reference /// See pub struct Emoji { - pub id: Option, + pub id: Snowflake, pub name: Option, #[cfg(feature = "sqlx")] pub roles: Option>>, #[cfg(not(feature = "sqlx"))] pub roles: Option>, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub user: Option>>, + pub user: Option>>, pub require_colons: Option, pub managed: Option, pub animated: Option, diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index 651884f..3f59055 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -1,18 +1,22 @@ -use std::sync::{Arc, Mutex}; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; +use chorus_macros::{observe_option_vec, observe_vec, Composite, Updateable}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; +use crate::gateway::{GatewayHandle, Updateable}; use crate::types::types::guild_configuration::GuildFeaturesList; use crate::types::{ entities::{Channel, Emoji, RoleObject, Sticker, User, VoiceState, Webhook}, interfaces::WelcomeScreenObject, utils::Snowflake, + Composite, }; /// See -#[derive(Serialize, Deserialize, Debug, Default, Clone)] +#[derive(Serialize, Deserialize, Debug, Default, Clone, Updateable, Composite)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct Guild { pub afk_channel_id: Option, @@ -27,13 +31,15 @@ pub struct Guild { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub bans: Option>, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub channels: Option>>>, + #[observe_option_vec] + pub channels: Option>>>, pub default_message_notifications: Option, pub description: Option, pub discovery_splash: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] #[serde(default)] - pub emojis: Vec>>, + #[observe_vec] + pub emojis: Vec>>, pub explicit_content_filter: Option, //#[cfg_attr(feature = "sqlx", sqlx(try_from = "String"))] pub features: Option, @@ -42,7 +48,7 @@ pub struct Guild { pub icon_hash: Option, pub id: Snowflake, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub invites: Option>>>, + pub invites: Option>, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub joined_at: Option, pub large: Option, @@ -68,7 +74,8 @@ pub struct Guild { pub public_updates_channel_id: Option, pub region: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub roles: Option>>>, + #[observe_option_vec] + pub roles: Option>>>, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub rules_channel: Option, pub rules_channel_id: Option, @@ -81,13 +88,15 @@ pub struct Guild { pub vanity_url_code: Option, pub verification_level: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub voice_states: Option>>>, + #[observe_option_vec] + pub voice_states: Option>>>, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub webhooks: Option>>>, + #[observe_option_vec] + pub webhooks: Option>>>, #[cfg(feature = "sqlx")] pub welcome_screen: Option>, #[cfg(not(feature = "sqlx"))] - pub welcome_screen: Option>>, + pub welcome_screen: Option, pub widget_channel_id: Option, pub widget_enabled: Option, } @@ -113,11 +122,11 @@ pub struct GuildInvite { pub created_at: DateTime, pub expires_at: Option>, pub guild_id: Snowflake, - pub guild: Option>>, + pub guild: Option>>, pub channel_id: Snowflake, - pub channel: Option>>, + pub channel: Option>>, pub inviter_id: Option, - pub inviter: Option>>, + pub inviter: Option>>, pub target_user_id: Option, pub target_user: Option, pub target_user_type: Option, @@ -151,7 +160,7 @@ pub struct GuildScheduledEvent { pub entity_type: GuildScheduledEventEntityType, pub entity_id: Option, pub entity_metadata: Option, - pub creator: Option>>, + pub creator: Option>>, pub user_count: Option, pub image: Option, } diff --git a/src/types/entities/guild_member.rs b/src/types/entities/guild_member.rs index af63cfb..bf2f93b 100644 --- a/src/types/entities/guild_member.rs +++ b/src/types/entities/guild_member.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use serde::{Deserialize, Serialize}; @@ -10,7 +10,7 @@ use crate::types::{entities::PublicUser, Snowflake}; /// # Reference /// See pub struct GuildMember { - pub user: Option>>, + pub user: Option>>, pub nick: Option, pub avatar: Option, pub roles: Vec, diff --git a/src/types/entities/integration.rs b/src/types/entities/integration.rs index a95c33c..0913213 100644 --- a/src/types/entities/integration.rs +++ b/src/types/entities/integration.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -23,14 +23,14 @@ pub struct Integration { pub expire_behaviour: Option, pub expire_grace_period: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub user: Option>>, + pub user: Option>>, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub account: IntegrationAccount, pub synced_at: Option>, pub subscriber_count: Option, pub revoked: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub application: Option>>, + pub application: Option>>, pub scopes: Option>, } diff --git a/src/types/entities/invite.rs b/src/types/entities/invite.rs index 7eefd98..1e6befc 100644 --- a/src/types/entities/invite.rs +++ b/src/types/entities/invite.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -70,7 +70,7 @@ pub enum NSFWLevel { /// See #[derive(Debug, Serialize, Deserialize)] pub struct InviteStageInstance { - pub members: Vec>>, + pub members: Vec>>, pub participant_count: i32, pub speaker_count: i32, pub topic: String, diff --git a/src/types/entities/mod.rs b/src/types/entities/mod.rs index 6f1f58b..abdf976 100644 --- a/src/types/entities/mod.rs +++ b/src/types/entities/mod.rs @@ -22,6 +22,11 @@ pub use user_settings::*; pub use voice_state::*; pub use webhook::*; +use crate::gateway::{GatewayHandle, Updateable}; +use async_trait::async_trait; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; + mod application; mod attachment; mod audit_log; @@ -45,3 +50,62 @@ mod user; mod user_settings; mod voice_state; mod webhook; + +#[async_trait(?Send)] +pub trait Composite { + async fn watch_whole(self, gateway: &GatewayHandle) -> Self; + + async fn option_observe_fn( + value: Option>>, + gateway: &GatewayHandle, + ) -> Option>> + where + T: Composite + Debug, + { + if let Some(value) = value { + let value = value.clone(); + Some(gateway.observe(value).await) + } else { + None + } + } + + async fn option_vec_observe_fn( + value: Option>>>, + gateway: &GatewayHandle, + ) -> Option>>> + where + T: Composite, + { + if let Some(value) = value { + let mut vec = Vec::new(); + for component in value.into_iter() { + vec.push(gateway.observe(component).await); + } + Some(vec) + } else { + None + } + } + + async fn value_observe_fn(value: Arc>, gateway: &GatewayHandle) -> Arc> + where + T: Composite, + { + gateway.observe(value).await + } + + async fn vec_observe_fn( + value: Vec>>, + gateway: &GatewayHandle, + ) -> Vec>> + where + T: Composite, + { + let mut vec = Vec::new(); + for component in value.into_iter() { + vec.push(gateway.observe(component).await); + } + vec + } +} diff --git a/src/types/entities/relationship.rs b/src/types/entities/relationship.rs index a5f5bcb..576d99a 100644 --- a/src/types/entities/relationship.rs +++ b/src/types/entities/relationship.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -15,7 +15,7 @@ pub struct Relationship { #[serde(rename = "type")] pub relationship_type: RelationshipType, pub nickname: Option, - pub user: Arc>, + pub user: Arc>, pub since: Option>, } diff --git a/src/types/entities/role.rs b/src/types/entities/role.rs index 3ad53ce..22dceff 100644 --- a/src/types/entities/role.rs +++ b/src/types/entities/role.rs @@ -1,10 +1,13 @@ use bitflags::bitflags; +use chorus_macros::{Composite, Updateable}; use serde::{Deserialize, Serialize}; use serde_aux::prelude::{deserialize_option_number_from_string, deserialize_string_from_number}; +use std::fmt::Debug; -use crate::types::utils::Snowflake; +use crate::gateway::{GatewayHandle, Updateable}; +use crate::types::{utils::Snowflake, Composite}; -#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, Updateable, Composite)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] /// See pub struct RoleObject { diff --git a/src/types/entities/sticker.rs b/src/types/entities/sticker.rs index 42edb7f..5413112 100644 --- a/src/types/entities/sticker.rs +++ b/src/types/entities/sticker.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use serde::{Deserialize, Serialize}; @@ -24,7 +24,7 @@ pub struct Sticker { pub available: Option, pub guild_id: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub user: Option>>, + pub user: Option>>, pub sort_value: Option, } diff --git a/src/types/entities/team.rs b/src/types/entities/team.rs index fe6562b..8e32f55 100644 --- a/src/types/entities/team.rs +++ b/src/types/entities/team.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use serde::{Deserialize, Serialize}; @@ -21,5 +21,5 @@ pub struct TeamMember { pub membership_state: u8, pub permissions: Vec, pub team_id: Snowflake, - pub user: Arc>, + pub user: Arc>, } diff --git a/src/types/entities/template.rs b/src/types/entities/template.rs index 0550e9e..1305a98 100644 --- a/src/types/entities/template.rs +++ b/src/types/entities/template.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -18,13 +18,13 @@ pub struct GuildTemplate { pub usage_count: Option, pub creator_id: Snowflake, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub creator: Arc>, + pub creator: Arc>, pub created_at: DateTime, pub updated_at: DateTime, pub source_guild_id: Snowflake, #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub source_guild: Vec>>, + pub source_guild: Vec>>, // Unsure how a {recursive: Guild} looks like, might be a Vec? #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub serialized_source_guild: Vec>>, + pub serialized_source_guild: Vec>>, } diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index 5b31b6d..aadfd35 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -1,7 +1,11 @@ -use crate::types::utils::Snowflake; +use chorus_macros::{Composite, Updateable}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serde_aux::prelude::deserialize_option_number_from_string; +use std::fmt::Debug; + +use crate::gateway::{GatewayHandle, Updateable}; +use crate::types::{utils::Snowflake, Composite}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] @@ -15,7 +19,7 @@ impl User { PublicUser::from(self) } } -#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, Eq, Updateable, Composite)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct User { pub id: Snowflake, diff --git a/src/types/entities/user_settings.rs b/src/types/entities/user_settings.rs index 4705a92..1be2c0a 100644 --- a/src/types/entities/user_settings.rs +++ b/src/types/entities/user_settings.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use chrono::{serde::ts_milliseconds_option, Utc}; use serde::{Deserialize, Serialize}; @@ -75,7 +75,7 @@ pub struct UserSettings { #[cfg(not(feature = "sqlx"))] pub restricted_guilds: Vec, pub show_current_game: bool, - pub status: Arc>, + pub status: Arc>, pub stream_notifications_enabled: bool, pub theme: UserTheme, pub timezone_offset: i16, @@ -111,7 +111,7 @@ impl Default for UserSettings { render_reactions: true, restricted_guilds: Default::default(), show_current_game: true, - status: Arc::new(Mutex::new(UserStatus::Online)), + status: Arc::new(RwLock::new(UserStatus::Online)), stream_notifications_enabled: false, theme: UserTheme::Dark, timezone_offset: 0, @@ -151,5 +151,5 @@ pub struct GuildFolder { #[derive(Debug, Serialize, Deserialize)] pub struct LoginResult { pub token: String, - pub settings: Arc>, + pub settings: Arc>, } diff --git a/src/types/entities/voice_state.rs b/src/types/entities/voice_state.rs index 29b8e6e..c879c8e 100644 --- a/src/types/entities/voice_state.rs +++ b/src/types/entities/voice_state.rs @@ -1,22 +1,26 @@ -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; +use chorus_macros::{Composite, Updateable}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use crate::gateway::{GatewayHandle, Updateable}; use crate::types::{ entities::{Guild, GuildMember}, utils::Snowflake, + Composite, }; /// See -#[derive(Serialize, Deserialize, Debug, Default, Clone)] +#[derive(Serialize, Deserialize, Debug, Default, Clone, Updateable, Composite)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct VoiceState { pub guild_id: Option, pub guild: Option, pub channel_id: Option, pub user_id: Snowflake, - pub member: Option>>, + pub member: Option>>, pub session_id: Snowflake, pub token: Option, pub deaf: bool, @@ -27,5 +31,5 @@ pub struct VoiceState { pub self_video: bool, pub suppress: bool, pub request_to_speak_timestamp: Option>, - pub id: Option, + pub id: Snowflake, } diff --git a/src/types/entities/webhook.rs b/src/types/entities/webhook.rs index f8b6530..7771dbf 100644 --- a/src/types/entities/webhook.rs +++ b/src/types/entities/webhook.rs @@ -1,14 +1,18 @@ -use std::sync::{Arc, Mutex}; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; +use chorus_macros::{Composite, Updateable}; use serde::{Deserialize, Serialize}; +use crate::gateway::{GatewayHandle, Updateable}; use crate::types::{ entities::{Guild, User}, utils::Snowflake, + Composite, }; /// See -#[derive(Serialize, Deserialize, Debug, Default, Clone)] +#[derive(Serialize, Deserialize, Debug, Default, Clone, Updateable, Composite)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct Webhook { pub id: Snowflake, @@ -22,10 +26,10 @@ pub struct Webhook { pub application_id: Snowflake, #[serde(skip_serializing_if = "Option::is_none")] #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub user: Option>>, + pub user: Option>>, #[serde(skip_serializing_if = "Option::is_none")] #[cfg_attr(feature = "sqlx", sqlx(skip))] - pub source_guild: Option>>, + pub source_guild: Option>>, #[serde(skip_serializing_if = "Option::is_none")] pub url: Option, } diff --git a/src/types/events/channel.rs b/src/types/events/channel.rs index 002230c..017c50e 100644 --- a/src/types/events/channel.rs +++ b/src/types/events/channel.rs @@ -1,5 +1,8 @@ +use std::sync::{Arc, RwLock}; + use crate::types::events::WebSocketEvent; -use crate::types::{entities::Channel, Snowflake}; +use crate::types::{entities::Channel, JsonField, Snowflake}; +use chorus_macros::JsonField; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -24,18 +27,21 @@ pub struct ChannelCreate { impl WebSocketEvent for ChannelCreate {} -#[derive(Debug, Default, Deserialize, Serialize, Clone)] +#[derive(Debug, Default, Deserialize, Serialize, Clone, JsonField)] /// See pub struct ChannelUpdate { #[serde(flatten)] pub channel: Channel, + #[serde(skip)] + pub json: String, } impl WebSocketEvent for ChannelUpdate {} impl UpdateMessage for ChannelUpdate { - fn update(&self, object_to_update: &mut Channel) { - *object_to_update = self.channel.clone(); + fn update(&mut self, object_to_update: Arc>) { + let mut write = object_to_update.write().unwrap(); + *write = self.channel.clone(); } fn id(&self) -> Snowflake { self.channel.id diff --git a/src/types/events/guild.rs b/src/types/events/guild.rs index 2cbdbd0..5fac86b 100644 --- a/src/types/events/guild.rs +++ b/src/types/events/guild.rs @@ -1,13 +1,17 @@ +use std::sync::{Arc, RwLock}; + +use chorus_macros::JsonField; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use crate::types::entities::{Guild, PublicUser, UnavailableGuild}; use crate::types::events::WebSocketEvent; use crate::types::{ - AuditLogEntry, Emoji, GuildMember, GuildScheduledEvent, RoleObject, Snowflake, Sticker, + AuditLogEntry, Emoji, GuildMember, GuildScheduledEvent, JsonField, RoleObject, Snowflake, + Sticker, }; -use super::PresenceUpdate; +use super::{PresenceUpdate, UpdateMessage}; #[derive(Debug, Deserialize, Serialize, Default, Clone)] /// See ; @@ -164,24 +168,60 @@ pub struct GuildMembersChunk { impl WebSocketEvent for GuildMembersChunk {} -#[derive(Debug, Default, Deserialize, Serialize, Clone)] +#[derive(Debug, Default, Deserialize, Serialize, Clone, JsonField)] /// See pub struct GuildRoleCreate { pub guild_id: Snowflake, pub role: RoleObject, + #[serde(skip)] + pub json: String, } impl WebSocketEvent for GuildRoleCreate {} -#[derive(Debug, Default, Deserialize, Serialize, Clone)] +impl UpdateMessage for GuildRoleCreate { + fn id(&self) -> Snowflake { + self.guild_id + } + + fn update(&mut self, object_to_update: Arc>) { + let mut object_to_update = object_to_update.write().unwrap(); + if object_to_update.roles.is_some() { + object_to_update + .roles + .as_mut() + .unwrap() + .push(Arc::new(RwLock::new(self.role.clone()))); + } else { + object_to_update.roles = Some(Vec::from([Arc::new(RwLock::new(self.role.clone()))])); + } + } +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, JsonField)] /// See pub struct GuildRoleUpdate { pub guild_id: Snowflake, pub role: RoleObject, + #[serde(skip)] + pub json: String, } impl WebSocketEvent for GuildRoleUpdate {} +impl UpdateMessage for GuildRoleUpdate { + fn id(&self) -> Snowflake { + self.role.id + } + + fn update(&mut self, object_to_update: Arc>) { + println!("Processing Role Update. Name: {}", self.role.name); + let mut write = object_to_update.write().unwrap(); + *write = self.role.clone(); + println!("Updated role: Name: {}", write.name); + } +} + #[derive(Debug, Default, Deserialize, Serialize, Clone)] /// See pub struct GuildRoleDelete { diff --git a/src/types/events/mod.rs b/src/types/events/mod.rs index d477800..0a296fb 100644 --- a/src/types/events/mod.rs +++ b/src/types/events/mod.rs @@ -1,3 +1,8 @@ +use std::sync::{Arc, RwLock}; + +use std::collections::HashMap; + +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; pub use application::*; @@ -19,6 +24,7 @@ pub use ready::*; pub use relationship::*; pub use request_members::*; pub use resume::*; +use serde_json::{from_str, from_value, to_value, Value}; pub use session::*; pub use stage_instance::*; pub use thread::*; @@ -109,10 +115,31 @@ impl<'a> WebSocketEvent for GatewayReceivePayload<'a> {} /// This would imply, that the [`WebSocketEvent`] "[`ChannelUpdate`]" contains new/updated information /// about a [`Channel`]. The update method describes how this new information will be turned into /// a [`Channel`] object. -pub(crate) trait UpdateMessage: Clone +pub(crate) trait UpdateMessage: Clone + JsonField where - T: Updateable, + T: Updateable + Serialize + DeserializeOwned + Clone, { - fn update(&self, object_to_update: &mut T); + fn update(&mut self, object_to_update: Arc>) { + update_object(self.get_json(), object_to_update) + } fn id(&self) -> Snowflake; } + +pub(crate) trait JsonField: Clone { + fn set_json(&mut self, json: String); + fn get_json(&self) -> String; +} + +/// Only applicable for events where the Update struct is the same as the Entity struct +pub(crate) fn update_object( + value: String, + object: Arc>, +) { + let data_from_event: HashMap = from_str(&value).unwrap(); + let mut original_data: HashMap = + from_value(to_value(object.clone()).unwrap()).unwrap(); + for (updated_entry_key, updated_entry_value) in data_from_event.into_iter() { + original_data.insert(updated_entry_key.clone(), updated_entry_value); + } + *object.write().unwrap() = from_value(to_value(original_data).unwrap()).unwrap(); +} diff --git a/src/types/schema/role.rs b/src/types/schema/role.rs index 4ae300d..284f506 100644 --- a/src/types/schema/role.rs +++ b/src/types/schema/role.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(rename_all = "snake_case")] /// Represents the schema which needs to be sent to create or modify a Role. /// See: [https://docs.spacebar.chat/routes/#cmp--schemas-rolemodifyschema](https://docs.spacebar.chat/routes/#cmp--schemas-rolemodifyschema) diff --git a/tests/channels.rs b/tests/channels.rs index d810c10..94129f6 100644 --- a/tests/channels.rs +++ b/tests/channels.rs @@ -8,7 +8,7 @@ mod common; #[tokio::test] async fn get_channel() { let mut bundle = common::setup().await; - let bundle_channel = bundle.channel.clone(); + let bundle_channel = bundle.channel.read().unwrap().clone(); let bundle_user = &mut bundle.user; assert_eq!( @@ -21,7 +21,8 @@ async fn get_channel() { #[tokio::test] async fn delete_channel() { let mut bundle = common::setup().await; - let result = Channel::delete(bundle.channel.clone(), &mut bundle.user).await; + let channel_guard = bundle.channel.write().unwrap().clone(); + let result = Channel::delete(channel_guard, &mut bundle.user).await; assert!(result.is_ok()); common::teardown(bundle).await } @@ -30,7 +31,7 @@ async fn delete_channel() { async fn modify_channel() { const CHANNEL_NAME: &str = "beepboop"; let mut bundle = common::setup().await; - let channel = &mut bundle.channel; + let channel = &mut bundle.channel.read().unwrap().clone(); let modify_data: types::ChannelModifySchema = types::ChannelModifySchema { name: Some(CHANNEL_NAME.to_string()), channel_type: None, @@ -50,32 +51,26 @@ async fn modify_channel() { default_thread_rate_limit_per_user: None, video_quality_mode: None, }; - let modified_channel = Channel::modify(channel, modify_data, channel.id, &mut bundle.user) - .await - .unwrap(); + let modified_channel = channel.modify(modify_data, &mut bundle.user).await.unwrap(); assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string())); let permission_override = PermissionFlags::from_vec(Vec::from([ PermissionFlags::MANAGE_CHANNELS, PermissionFlags::MANAGE_MESSAGES, ])); - let user_id: types::Snowflake = bundle.user.object.lock().unwrap().id; + let user_id: types::Snowflake = bundle.user.object.read().unwrap().id; let permission_override = PermissionOverwrite { id: user_id, overwrite_type: "1".to_string(), allow: permission_override, deny: "0".to_string(), }; + let channel_id: Snowflake = bundle.channel.read().unwrap().id; + Channel::edit_permissions(&mut bundle.user, channel_id, permission_override.clone()) + .await + .unwrap(); - Channel::edit_permissions( - &mut bundle.user, - bundle.channel.id, - permission_override.clone(), - ) - .await - .unwrap(); - - Channel::delete_permission(&mut bundle.user, bundle.channel.id, permission_override.id) + Channel::delete_permission(&mut bundle.user, channel_id, permission_override.id) .await .unwrap(); @@ -85,7 +80,7 @@ async fn modify_channel() { #[tokio::test] async fn get_channel_messages() { let mut bundle = common::setup().await; - + let channel_id: Snowflake = bundle.channel.read().unwrap().id; // First create some messages to read for _ in 0..10 { let _ = bundle @@ -95,7 +90,7 @@ async fn get_channel_messages() { content: Some("A Message!".to_string()), ..Default::default() }, - bundle.channel.id, + channel_id, ) .await .unwrap(); @@ -104,7 +99,7 @@ async fn get_channel_messages() { assert_eq!( Channel::messages( GetChannelMessagesSchema::before(Snowflake::generate()), - bundle.channel.id, + channel_id, &mut bundle.user, ) .await @@ -128,7 +123,7 @@ async fn get_channel_messages() { assert!(Channel::messages( GetChannelMessagesSchema::after(Snowflake::generate()), - bundle.channel.id, + channel_id, &mut bundle.user, ) .await @@ -144,7 +139,7 @@ async fn create_dm() { let other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; let private_channel_create_schema = PrivateChannelCreateSchema { - recipients: Some(Vec::from([other_user.object.lock().unwrap().id])), + recipients: Some(Vec::from([other_user.object.read().unwrap().id])), access_tokens: None, nicks: None, }; @@ -160,11 +155,11 @@ async fn create_dm() { .unwrap() .get(0) .unwrap() - .lock() + .read() .unwrap() .id .clone(), - other_user.object.lock().unwrap().id + other_user.object.read().unwrap().id ); assert_eq!( dm_channel @@ -173,11 +168,11 @@ async fn create_dm() { .unwrap() .get(1) .unwrap() - .lock() + .read() .unwrap() .id .clone(), - user.object.lock().unwrap().id.clone() + user.object.read().unwrap().id.clone() ); common::teardown(bundle).await; } @@ -189,9 +184,9 @@ async fn remove_add_person_from_to_dm() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let mut third_user = bundle.create_user("integrationtestuser3").await; - let third_user_id = third_user.object.lock().unwrap().id; - let other_user_id = other_user.object.lock().unwrap().id; - let user_id = bundle.user.object.lock().unwrap().id; + let third_user_id = third_user.object.read().unwrap().id; + let other_user_id = other_user.object.read().unwrap().id; + let user_id = bundle.user.object.read().unwrap().id; let user = &mut bundle.user; let private_channel_create_schema = PrivateChannelCreateSchema { recipients: Some(Vec::from([other_user_id, third_user_id])), @@ -234,7 +229,7 @@ async fn remove_add_person_from_to_dm() { .unwrap() .get(0) .unwrap() - .lock() + .read() .unwrap() .id, other_user_id @@ -246,7 +241,7 @@ async fn remove_add_person_from_to_dm() { .unwrap() .get(1) .unwrap() - .lock() + .read() .unwrap() .id, user_id diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 747a8cd..029df2d 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,3 +1,5 @@ +use std::sync::{Arc, RwLock}; + use chorus::gateway::Gateway; use chorus::{ instance::{Instance, UserMeta}, @@ -14,9 +16,9 @@ pub(crate) struct TestBundle { pub urls: UrlBundle, pub user: UserMeta, pub instance: Instance, - pub guild: Guild, - pub role: RoleObject, - pub channel: Channel, + pub guild: Arc>, + pub role: Arc>, + pub channel: Arc>, } #[allow(unused)] @@ -113,17 +115,16 @@ pub(crate) async fn setup() -> TestBundle { urls, user, instance, - guild, - role, - channel, + guild: Arc::new(RwLock::new(guild)), + role: Arc::new(RwLock::new(role)), + channel: Arc::new(RwLock::new(channel)), } } // Teardown method to clean up after a test. #[allow(dead_code)] pub(crate) async fn teardown(mut bundle: TestBundle) { - Guild::delete(&mut bundle.user, bundle.guild.id) - .await - .unwrap(); + let id = bundle.guild.read().unwrap().id; + Guild::delete(&mut bundle.user, id).await.unwrap(); bundle.user.delete().await.unwrap() } diff --git a/tests/gateway.rs b/tests/gateway.rs index 21a2018..be1cb10 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -1,7 +1,9 @@ mod common; +use std::sync::{Arc, RwLock}; + use chorus::gateway::*; -use chorus::types::{self, Channel}; +use chorus::types::{self, ChannelModifySchema, RoleCreateModifySchema, RoleObject}; #[tokio::test] /// Tests establishing a connection (hello and heartbeats) on the local gateway; @@ -29,18 +31,88 @@ async fn test_gateway_authenticate() { #[tokio::test] async fn test_self_updating_structs() { let mut bundle = common::setup().await; - let channel_updater = bundle.user.gateway.observe(bundle.channel.clone()).await; - let received_channel = channel_updater.borrow().clone(); - assert_eq!(received_channel, bundle.channel); - let channel = &mut bundle.channel; - let modify_data = types::ChannelModifySchema { - name: Some("beepboop".to_string()), + let received_channel = bundle + .user + .gateway + .observe_and_into_inner(bundle.channel.clone()) + .await; + + assert_eq!(received_channel, bundle.channel.read().unwrap().clone()); + + let modify_schema = ChannelModifySchema { + name: Some("selfupdating".to_string()), ..Default::default() }; - Channel::modify(channel, modify_data, channel.id, &mut bundle.user) + received_channel + .modify(modify_schema, &mut bundle.user) .await .unwrap(); - let received_channel = channel_updater.borrow(); - assert_eq!(received_channel.name.as_ref().unwrap(), "beepboop"); + assert_eq!( + bundle + .user + .gateway + .observe_and_into_inner(bundle.channel.clone()) + .await + .name + .unwrap(), + "selfupdating".to_string() + ); + common::teardown(bundle).await } + +#[tokio::test] +async fn test_recursive_self_updating_structs() { + // Setup + let mut bundle = common::setup().await; + let guild = bundle.guild.clone(); + // Observe Guild, make sure it has no channels + let guild = bundle.user.gateway.observe(guild.clone()).await; + let inner_guild = guild.read().unwrap().clone(); + assert!(inner_guild.roles.is_none()); + // Create Role + let permissions = types::PermissionFlags::CONNECT | types::PermissionFlags::MANAGE_EVENTS; + let permissions = Some(permissions.to_string()); + let mut role_create_schema: types::RoleCreateModifySchema = RoleCreateModifySchema { + name: Some("cool person".to_string()), + permissions, + hoist: Some(true), + icon: None, + unicode_emoji: Some("".to_string()), + mentionable: Some(true), + position: None, + color: None, + }; + let guild_id = inner_guild.id; + let role = RoleObject::create(&mut bundle.user, guild_id, role_create_schema.clone()) + .await + .unwrap(); + // Watch role; + bundle + .user + .gateway + .observe(Arc::new(RwLock::new(role.clone()))) + .await; + // Update Guild and check for Guild + let inner_guild = guild.read().unwrap().clone(); + assert!(inner_guild.roles.is_some()); + // Update the Role + role_create_schema.name = Some("yippieee".to_string()); + RoleObject::modify(&mut bundle.user, guild_id, role.id, role_create_schema) + .await + .unwrap(); + let role_inner = bundle + .user + .gateway + .observe_and_into_inner(Arc::new(RwLock::new(role.clone()))) + .await; + assert_eq!(role_inner.name, "yippieee"); + // Check if the change propagated + let guild = bundle.user.gateway.observe(bundle.guild.clone()).await; + let inner_guild = guild.read().unwrap().clone(); + let guild_roles = inner_guild.roles; + let guild_role = guild_roles.unwrap(); + let guild_role_inner = guild_role.get(0).unwrap().read().unwrap().clone(); + assert_eq!(guild_role_inner.name, "yippieee".to_string()); + common::teardown(bundle).await; +} diff --git a/tests/guilds.rs b/tests/guilds.rs index cc33b30..3d29d43 100644 --- a/tests/guilds.rs +++ b/tests/guilds.rs @@ -27,9 +27,7 @@ async fn guild_creation_deletion() { #[tokio::test] async fn get_channels() { let mut bundle = common::setup().await; - println!( - "{:?}", - bundle.guild.channels(&mut bundle.user).await.unwrap() - ); + let guild = bundle.guild.read().unwrap().clone(); + println!("{:?}", guild.channels(&mut bundle.user).await.unwrap()); common::teardown(bundle).await; } diff --git a/tests/invites.rs b/tests/invites.rs index e25fea5..ab264d4 100644 --- a/tests/invites.rs +++ b/tests/invites.rs @@ -3,11 +3,12 @@ use chorus::types::CreateChannelInviteSchema; #[tokio::test] async fn create_accept_invite() { let mut bundle = common::setup().await; - let channel = bundle.channel.clone(); + let channel = bundle.channel.read().unwrap().clone(); let mut other_user = bundle.create_user("testuser1312").await; let user = &mut bundle.user; let create_channel_invite_schema = CreateChannelInviteSchema::default(); - assert!(chorus::types::Guild::get(bundle.guild.id, &mut other_user) + let guild = bundle.guild.read().unwrap().clone(); + assert!(chorus::types::Guild::get(guild.id, &mut other_user) .await .is_err()); let invite = user @@ -16,7 +17,7 @@ async fn create_accept_invite() { .unwrap(); other_user.accept_invite(&invite.code, None).await.unwrap(); - assert!(chorus::types::Guild::get(bundle.guild.id, &mut other_user) + assert!(chorus::types::Guild::get(guild.id, &mut other_user) .await .is_ok()); common::teardown(bundle).await; diff --git a/tests/members.rs b/tests/members.rs index a314be7..fbab772 100644 --- a/tests/members.rs +++ b/tests/members.rs @@ -5,9 +5,9 @@ mod common; #[tokio::test] async fn add_remove_role() -> ChorusResult<()> { let mut bundle = common::setup().await; - let guild = bundle.guild.id; - let role = bundle.role.id; - let member_id = bundle.user.object.lock().unwrap().id; + let guild = bundle.guild.read().unwrap().id; + let role = bundle.role.read().unwrap().id; + let member_id = bundle.user.object.read().unwrap().id; GuildMember::add_role(&mut bundle.user, guild, member_id, role).await?; let member = GuildMember::get(&mut bundle.user, guild, member_id) .await diff --git a/tests/messages.rs b/tests/messages.rs index af6ee5b..417ca54 100644 --- a/tests/messages.rs +++ b/tests/messages.rs @@ -12,11 +12,8 @@ async fn send_message() { content: Some("A Message!".to_string()), ..Default::default() }; - let _ = bundle - .user - .send_message(message, bundle.channel.id) - .await - .unwrap(); + let channel = bundle.channel.read().unwrap().clone(); + let _ = bundle.user.send_message(message, channel.id).await.unwrap(); common::teardown(bundle).await } @@ -50,13 +47,9 @@ async fn send_message_attachment() { attachments: Some(vec![attachment.clone()]), ..Default::default() }; - + let channel = bundle.channel.read().unwrap().clone(); let vec_attach = vec![attachment.clone()]; let _arg = Some(&vec_attach); - bundle - .user - .send_message(message, bundle.channel.id) - .await - .unwrap(); + bundle.user.send_message(message, channel.id).await.unwrap(); common::teardown(bundle).await } diff --git a/tests/relationships.rs b/tests/relationships.rs index 2773474..09ddab0 100644 --- a/tests/relationships.rs +++ b/tests/relationships.rs @@ -7,9 +7,9 @@ async fn test_get_mutual_relationships() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; - let username = user.object.lock().unwrap().username.clone(); - let discriminator = user.object.lock().unwrap().discriminator.clone(); - let other_user_id: types::Snowflake = other_user.object.lock().unwrap().id; + let username = user.object.read().unwrap().username.clone(); + let discriminator = user.object.read().unwrap().discriminator.clone(); + let other_user_id: types::Snowflake = other_user.object.read().unwrap().id; let friend_request_schema = types::FriendRequestSendSchema { username, discriminator: Some(discriminator), @@ -28,8 +28,8 @@ async fn test_get_relationships() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; - let username = user.object.lock().unwrap().username.clone(); - let discriminator = user.object.lock().unwrap().discriminator.clone(); + let username = user.object.read().unwrap().username.clone(); + let discriminator = user.object.read().unwrap().discriminator.clone(); let friend_request_schema = types::FriendRequestSendSchema { username, discriminator: Some(discriminator), @@ -41,7 +41,7 @@ async fn test_get_relationships() { let relationships = user.get_relationships().await.unwrap(); assert_eq!( relationships.get(0).unwrap().id, - other_user.object.lock().unwrap().id + other_user.object.read().unwrap().id ); common::teardown(bundle).await } @@ -51,8 +51,8 @@ async fn test_modify_relationship_friends() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; - let user_id: types::Snowflake = user.object.lock().unwrap().id; - let other_user_id: types::Snowflake = other_user.object.lock().unwrap().id; + let user_id: types::Snowflake = user.object.read().unwrap().id; + let other_user_id: types::Snowflake = other_user.object.read().unwrap().id; other_user .modify_user_relationship(user_id, types::RelationshipType::Friends) @@ -61,7 +61,7 @@ async fn test_modify_relationship_friends() { let relationships = user.get_relationships().await.unwrap(); assert_eq!( relationships.get(0).unwrap().id, - other_user.object.lock().unwrap().id + other_user.object.read().unwrap().id ); assert_eq!( relationships.get(0).unwrap().relationship_type, @@ -70,7 +70,7 @@ async fn test_modify_relationship_friends() { let relationships = other_user.get_relationships().await.unwrap(); assert_eq!( relationships.get(0).unwrap().id, - user.object.lock().unwrap().id + user.object.read().unwrap().id ); assert_eq!( relationships.get(0).unwrap().relationship_type, @@ -102,7 +102,7 @@ async fn test_modify_relationship_block() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; - let user_id: types::Snowflake = user.object.lock().unwrap().id; + let user_id: types::Snowflake = user.object.read().unwrap().id; other_user .modify_user_relationship(user_id, types::RelationshipType::Blocked) @@ -113,7 +113,7 @@ async fn test_modify_relationship_block() { let relationships = other_user.get_relationships().await.unwrap(); assert_eq!( relationships.get(0).unwrap().id, - user.object.lock().unwrap().id + user.object.read().unwrap().id ); assert_eq!( relationships.get(0).unwrap().relationship_type, diff --git a/tests/roles.rs b/tests/roles.rs index f45fb62..7876e5d 100644 --- a/tests/roles.rs +++ b/tests/roles.rs @@ -17,12 +17,12 @@ async fn create_and_get_roles() { position: None, color: None, }; - let guild = bundle.guild.id; - let role = types::RoleObject::create(&mut bundle.user, guild, role_create_schema) + let guild_id = bundle.guild.read().unwrap().id; + let role = types::RoleObject::create(&mut bundle.user, guild_id, role_create_schema) .await .unwrap(); - let expected = types::RoleObject::get_all(&mut bundle.user, guild) + let expected = types::RoleObject::get_all(&mut bundle.user, guild_id) .await .unwrap()[2] .clone(); @@ -34,9 +34,9 @@ async fn create_and_get_roles() { #[tokio::test] async fn get_singular_role() { let mut bundle = common::setup().await; - let guild_id = bundle.guild.id; - let role_id = bundle.role.id; - let role = bundle.role.clone(); + let guild_id = bundle.guild.read().unwrap().id; + let role_id = bundle.role.read().unwrap().id; + let role = bundle.role.read().unwrap().clone(); let same_role = chorus::types::RoleObject::get(&mut bundle.user, guild_id, role_id) .await .unwrap();