Merge branch 'feature/update-message-all-events' into enhancement/improved-auto-updating-structs

This commit is contained in:
bitfl0wer 2023-08-15 20:53:56 +02:00
commit 9dc86615a4
No known key found for this signature in database
GPG Key ID: 0ACD574FCF5226CF
5 changed files with 61 additions and 10 deletions

View File

@ -18,6 +18,25 @@ pub fn updateable_macro_derive(input: TokenStream) -> TokenStream {
.into() .into()
} }
#[proc_macro_derive(JsonField)]
pub fn jsonfield_macro_derive(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
// No need for macro hygiene, we're only using this in chorus
quote! {
impl JsonField for #name {
fn get_json(&self) -> String {
self.json.clone()
}
fn set_json(&mut self, json: String) {
self.json = json;
}
}
}
.into()
}
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn observe_option(_args: TokenStream, input: TokenStream) -> TokenStream { pub fn observe_option(_args: TokenStream, input: TokenStream) -> TokenStream {
input input

View File

@ -152,6 +152,10 @@ impl GatewayMessage {
/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently /// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently
/// implemented types with the trait [`WebSocketEvent`] /// implemented types with the trait [`WebSocketEvent`]
/// Using this handle you can also send Gateway Events directly. /// Using this handle you can also send Gateway Events directly.
///
/// # Store
/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel<T: Updateable>`]. See the
/// [`Updateable`] trait for more information.
#[derive(Debug)] #[derive(Debug)]
pub struct GatewayHandle { pub struct GatewayHandle {
pub url: String, pub url: String,
@ -311,6 +315,8 @@ impl GatewayHandle {
} }
} }
/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel<T: Updateable>`]. See the
/// [`Updateable`] trait for more information.
pub struct Gateway { pub struct Gateway {
events: Arc<Mutex<Events>>, events: Arc<Mutex<Events>>,
heartbeat_handler: HeartbeatHandler, heartbeat_handler: HeartbeatHandler,
@ -494,7 +500,7 @@ impl Gateway {
Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"),
Ok(message) => { Ok(message) => {
$( $(
let message: $message_type = message; let mut message: $message_type = message;
if let Some(to_update) = self.store.lock().await.get(&message.id()) { if let Some(to_update) = self.store.lock().await.get(&message.id()) {
if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<Arc<RwLock<$update_type>>>, watch::Receiver<Arc<RwLock<$update_type>>>)>() { if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<Arc<RwLock<$update_type>>>, watch::Receiver<Arc<RwLock<$update_type>>>)>() {
// `object` is the current value of the `watch::channel`. It's being passed into `message.update()` to be modified // `object` is the current value of the `watch::channel`. It's being passed into `message.update()` to be modified

View File

@ -1,7 +1,8 @@
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use crate::types::events::WebSocketEvent; use crate::types::events::WebSocketEvent;
use crate::types::{entities::Channel, Snowflake}; use crate::types::{entities::Channel, JsonField, Snowflake};
use chorus_macros::JsonField;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -26,17 +27,19 @@ pub struct ChannelCreate {
impl WebSocketEvent for ChannelCreate {} impl WebSocketEvent for ChannelCreate {}
#[derive(Debug, Default, Deserialize, Serialize, Clone)] #[derive(Debug, Default, Deserialize, Serialize, Clone, JsonField)]
/// See <https://discord.com/developers/docs/topics/gateway-events#channel-update> /// See <https://discord.com/developers/docs/topics/gateway-events#channel-update>
pub struct ChannelUpdate { pub struct ChannelUpdate {
#[serde(flatten)] #[serde(flatten)]
pub channel: Channel, pub channel: Channel,
#[serde(skip)]
pub json: String,
} }
impl WebSocketEvent for ChannelUpdate {} impl WebSocketEvent for ChannelUpdate {}
impl UpdateMessage<Channel> for ChannelUpdate { impl UpdateMessage<Channel> for ChannelUpdate {
fn update(&self, object_to_update: Arc<RwLock<Channel>>) { fn update(&mut self, object_to_update: Arc<RwLock<Channel>>) {
let mut write = object_to_update.write().unwrap(); let mut write = object_to_update.write().unwrap();
*write = self.channel.clone(); *write = self.channel.clone();
} }

View File

@ -1,5 +1,8 @@
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::collections::HashMap;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub use application::*; pub use application::*;
@ -21,6 +24,7 @@ pub use ready::*;
pub use relationship::*; pub use relationship::*;
pub use request_members::*; pub use request_members::*;
pub use resume::*; pub use resume::*;
use serde_json::{from_str, from_value, to_value, Value};
pub use session::*; pub use session::*;
pub use stage_instance::*; pub use stage_instance::*;
pub use thread::*; pub use thread::*;
@ -111,10 +115,31 @@ impl<'a> WebSocketEvent for GatewayReceivePayload<'a> {}
/// This would imply, that the [`WebSocketEvent`] "[`ChannelUpdate`]" contains new/updated information /// This would imply, that the [`WebSocketEvent`] "[`ChannelUpdate`]" contains new/updated information
/// about a [`Channel`]. The update method describes how this new information will be turned into /// about a [`Channel`]. The update method describes how this new information will be turned into
/// a [`Channel`] object. /// a [`Channel`] object.
pub(crate) trait UpdateMessage<T>: Clone pub(crate) trait UpdateMessage<T>: Clone + JsonField
where where
T: Updateable, T: Updateable + Serialize + DeserializeOwned + Clone,
{ {
fn update(&self, object_to_update: Arc<RwLock<T>>); fn update(&mut self, object_to_update: Arc<RwLock<T>>) {
update_object(self.get_json(), object_to_update)
}
fn id(&self) -> Snowflake; fn id(&self) -> Snowflake;
} }
pub(crate) trait JsonField: Clone {
fn set_json(&mut self, json: String);
fn get_json(&self) -> String;
}
/// Only applicable for events where the Update struct is the same as the Entity struct
pub(crate) fn update_object(
value: String,
object: Arc<RwLock<(impl Updateable + Serialize + DeserializeOwned + Clone)>>,
) {
let data_from_event: HashMap<String, Value> = from_str(&value).unwrap();
let mut original_data: HashMap<String, Value> =
from_value(to_value(object.clone()).unwrap()).unwrap();
for (updated_entry_key, updated_entry_value) in data_from_event.into_iter() {
original_data.insert(updated_entry_key.clone(), updated_entry_value);
}
*object.write().unwrap() = from_value(to_value(original_data).unwrap()).unwrap();
}

View File

@ -51,9 +51,7 @@ async fn modify_channel() {
default_thread_rate_limit_per_user: None, default_thread_rate_limit_per_user: None,
video_quality_mode: None, video_quality_mode: None,
}; };
let modified_channel = Channel::modify(channel, modify_data, &mut bundle.user) let modified_channel = channel.modify(modify_data, &mut bundle.user).await.unwrap();
.await
.unwrap();
assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string())); assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string()));
let permission_override = PermissionFlags::from_vec(Vec::from([ let permission_override = PermissionFlags::from_vec(Vec::from([