Change observe() to take Arc<RwLock<T>>

This commit is contained in:
fowb 2023-08-12 19:47:11 +02:00
parent ff47965e99
commit a0c9f6fb14
2 changed files with 13 additions and 7 deletions

View File

@ -6,7 +6,7 @@ use async_trait::async_trait;
use std::any::Any; use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::Arc; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use tokio::sync::watch; use tokio::sync::watch;
use tokio::time::sleep_until; use tokio::time::sleep_until;
@ -196,20 +196,26 @@ impl GatewayHandle {
.unwrap(); .unwrap();
} }
pub async fn observe<T: Updateable>(&self, object: T) -> watch::Receiver<T> { pub async fn observe<T: Updateable>(
&self,
object: Arc<RwLock<T>>,
) -> watch::Receiver<Arc<RwLock<T>>> {
let mut store = self.store.lock().await; let mut store = self.store.lock().await;
if let Some(channel) = store.get(&object.id()) { if let Some(channel) = store.get(&object.clone().read().unwrap().id()) {
let (_, rx) = channel let (_, rx) = channel
.downcast_ref::<(watch::Sender<T>, watch::Receiver<T>)>() .downcast_ref::<(
watch::Sender<Arc<RwLock<T>>>,
watch::Receiver<Arc<RwLock<T>>>,
)>()
.unwrap_or_else(|| { .unwrap_or_else(|| {
panic!( panic!(
"Snowflake {} already exists in the store, but it is not of type T.", "Snowflake {} already exists in the store, but it is not of type T.",
object.id() object.read().unwrap().id()
) )
}); });
rx.clone() rx.clone()
} else { } else {
let id = object.id(); let id = object.read().unwrap().id();
let channel = watch::channel(object); let channel = watch::channel(object);
let receiver = channel.1.clone(); let receiver = channel.1.clone();
store.insert(id, Box::new(channel)); store.insert(id, Box::new(channel));

View File

@ -29,7 +29,7 @@ async fn test_gateway_authenticate() {
#[tokio::test] #[tokio::test]
async fn test_self_updating_structs() { async fn test_self_updating_structs() {
let mut bundle = common::setup().await; let mut bundle = common::setup().await;
let channel_updater = bundle.user.gateway.observe(bundle.channel.clone()).await; let channel_updater = bundle.user.gateway.observe(bundle.channel).await;
let received_channel = channel_updater.borrow().clone(); let received_channel = channel_updater.borrow().clone();
assert_eq!(received_channel, bundle.channel); assert_eq!(received_channel, bundle.channel);
let channel = &mut bundle.channel; let channel = &mut bundle.channel;