Remove tokio watch channels

This commit is contained in:
bitfl0wer 2023-08-16 21:26:19 +02:00
parent e9f94267c6
commit fbf72b74d0
2 changed files with 33 additions and 94 deletions

View File

@ -10,7 +10,6 @@ use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::watch;
use tokio::time::sleep_until;
use futures_util::stream::SplitSink;
@ -150,14 +149,12 @@ impl GatewayMessage {
}
}
pub type ObservableObject = dyn Send + Sync + Any;
/// Represents a handle to a Gateway connection. A Gateway connection will create observable
/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently
/// implemented types with the trait [`WebSocketEvent`]
/// Using this handle you can also send Gateway Events directly.
///
/// # Store
/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel<T: Updateable>`]. See the
/// [`Updateable`] trait for more information.
#[derive(Debug)]
pub struct GatewayHandle {
pub url: String,
@ -173,7 +170,7 @@ pub struct GatewayHandle {
pub handle: JoinHandle<()>,
/// Tells gateway tasks to close
kill_send: tokio::sync::broadcast::Sender<()>,
store: Arc<Mutex<HashMap<Snowflake, Box<dyn Send + Any>>>>,
store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
}
/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake.
@ -202,55 +199,42 @@ impl GatewayHandle {
.unwrap();
}
pub async fn observer_channel<T: Updateable + Clone + Composite<T>>(
pub async fn observe<T: Updateable + Clone + Debug + Composite<T>>(
&self,
object: Arc<RwLock<T>>,
) -> watch::Receiver<Arc<RwLock<T>>> {
) -> Arc<RwLock<T>> {
let mut store = self.store.lock().await;
let id = object.read().unwrap().id();
if let Some(channel) = store.get(&id) {
let (_, rx) = channel
.downcast_ref::<(
watch::Sender<Arc<RwLock<T>>>,
watch::Receiver<Arc<RwLock<T>>>,
)>()
.unwrap_or_else(|| {
panic!(
"Snowflake {} already exists in the store, but it is not of type T.",
object.read().unwrap().id()
)
});
rx.clone()
let object = channel.clone();
let inner_object = object.read().unwrap();
inner_object.downcast_ref::<T>().unwrap_or_else(|| {
panic!(
"Snowflake {} already exists in the store, but it is not of type T.",
id
)
});
let ptr = Arc::into_raw(object.clone());
unsafe { println!("{:?}", Arc::from_raw(ptr as *const RwLock<T>).clone()) };
unsafe { Arc::from_raw(ptr as *const RwLock<T>).clone() }
} else {
let id = object.read().unwrap().id();
let object = object.read().unwrap().clone();
let object = object.clone().watch_whole(self).await;
let channel = watch::channel(Arc::new(RwLock::new(object)));
let receiver = channel.1.clone();
store.insert(id, Box::new(channel));
receiver
let wrapped = Arc::new(RwLock::new(object));
store.insert(id, wrapped.clone());
wrapped
}
}
/// Recursively observes and updates all updateable fields on the struct T. Returns an object
/// with all of its observable fields being observed.
pub async fn observe_and_get<T: Updateable + Clone + Composite<T>>(
&self,
object: Arc<RwLock<T>>,
) -> Arc<RwLock<T>> {
let channel = self.observer_channel(object.clone()).await;
let object = channel.borrow().clone();
object
}
/// Recursively observes and updates all updateable fields on the struct T. Returns an object `T`
/// with all of its observable fields being observed.
pub async fn observe_and_into_inner<T: Updateable + Clone + Composite<T>>(
pub async fn observe_and_into_inner<T: Updateable + Clone + Debug + Composite<T>>(
&self,
object: Arc<RwLock<T>>,
) -> T {
let channel = self.observer_channel(object.clone()).await;
let object = channel.borrow().clone().read().unwrap().clone();
let channel = self.observe(object.clone()).await;
let object = channel.read().unwrap().clone();
object
}
@ -330,8 +314,6 @@ impl GatewayHandle {
}
}
/// The value of `store`s [`HashMap`] is a [`tokio::sync::watch::channel<T: Updateable>`]. See the
/// [`Updateable`] trait for more information.
pub struct Gateway {
events: Arc<Mutex<Events>>,
heartbeat_handler: HeartbeatHandler,
@ -345,7 +327,7 @@ pub struct Gateway {
>,
websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
kill_send: tokio::sync::broadcast::Sender<()>,
store: Arc<Mutex<HashMap<Snowflake, Box<dyn Send + Any>>>>,
store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
}
impl Gateway {
@ -518,12 +500,14 @@ impl Gateway {
$(
let mut message: $message_type = message;
if let Some(to_update) = self.store.lock().await.get(&message.id()) {
if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<Arc<RwLock<$update_type>>>, watch::Receiver<Arc<RwLock<$update_type>>>)>() {
// `object` is the current value of the `watch::channel`. It's being passed into `message.update()` to be modified
// within the closure function. Then, this closure is passed to the `tx.send_modify()` method which applies the
// modification to the current value of the watch channel.
let object = to_update.clone();
let inner_object = object.read().unwrap();
if let Some(_) = inner_object.downcast_ref::<$update_type>() {
let ptr = Arc::into_raw(object.clone());
let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<$update_type>).clone() };
drop(inner_object);
message.set_json(json.to_string());
tx.send_modify(|object| message.update(object.clone()));
message.update(downcasted.clone());
} else {
warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id())
}

View File

@ -1,11 +1,7 @@
mod common;
use std::sync::{Arc, RwLock};
use chorus::gateway::*;
use chorus::types::{
self, ChannelModifySchema, Composite, PermissionFlags, RoleCreateModifySchema, RoleObject,
};
use chorus::types::{self, ChannelModifySchema};
#[tokio::test]
/// Tests establishing a connection (hello and heartbeats) on the local gateway;
@ -65,48 +61,7 @@ async fn test_self_updating_structs() {
#[tokio::test]
async fn test_recursive_self_updating_structs() {
let mut bundle = common::setup().await;
let guild = bundle
.user
.gateway
.observe_and_into_inner(bundle.guild.clone())
.await;
assert!(guild.roles.is_none());
let id = guild.id;
let permissions = PermissionFlags::CONNECT | PermissionFlags::MANAGE_EVENTS;
let permissions = Some(permissions.to_string());
let mut role_create_schema = RoleCreateModifySchema {
name: Some("among us".to_string()),
permissions,
hoist: Some(true),
icon: None,
unicode_emoji: Some("".to_string()),
mentionable: Some(true),
position: None,
color: None,
};
let role = RoleObject::create(&mut bundle.user, id, role_create_schema.clone())
.await
.unwrap();
let role_watch = bundle
.user
.gateway
.observer_channel(Arc::new(RwLock::new(role.clone())))
.await;
let guild = bundle
.user
.gateway
.observe_and_into_inner(bundle.guild.clone())
.await;
assert!(guild.roles.is_some());
role_create_schema.name = Some("enbyenvy".to_string());
RoleObject::modify(&mut bundle.user, id, role.id, role_create_schema)
.await
.unwrap();
let newrole = role_watch.borrow().read().unwrap().clone();
assert_eq!(newrole.name, "enbyenvy".to_string());
let guild_role = role_watch.borrow().read().unwrap().clone();
assert_eq!(guild_role.name, "enbyenvy".to_string());
// Setup
let bundle = common::setup().await;
common::teardown(bundle).await;
}