Change observe() to take Arc<RwLock<T>>

This commit is contained in:
fowb 2023-08-12 19:47:11 +02:00
parent b672dd221c
commit d28f19d8ca
2 changed files with 13 additions and 7 deletions

View File

@ -6,7 +6,7 @@ use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::watch;
use tokio::time::sleep_until;
@ -196,20 +196,26 @@ impl GatewayHandle {
.unwrap();
}
pub async fn observe<T: Updateable>(&self, object: T) -> watch::Receiver<T> {
pub async fn observe<T: Updateable>(
&self,
object: Arc<RwLock<T>>,
) -> watch::Receiver<Arc<RwLock<T>>> {
let mut store = self.store.lock().await;
if let Some(channel) = store.get(&object.id()) {
if let Some(channel) = store.get(&object.clone().read().unwrap().id()) {
let (_, rx) = channel
.downcast_ref::<(watch::Sender<T>, watch::Receiver<T>)>()
.downcast_ref::<(
watch::Sender<Arc<RwLock<T>>>,
watch::Receiver<Arc<RwLock<T>>>,
)>()
.unwrap_or_else(|| {
panic!(
"Snowflake {} already exists in the store, but it is not of type T.",
object.id()
object.read().unwrap().id()
)
});
rx.clone()
} else {
let id = object.id();
let id = object.read().unwrap().id();
let channel = watch::channel(object);
let receiver = channel.1.clone();
store.insert(id, Box::new(channel));

View File

@ -29,7 +29,7 @@ async fn test_gateway_authenticate() {
#[tokio::test]
async fn test_self_updating_structs() {
let mut bundle = common::setup().await;
let channel_updater = bundle.user.gateway.observe(bundle.channel.clone()).await;
let channel_updater = bundle.user.gateway.observe(bundle.channel).await;
let received_channel = channel_updater.borrow().clone();
assert_eq!(received_channel, bundle.channel);
let channel = &mut bundle.channel;