From d36b1de1700610b86afd1c52d20b9d0699c8be12 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Wed, 16 Aug 2023 00:18:32 +0200 Subject: [PATCH] Give `GatewayHandle` and `Gateway` common trait to call `watch_whole()` or observe() from a `Gateway` --- chorus-macros/src/lib.rs | 2 +- src/gateway.rs | 77 +++++++++++++++++++++---------- src/types/entities/channel.rs | 2 +- src/types/entities/emoji.rs | 2 +- src/types/entities/guild.rs | 2 +- src/types/entities/mod.rs | 15 +++--- src/types/entities/role.rs | 2 +- src/types/entities/user.rs | 2 +- src/types/entities/voice_state.rs | 2 +- src/types/entities/webhook.rs | 2 +- src/types/events/guild.rs | 29 ++++++++++-- tests/gateway.rs | 5 +- 12 files changed, 99 insertions(+), 43 deletions(-) diff --git a/chorus-macros/src/lib.rs b/chorus-macros/src/lib.rs index baac40b..50bcebf 100644 --- a/chorus-macros/src/lib.rs +++ b/chorus-macros/src/lib.rs @@ -105,7 +105,7 @@ pub fn composite_derive(input: TokenStream) -> TokenStream { let expanded = quote! { #[async_trait::async_trait(?Send)] impl Composite for #ident { - async fn watch_whole(self, gateway: &GatewayHandle) -> Self { + async fn watch_whole(self, gateway: &(impl GatewayObject + ?Sized)) -> Self { Self { #(#field_exprs,)* } diff --git a/src/gateway.rs b/src/gateway.rs index 9232c2d..950ec96 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -180,32 +180,16 @@ pub trait Updateable: 'static + Send + Sync { fn id(&self) -> Snowflake; } -impl GatewayHandle { - /// Sends json to the gateway with an opcode - async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) { - let gateway_payload = types::GatewaySendPayload { - op_code, - event_data: Some(to_send), - sequence_number: None, - }; - - let payload_json = serde_json::to_string(&gateway_payload).unwrap(); - - let message = tokio_tungstenite::tungstenite::Message::text(payload_json); - - self.websocket_send - .lock() - .await - .send(message) - .await - .unwrap(); - } - - pub async fn observe>( +#[async_trait(?Send)] +pub trait GatewayObject { + fn store(&self) -> Arc>>>; + fn events(&self) -> Arc>; + async fn observe>( &self, object: Arc>, ) -> watch::Receiver>> { - let mut store = self.store.lock().await; + let store = self.store(); + let mut store = store.lock().await; let id = object.read().unwrap().id(); if let Some(channel) = store.get(&id) { let (_, rx) = channel @@ -230,8 +214,7 @@ impl GatewayHandle { receiver } } - - pub async fn observe_and_get>( + async fn observe_and_get>( &self, object: Arc>, ) -> Arc> { @@ -239,6 +222,50 @@ impl GatewayHandle { let object = channel.borrow().clone(); object } +} + +#[async_trait(?Send)] +impl GatewayObject for GatewayHandle { + fn store(&self) -> Arc>>> { + self.store.clone() + } + + fn events(&self) -> Arc> { + self.events.clone() + } +} + +#[async_trait(?Send)] +impl GatewayObject for Gateway { + fn store(&self) -> Arc>>> { + self.store.clone() + } + + fn events(&self) -> Arc> { + self.events.clone() + } +} + +impl GatewayHandle { + /// Sends json to the gateway with an opcode + async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) { + let gateway_payload = types::GatewaySendPayload { + op_code, + event_data: Some(to_send), + sequence_number: None, + }; + + let payload_json = serde_json::to_string(&gateway_payload).unwrap(); + + let message = tokio_tungstenite::tungstenite::Message::text(payload_json); + + self.websocket_send + .lock() + .await + .send(message) + .await + .unwrap(); + } /// Sends an identify event to the gateway pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index ad776d8..8f0bf8b 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use serde_aux::prelude::deserialize_string_from_number; use serde_repr::{Deserialize_repr, Serialize_repr}; -use crate::gateway::{GatewayHandle, Updateable}; +use crate::gateway::{GatewayObject, Updateable}; use crate::types::{ entities::{GuildMember, User}, utils::Snowflake, diff --git a/src/types/entities/emoji.rs b/src/types/entities/emoji.rs index 308c41b..25b7d9a 100644 --- a/src/types/entities/emoji.rs +++ b/src/types/entities/emoji.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use chorus_macros::{Composite, Updateable}; use serde::{Deserialize, Serialize}; -use crate::gateway::{GatewayHandle, Updateable}; +use crate::gateway::{GatewayObject, Updateable}; use crate::types::entities::User; use crate::types::{Composite, Snowflake}; diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index 7f73403..8f875ec 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -5,7 +5,7 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; -use crate::gateway::{GatewayHandle, Updateable}; +use crate::gateway::{GatewayObject, Updateable}; use crate::types::types::guild_configuration::GuildFeaturesList; use crate::types::{ entities::{Channel, Emoji, RoleObject, Sticker, User, VoiceState, Webhook}, diff --git a/src/types/entities/mod.rs b/src/types/entities/mod.rs index 7f8cc56..a92463c 100644 --- a/src/types/entities/mod.rs +++ b/src/types/entities/mod.rs @@ -22,7 +22,7 @@ pub use user_settings::*; pub use voice_state::*; pub use webhook::*; -use crate::gateway::{GatewayHandle, Updateable}; +use crate::gateway::{GatewayObject, Updateable}; use async_trait::async_trait; use std::sync::{Arc, RwLock}; @@ -52,11 +52,11 @@ mod webhook; #[async_trait(?Send)] pub trait Composite { - async fn watch_whole(self, gateway: &GatewayHandle) -> Self; + async fn watch_whole(self, gateway: &(impl GatewayObject + ?Sized)) -> Self; async fn option_observe_fn( value: Option>>, - gateway: &GatewayHandle, + gateway: &(impl GatewayObject + ?Sized), ) -> Option>> where T: Composite, @@ -71,7 +71,7 @@ pub trait Composite { async fn option_vec_observe_fn( value: Option>>>, - gateway: &GatewayHandle, + gateway: &(impl GatewayObject + ?Sized), ) -> Option>>> where T: Composite, @@ -87,7 +87,10 @@ pub trait Composite { } } - async fn value_observe_fn(value: Arc>, gateway: &GatewayHandle) -> Arc> + async fn value_observe_fn( + value: Arc>, + gateway: &(impl GatewayObject + ?Sized), + ) -> Arc> where T: Composite, { @@ -96,7 +99,7 @@ pub trait Composite { async fn vec_observe_fn( value: Vec>>, - gateway: &GatewayHandle, + gateway: &(impl GatewayObject + ?Sized), ) -> Vec>> where T: Composite, diff --git a/src/types/entities/role.rs b/src/types/entities/role.rs index 9fac835..17b9024 100644 --- a/src/types/entities/role.rs +++ b/src/types/entities/role.rs @@ -3,7 +3,7 @@ use chorus_macros::{Composite, Updateable}; use serde::{Deserialize, Serialize}; use serde_aux::prelude::{deserialize_option_number_from_string, deserialize_string_from_number}; -use crate::gateway::{GatewayHandle, Updateable}; +use crate::gateway::{GatewayObject, Updateable}; use crate::types::{utils::Snowflake, Composite}; #[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, Updateable, Composite)] diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index 56329b2..22e3e9c 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -3,7 +3,7 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serde_aux::prelude::deserialize_option_number_from_string; -use crate::gateway::{GatewayHandle, Updateable}; +use crate::gateway::{GatewayObject, Updateable}; use crate::types::{utils::Snowflake, Composite}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] diff --git a/src/types/entities/voice_state.rs b/src/types/entities/voice_state.rs index e74aa7d..798fe04 100644 --- a/src/types/entities/voice_state.rs +++ b/src/types/entities/voice_state.rs @@ -4,7 +4,7 @@ use chorus_macros::{Composite, Updateable}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use crate::gateway::{GatewayHandle, Updateable}; +use crate::gateway::{GatewayObject, Updateable}; use crate::types::{ entities::{Guild, GuildMember}, utils::Snowflake, diff --git a/src/types/entities/webhook.rs b/src/types/entities/webhook.rs index dd88ca1..c80cdbd 100644 --- a/src/types/entities/webhook.rs +++ b/src/types/entities/webhook.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use chorus_macros::{Composite, Updateable}; use serde::{Deserialize, Serialize}; -use crate::gateway::{GatewayHandle, Updateable}; +use crate::gateway::{GatewayObject, Updateable}; use crate::types::{ entities::{Guild, User}, utils::Snowflake, diff --git a/src/types/events/guild.rs b/src/types/events/guild.rs index 2cbdbd0..89a8339 100644 --- a/src/types/events/guild.rs +++ b/src/types/events/guild.rs @@ -1,13 +1,17 @@ +use std::sync::{Arc, RwLock}; + +use chorus_macros::JsonField; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use crate::types::entities::{Guild, PublicUser, UnavailableGuild}; use crate::types::events::WebSocketEvent; use crate::types::{ - AuditLogEntry, Emoji, GuildMember, GuildScheduledEvent, RoleObject, Snowflake, Sticker, + AuditLogEntry, Emoji, GuildMember, GuildScheduledEvent, JsonField, RoleObject, Snowflake, + Sticker, }; -use super::PresenceUpdate; +use super::{PresenceUpdate, UpdateMessage}; #[derive(Debug, Deserialize, Serialize, Default, Clone)] /// See ; @@ -164,15 +168,34 @@ pub struct GuildMembersChunk { impl WebSocketEvent for GuildMembersChunk {} -#[derive(Debug, Default, Deserialize, Serialize, Clone)] +#[derive(Debug, Default, Deserialize, Serialize, Clone, JsonField)] /// See pub struct GuildRoleCreate { pub guild_id: Snowflake, pub role: RoleObject, + #[serde(skip)] + pub json: String, } impl WebSocketEvent for GuildRoleCreate {} +impl UpdateMessage for GuildRoleCreate { + fn id(&self) -> Snowflake { + self.role.id + } + + fn update(&mut self, object_to_update: std::sync::Arc>) { + let mut write = object_to_update.write().unwrap(); + if write.roles.is_some() { + write + .roles + .as_mut() + .unwrap() + .push(Arc::new(RwLock::new(self.role.clone()))); + } + } +} + #[derive(Debug, Default, Deserialize, Serialize, Clone)] /// See pub struct GuildRoleUpdate { diff --git a/tests/gateway.rs b/tests/gateway.rs index 4cc450a..742dc4e 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -1,7 +1,7 @@ mod common; use chorus::gateway::*; -use chorus::types::{self, Channel, ChannelModifySchema}; +use chorus::types::{self, Channel, ChannelModifySchema, Guild}; #[tokio::test] /// Tests establishing a connection (hello and heartbeats) on the local gateway; @@ -63,3 +63,6 @@ async fn test_self_updating_structs() { common::teardown(bundle).await } + +#[tokio::test] +async fn test_recursive_self_updating_structs() {}