diff --git a/chorus-macros/src/lib.rs b/chorus-macros/src/lib.rs index 72e5ddd..baac40b 100644 --- a/chorus-macros/src/lib.rs +++ b/chorus-macros/src/lib.rs @@ -18,6 +18,25 @@ pub fn updateable_macro_derive(input: TokenStream) -> TokenStream { .into() } +#[proc_macro_derive(JsonField)] +pub fn jsonfield_macro_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + // No need for macro hygiene, we're only using this in chorus + quote! { + impl JsonField for #name { + fn get_json(&self) -> String { + self.json.clone() + } + fn set_json(&mut self, json: String) { + self.json = json; + } + } + } + .into() +} + #[proc_macro_attribute] pub fn observe_option(_args: TokenStream, input: TokenStream) -> TokenStream { input diff --git a/src/gateway.rs b/src/gateway.rs index 886bd3a..3fb56e7 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -152,6 +152,10 @@ impl GatewayMessage { /// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently /// implemented types with the trait [`WebSocketEvent`] /// Using this handle you can also send Gateway Events directly. +/// +/// # Store +/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel`]. See the +/// [`Updateable`] trait for more information. #[derive(Debug)] pub struct GatewayHandle { pub url: String, @@ -311,6 +315,8 @@ impl GatewayHandle { } } +/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel`]. See the +/// [`Updateable`] trait for more information. pub struct Gateway { events: Arc>, heartbeat_handler: HeartbeatHandler, @@ -494,7 +500,7 @@ impl Gateway { Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), Ok(message) => { $( - let message: $message_type = message; + let mut message: $message_type = message; if let Some(to_update) = self.store.lock().await.get(&message.id()) { if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender>>, watch::Receiver>>)>() { // `object` is the current value of the `watch::channel`. It's being passed into `message.update()` to be modified diff --git a/src/types/events/channel.rs b/src/types/events/channel.rs index 000bc17..017c50e 100644 --- a/src/types/events/channel.rs +++ b/src/types/events/channel.rs @@ -1,7 +1,8 @@ use std::sync::{Arc, RwLock}; use crate::types::events::WebSocketEvent; -use crate::types::{entities::Channel, Snowflake}; +use crate::types::{entities::Channel, JsonField, Snowflake}; +use chorus_macros::JsonField; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -26,17 +27,19 @@ pub struct ChannelCreate { impl WebSocketEvent for ChannelCreate {} -#[derive(Debug, Default, Deserialize, Serialize, Clone)] +#[derive(Debug, Default, Deserialize, Serialize, Clone, JsonField)] /// See pub struct ChannelUpdate { #[serde(flatten)] pub channel: Channel, + #[serde(skip)] + pub json: String, } impl WebSocketEvent for ChannelUpdate {} impl UpdateMessage for ChannelUpdate { - fn update(&self, object_to_update: Arc>) { + fn update(&mut self, object_to_update: Arc>) { let mut write = object_to_update.write().unwrap(); *write = self.channel.clone(); } diff --git a/src/types/events/mod.rs b/src/types/events/mod.rs index 0413ac9..0a296fb 100644 --- a/src/types/events/mod.rs +++ b/src/types/events/mod.rs @@ -1,5 +1,8 @@ use std::sync::{Arc, RwLock}; +use std::collections::HashMap; + +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; pub use application::*; @@ -21,6 +24,7 @@ pub use ready::*; pub use relationship::*; pub use request_members::*; pub use resume::*; +use serde_json::{from_str, from_value, to_value, Value}; pub use session::*; pub use stage_instance::*; pub use thread::*; @@ -111,10 +115,31 @@ impl<'a> WebSocketEvent for GatewayReceivePayload<'a> {} /// This would imply, that the [`WebSocketEvent`] "[`ChannelUpdate`]" contains new/updated information /// about a [`Channel`]. The update method describes how this new information will be turned into /// a [`Channel`] object. -pub(crate) trait UpdateMessage: Clone +pub(crate) trait UpdateMessage: Clone + JsonField where - T: Updateable, + T: Updateable + Serialize + DeserializeOwned + Clone, { - fn update(&self, object_to_update: Arc>); + fn update(&mut self, object_to_update: Arc>) { + update_object(self.get_json(), object_to_update) + } fn id(&self) -> Snowflake; } + +pub(crate) trait JsonField: Clone { + fn set_json(&mut self, json: String); + fn get_json(&self) -> String; +} + +/// Only applicable for events where the Update struct is the same as the Entity struct +pub(crate) fn update_object( + value: String, + object: Arc>, +) { + let data_from_event: HashMap = from_str(&value).unwrap(); + let mut original_data: HashMap = + from_value(to_value(object.clone()).unwrap()).unwrap(); + for (updated_entry_key, updated_entry_value) in data_from_event.into_iter() { + original_data.insert(updated_entry_key.clone(), updated_entry_value); + } + *object.write().unwrap() = from_value(to_value(original_data).unwrap()).unwrap(); +} diff --git a/tests/channels.rs b/tests/channels.rs index 3043c00..94129f6 100644 --- a/tests/channels.rs +++ b/tests/channels.rs @@ -51,9 +51,7 @@ async fn modify_channel() { default_thread_rate_limit_per_user: None, video_quality_mode: None, }; - let modified_channel = Channel::modify(channel, modify_data, &mut bundle.user) - .await - .unwrap(); + let modified_channel = channel.modify(modify_data, &mut bundle.user).await.unwrap(); assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string())); let permission_override = PermissionFlags::from_vec(Vec::from([