From fbf72b74d0634ac9f11eb797877e4f74b2734ab8 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Wed, 16 Aug 2023 21:26:19 +0200 Subject: [PATCH] Remove tokio watch channels --- src/gateway.rs | 76 +++++++++++++++++++----------------------------- tests/gateway.rs | 51 ++------------------------------ 2 files changed, 33 insertions(+), 94 deletions(-) diff --git a/src/gateway.rs b/src/gateway.rs index 695e893..e1c4836 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -10,7 +10,6 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::{Arc, RwLock}; use std::time::Duration; -use tokio::sync::watch; use tokio::time::sleep_until; use futures_util::stream::SplitSink; @@ -150,14 +149,12 @@ impl GatewayMessage { } } +pub type ObservableObject = dyn Send + Sync + Any; + /// Represents a handle to a Gateway connection. A Gateway connection will create observable /// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently /// implemented types with the trait [`WebSocketEvent`] /// Using this handle you can also send Gateway Events directly. -/// -/// # Store -/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel`]. See the -/// [`Updateable`] trait for more information. #[derive(Debug)] pub struct GatewayHandle { pub url: String, @@ -173,7 +170,7 @@ pub struct GatewayHandle { pub handle: JoinHandle<()>, /// Tells gateway tasks to close kill_send: tokio::sync::broadcast::Sender<()>, - store: Arc>>>, + store: Arc>>>>, } /// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. @@ -202,55 +199,42 @@ impl GatewayHandle { .unwrap(); } - pub async fn observer_channel>( + pub async fn observe>( &self, object: Arc>, - ) -> watch::Receiver>> { + ) -> Arc> { let mut store = self.store.lock().await; let id = object.read().unwrap().id(); if let Some(channel) = store.get(&id) { - let (_, rx) = channel - .downcast_ref::<( - watch::Sender>>, - watch::Receiver>>, - )>() - .unwrap_or_else(|| { - panic!( - "Snowflake {} already exists in the store, but it is not of type T.", - object.read().unwrap().id() - ) - }); - rx.clone() + let object = channel.clone(); + let inner_object = object.read().unwrap(); + inner_object.downcast_ref::().unwrap_or_else(|| { + panic!( + "Snowflake {} already exists in the store, but it is not of type T.", + id + ) + }); + let ptr = Arc::into_raw(object.clone()); + unsafe { println!("{:?}", Arc::from_raw(ptr as *const RwLock).clone()) }; + unsafe { Arc::from_raw(ptr as *const RwLock).clone() } } else { let id = object.read().unwrap().id(); let object = object.read().unwrap().clone(); let object = object.clone().watch_whole(self).await; - let channel = watch::channel(Arc::new(RwLock::new(object))); - let receiver = channel.1.clone(); - store.insert(id, Box::new(channel)); - receiver + let wrapped = Arc::new(RwLock::new(object)); + store.insert(id, wrapped.clone()); + wrapped } } - /// Recursively observes and updates all updateable fields on the struct T. Returns an object - /// with all of its observable fields being observed. - pub async fn observe_and_get>( - &self, - object: Arc>, - ) -> Arc> { - let channel = self.observer_channel(object.clone()).await; - let object = channel.borrow().clone(); - object - } - /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` /// with all of its observable fields being observed. - pub async fn observe_and_into_inner>( + pub async fn observe_and_into_inner>( &self, object: Arc>, ) -> T { - let channel = self.observer_channel(object.clone()).await; - let object = channel.borrow().clone().read().unwrap().clone(); + let channel = self.observe(object.clone()).await; + let object = channel.read().unwrap().clone(); object } @@ -330,8 +314,6 @@ impl GatewayHandle { } } -/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel`]. See the -/// [`Updateable`] trait for more information. pub struct Gateway { events: Arc>, heartbeat_handler: HeartbeatHandler, @@ -345,7 +327,7 @@ pub struct Gateway { >, websocket_receive: SplitStream>>, kill_send: tokio::sync::broadcast::Sender<()>, - store: Arc>>>, + store: Arc>>>>, } impl Gateway { @@ -518,12 +500,14 @@ impl Gateway { $( let mut message: $message_type = message; if let Some(to_update) = self.store.lock().await.get(&message.id()) { - if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender>>, watch::Receiver>>)>() { - // `object` is the current value of the `watch::channel`. It's being passed into `message.update()` to be modified - // within the closure function. Then, this closure is passed to the `tx.send_modify()` method which applies the - // modification to the current value of the watch channel. + let object = to_update.clone(); + let inner_object = object.read().unwrap(); + if let Some(_) = inner_object.downcast_ref::<$update_type>() { + let ptr = Arc::into_raw(object.clone()); + let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<$update_type>).clone() }; + drop(inner_object); message.set_json(json.to_string()); - tx.send_modify(|object| message.update(object.clone())); + message.update(downcasted.clone()); } else { warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id()) } diff --git a/tests/gateway.rs b/tests/gateway.rs index 3876ef6..7c15bf6 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -1,11 +1,7 @@ mod common; -use std::sync::{Arc, RwLock}; - use chorus::gateway::*; -use chorus::types::{ - self, ChannelModifySchema, Composite, PermissionFlags, RoleCreateModifySchema, RoleObject, -}; +use chorus::types::{self, ChannelModifySchema}; #[tokio::test] /// Tests establishing a connection (hello and heartbeats) on the local gateway; @@ -65,48 +61,7 @@ async fn test_self_updating_structs() { #[tokio::test] async fn test_recursive_self_updating_structs() { - let mut bundle = common::setup().await; - - let guild = bundle - .user - .gateway - .observe_and_into_inner(bundle.guild.clone()) - .await; - assert!(guild.roles.is_none()); - let id = guild.id; - let permissions = PermissionFlags::CONNECT | PermissionFlags::MANAGE_EVENTS; - let permissions = Some(permissions.to_string()); - let mut role_create_schema = RoleCreateModifySchema { - name: Some("among us".to_string()), - permissions, - hoist: Some(true), - icon: None, - unicode_emoji: Some("".to_string()), - mentionable: Some(true), - position: None, - color: None, - }; - let role = RoleObject::create(&mut bundle.user, id, role_create_schema.clone()) - .await - .unwrap(); - let role_watch = bundle - .user - .gateway - .observer_channel(Arc::new(RwLock::new(role.clone()))) - .await; - let guild = bundle - .user - .gateway - .observe_and_into_inner(bundle.guild.clone()) - .await; - assert!(guild.roles.is_some()); - role_create_schema.name = Some("enbyenvy".to_string()); - RoleObject::modify(&mut bundle.user, id, role.id, role_create_schema) - .await - .unwrap(); - let newrole = role_watch.borrow().read().unwrap().clone(); - assert_eq!(newrole.name, "enbyenvy".to_string()); - let guild_role = role_watch.borrow().read().unwrap().clone(); - assert_eq!(guild_role.name, "enbyenvy".to_string()); + // Setup + let bundle = common::setup().await; common::teardown(bundle).await; }