Remove tokio watch channels

This commit is contained in:
bitfl0wer 2023-08-16 21:26:19 +02:00
parent e9f94267c6
commit fbf72b74d0
2 changed files with 33 additions and 94 deletions

View File

@ -10,7 +10,6 @@ use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use tokio::sync::watch;
use tokio::time::sleep_until; use tokio::time::sleep_until;
use futures_util::stream::SplitSink; use futures_util::stream::SplitSink;
@ -150,14 +149,12 @@ impl GatewayMessage {
} }
} }
pub type ObservableObject = dyn Send + Sync + Any;
/// Represents a handle to a Gateway connection. A Gateway connection will create observable /// Represents a handle to a Gateway connection. A Gateway connection will create observable
/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently /// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently
/// implemented types with the trait [`WebSocketEvent`] /// implemented types with the trait [`WebSocketEvent`]
/// Using this handle you can also send Gateway Events directly. /// Using this handle you can also send Gateway Events directly.
///
/// # Store
/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel<T: Updateable>`]. See the
/// [`Updateable`] trait for more information.
#[derive(Debug)] #[derive(Debug)]
pub struct GatewayHandle { pub struct GatewayHandle {
pub url: String, pub url: String,
@ -173,7 +170,7 @@ pub struct GatewayHandle {
pub handle: JoinHandle<()>, pub handle: JoinHandle<()>,
/// Tells gateway tasks to close /// Tells gateway tasks to close
kill_send: tokio::sync::broadcast::Sender<()>, kill_send: tokio::sync::broadcast::Sender<()>,
store: Arc<Mutex<HashMap<Snowflake, Box<dyn Send + Any>>>>, store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
} }
/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. /// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake.
@ -202,55 +199,42 @@ impl GatewayHandle {
.unwrap(); .unwrap();
} }
pub async fn observer_channel<T: Updateable + Clone + Composite<T>>( pub async fn observe<T: Updateable + Clone + Debug + Composite<T>>(
&self, &self,
object: Arc<RwLock<T>>, object: Arc<RwLock<T>>,
) -> watch::Receiver<Arc<RwLock<T>>> { ) -> Arc<RwLock<T>> {
let mut store = self.store.lock().await; let mut store = self.store.lock().await;
let id = object.read().unwrap().id(); let id = object.read().unwrap().id();
if let Some(channel) = store.get(&id) { if let Some(channel) = store.get(&id) {
let (_, rx) = channel let object = channel.clone();
.downcast_ref::<( let inner_object = object.read().unwrap();
watch::Sender<Arc<RwLock<T>>>, inner_object.downcast_ref::<T>().unwrap_or_else(|| {
watch::Receiver<Arc<RwLock<T>>>, panic!(
)>() "Snowflake {} already exists in the store, but it is not of type T.",
.unwrap_or_else(|| { id
panic!( )
"Snowflake {} already exists in the store, but it is not of type T.", });
object.read().unwrap().id() let ptr = Arc::into_raw(object.clone());
) unsafe { println!("{:?}", Arc::from_raw(ptr as *const RwLock<T>).clone()) };
}); unsafe { Arc::from_raw(ptr as *const RwLock<T>).clone() }
rx.clone()
} else { } else {
let id = object.read().unwrap().id(); let id = object.read().unwrap().id();
let object = object.read().unwrap().clone(); let object = object.read().unwrap().clone();
let object = object.clone().watch_whole(self).await; let object = object.clone().watch_whole(self).await;
let channel = watch::channel(Arc::new(RwLock::new(object))); let wrapped = Arc::new(RwLock::new(object));
let receiver = channel.1.clone(); store.insert(id, wrapped.clone());
store.insert(id, Box::new(channel)); wrapped
receiver
} }
} }
/// Recursively observes and updates all updateable fields on the struct T. Returns an object
/// with all of its observable fields being observed.
pub async fn observe_and_get<T: Updateable + Clone + Composite<T>>(
&self,
object: Arc<RwLock<T>>,
) -> Arc<RwLock<T>> {
let channel = self.observer_channel(object.clone()).await;
let object = channel.borrow().clone();
object
}
/// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T`
/// with all of its observable fields being observed. /// with all of its observable fields being observed.
pub async fn observe_and_into_inner<T: Updateable + Clone + Composite<T>>( pub async fn observe_and_into_inner<T: Updateable + Clone + Debug + Composite<T>>(
&self, &self,
object: Arc<RwLock<T>>, object: Arc<RwLock<T>>,
) -> T { ) -> T {
let channel = self.observer_channel(object.clone()).await; let channel = self.observe(object.clone()).await;
let object = channel.borrow().clone().read().unwrap().clone(); let object = channel.read().unwrap().clone();
object object
} }
@ -330,8 +314,6 @@ impl GatewayHandle {
} }
} }
/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel<T: Updateable>`]. See the
/// [`Updateable`] trait for more information.
pub struct Gateway { pub struct Gateway {
events: Arc<Mutex<Events>>, events: Arc<Mutex<Events>>,
heartbeat_handler: HeartbeatHandler, heartbeat_handler: HeartbeatHandler,
@ -345,7 +327,7 @@ pub struct Gateway {
>, >,
websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
kill_send: tokio::sync::broadcast::Sender<()>, kill_send: tokio::sync::broadcast::Sender<()>,
store: Arc<Mutex<HashMap<Snowflake, Box<dyn Send + Any>>>>, store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
} }
impl Gateway { impl Gateway {
@ -518,12 +500,14 @@ impl Gateway {
$( $(
let mut message: $message_type = message; let mut message: $message_type = message;
if let Some(to_update) = self.store.lock().await.get(&message.id()) { if let Some(to_update) = self.store.lock().await.get(&message.id()) {
if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<Arc<RwLock<$update_type>>>, watch::Receiver<Arc<RwLock<$update_type>>>)>() { let object = to_update.clone();
// `object` is the current value of the `watch::channel`. It's being passed into `message.update()` to be modified let inner_object = object.read().unwrap();
// within the closure function. Then, this closure is passed to the `tx.send_modify()` method which applies the if let Some(_) = inner_object.downcast_ref::<$update_type>() {
// modification to the current value of the watch channel. let ptr = Arc::into_raw(object.clone());
let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<$update_type>).clone() };
drop(inner_object);
message.set_json(json.to_string()); message.set_json(json.to_string());
tx.send_modify(|object| message.update(object.clone())); message.update(downcasted.clone());
} else { } else {
warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id()) warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id())
} }

View File

@ -1,11 +1,7 @@
mod common; mod common;
use std::sync::{Arc, RwLock};
use chorus::gateway::*; use chorus::gateway::*;
use chorus::types::{ use chorus::types::{self, ChannelModifySchema};
self, ChannelModifySchema, Composite, PermissionFlags, RoleCreateModifySchema, RoleObject,
};
#[tokio::test] #[tokio::test]
/// Tests establishing a connection (hello and heartbeats) on the local gateway; /// Tests establishing a connection (hello and heartbeats) on the local gateway;
@ -65,48 +61,7 @@ async fn test_self_updating_structs() {
#[tokio::test] #[tokio::test]
async fn test_recursive_self_updating_structs() { async fn test_recursive_self_updating_structs() {
let mut bundle = common::setup().await; // Setup
let bundle = common::setup().await;
let guild = bundle
.user
.gateway
.observe_and_into_inner(bundle.guild.clone())
.await;
assert!(guild.roles.is_none());
let id = guild.id;
let permissions = PermissionFlags::CONNECT | PermissionFlags::MANAGE_EVENTS;
let permissions = Some(permissions.to_string());
let mut role_create_schema = RoleCreateModifySchema {
name: Some("among us".to_string()),
permissions,
hoist: Some(true),
icon: None,
unicode_emoji: Some("".to_string()),
mentionable: Some(true),
position: None,
color: None,
};
let role = RoleObject::create(&mut bundle.user, id, role_create_schema.clone())
.await
.unwrap();
let role_watch = bundle
.user
.gateway
.observer_channel(Arc::new(RwLock::new(role.clone())))
.await;
let guild = bundle
.user
.gateway
.observe_and_into_inner(bundle.guild.clone())
.await;
assert!(guild.roles.is_some());
role_create_schema.name = Some("enbyenvy".to_string());
RoleObject::modify(&mut bundle.user, id, role.id, role_create_schema)
.await
.unwrap();
let newrole = role_watch.borrow().read().unwrap().clone();
assert_eq!(newrole.name, "enbyenvy".to_string());
let guild_role = role_watch.borrow().read().unwrap().clone();
assert_eq!(guild_role.name, "enbyenvy".to_string());
common::teardown(bundle).await; common::teardown(bundle).await;
} }