Auto updating structs (#163)

* Gateway fields don't need to be pub

* Add store to Gateway

* Add UpdateMessage trait

* Proof of concept: impl UpdateMessage for Channel

* Start working on auto updating structs

* Send entity updates over watch channel

* Add id to UpdateMessage

* Create trait Updateable

* Add documentation

* add gateway test

* Complete test

* Impl UpdateMessage::update() for ChannelUpdate

* Impl UpdateMessage::update() for ChannelUpdate
Co-authored by: SpecificProtagonist <specificprotagonist@posteo.org>

* make channel::modify no longer mutate Channel

* change modify call

* remove unused imports

* Allow dead code with TODO to remove it

* fix channel::modify test

* Update src/gateway.rs

Co-authored-by: SpecificProtagonist <vincentjunge@posteo.net>

---------

Co-authored-by: SpecificProtagonist <vincentjunge@posteo.net>
This commit is contained in:
Flori 2023-07-21 15:35:31 +02:00 committed by GitHub
parent 776efce4dc
commit ccf44a5375
7 changed files with 133 additions and 23 deletions

View File

@ -64,11 +64,11 @@ impl Channel {
///
/// A `Result` that contains a `Channel` object if the request was successful, or an `ChorusLibError` if an error occurred during the request.
pub async fn modify(
&mut self,
&self,
modify_data: ChannelModifySchema,
channel_id: Snowflake,
user: &mut UserMeta,
) -> ChorusResult<()> {
) -> ChorusResult<Channel> {
let chorus_request = ChorusRequest {
request: Client::new()
.patch(format!(
@ -80,9 +80,7 @@ impl Channel {
.body(to_string(&modify_data).unwrap()),
limit_type: LimitType::Channel(channel_id),
};
let new_channel = chorus_request.deserialize_response::<Channel>(user).await?;
let _ = std::mem::replace(self, new_channel);
Ok(())
chorus_request.deserialize_response::<Channel>(user).await
}
pub async fn messages(

View File

@ -1,10 +1,14 @@
use crate::errors::GatewayError;
use crate::gateway::events::Events;
use crate::types;
use crate::types::WebSocketEvent;
use crate::types::{self, Channel, ChannelUpdate, Snowflake};
use crate::types::{UpdateMessage, WebSocketEvent};
use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tokio::time::sleep_until;
use futures_util::stream::SplitSink;
@ -163,6 +167,12 @@ pub struct GatewayHandle {
pub handle: JoinHandle<()>,
/// Tells gateway tasks to close
kill_send: tokio::sync::broadcast::Sender<()>,
store: Arc<Mutex<HashMap<Snowflake, Box<dyn Send + Any>>>>,
}
/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake.
pub trait Updateable: 'static + Send + Sync {
fn id(&self) -> Snowflake;
}
impl GatewayHandle {
@ -186,6 +196,27 @@ impl GatewayHandle {
.unwrap();
}
pub async fn observe<T: Updateable>(&self, object: T) -> watch::Receiver<T> {
let mut store = self.store.lock().await;
if let Some(channel) = store.get(&object.id()) {
let (_, rx) = channel
.downcast_ref::<(watch::Sender<T>, watch::Receiver<T>)>()
.unwrap_or_else(|| {
panic!(
"Snowflake {} already exists in the store, but it is not of type T.",
object.id()
)
});
rx.clone()
} else {
let id = object.id();
let channel = watch::channel(object);
let receiver = channel.1.clone();
store.insert(id, Box::new(channel));
receiver
}
}
/// Sends an identify event to the gateway
pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
@ -263,9 +294,9 @@ impl GatewayHandle {
}
pub struct Gateway {
pub events: Arc<Mutex<Events>>,
events: Arc<Mutex<Events>>,
heartbeat_handler: HeartbeatHandler,
pub websocket_send: Arc<
websocket_send: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
@ -273,8 +304,9 @@ pub struct Gateway {
>,
>,
>,
pub websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
kill_send: tokio::sync::broadcast::Sender<()>,
store: Arc<Mutex<HashMap<Snowflake, Box<dyn Send + Any>>>>,
}
impl Gateway {
@ -325,6 +357,8 @@ impl Gateway {
let events = Events::default();
let shared_events = Arc::new(Mutex::new(events));
let store = Arc::new(Mutex::new(HashMap::new()));
let mut gateway = Gateway {
events: shared_events.clone(),
heartbeat_handler: HeartbeatHandler::new(
@ -335,6 +369,7 @@ impl Gateway {
websocket_send: shared_websocket_send.clone(),
websocket_receive,
kill_send: kill_send.clone(),
store: store.clone(),
};
// Now we can continuously check for messages in a different task, since we aren't going to receive another hello
@ -348,6 +383,7 @@ impl Gateway {
websocket_send: shared_websocket_send.clone(),
handle,
kill_send: kill_send.clone(),
store,
})
}
@ -379,6 +415,7 @@ impl Gateway {
/// Deserializes and updates a dispatched event, when we already know its type;
/// (Called for every event in handle_message)
#[allow(dead_code)] // TODO: Remove this allow annotation
async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>(
data: &'a str,
event: &mut GatewayEvent<T>,
@ -431,17 +468,25 @@ impl Gateway {
trace!("Gateway: Received {event_name}");
macro_rules! handle {
($($name:literal => $($path:ident).+),*) => {
($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => {
match event_name.as_str() {
$($name => {
let event = &mut self.events.lock().await.$($path).+;
let result =
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if let Err(err) = result {
warn!("Failed to parse gateway event {event_name} ({err})");
match serde_json::from_str(gateway_payload.event_data.unwrap().get()) {
Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"),
Ok(message) => {
$(
let message: $message_type = message;
if let Some(to_update) = self.store.lock().await.get(&message.id()) {
if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<$update_type>, watch::Receiver<$update_type>)>() {
tx.send_modify(|object| message.update(object));
} else {
warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id())
}
}
)?
event.notify(message).await;
}
}
},)*
"RESUMED" => (),
@ -482,7 +527,7 @@ impl Gateway {
"AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete,
"AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution,
"CHANNEL_CREATE" => channel.create,
"CHANNEL_UPDATE" => channel.update,
"CHANNEL_UPDATE" => channel.update ChannelUpdate: Channel,
"CHANNEL_UNREAD_UPDATE" => channel.unread_update,
"CHANNEL_DELETE" => channel.delete,
"CHANNEL_PINS_UPDATE" => channel.pins_update,

View File

@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
use serde_aux::prelude::deserialize_string_from_number;
use serde_repr::{Deserialize_repr, Serialize_repr};
use crate::gateway::Updateable;
use crate::types::{
entities::{GuildMember, User},
utils::Snowflake,
@ -65,6 +66,12 @@ pub struct Channel {
pub video_quality_mode: Option<i32>,
}
impl Updateable for Channel {
fn id(&self) -> Snowflake {
self.id
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct Tag {
pub id: Snowflake,

View File

@ -3,6 +3,8 @@ use crate::types::{entities::Channel, Snowflake};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use super::UpdateMessage;
#[derive(Debug, Default, Deserialize, Serialize)]
/// See https://discord.com/developers/docs/topics/gateway-events#channel-pins-update
pub struct ChannelPinsUpdate {
@ -31,6 +33,15 @@ pub struct ChannelUpdate {
impl WebSocketEvent for ChannelUpdate {}
impl UpdateMessage<Channel> for ChannelUpdate {
fn update(&self, object_to_update: &mut Channel) {
*object_to_update = self.channel.clone();
}
fn id(&self) -> Snowflake {
self.channel.id
}
}
#[derive(Debug, Default, Deserialize, Serialize, Clone)]
/// Officially undocumented.
/// Sends updates to client about a new message with its id

View File

@ -26,6 +26,10 @@ pub use user::*;
pub use voice::*;
pub use webhooks::*;
use crate::gateway::Updateable;
use super::Snowflake;
mod application;
mod auto_moderation;
mod call;
@ -95,3 +99,23 @@ pub struct GatewayReceivePayload<'a> {
}
impl<'a> WebSocketEvent for GatewayReceivePayload<'a> {}
/// An [`UpdateMessage<T>`] represents a received Gateway Message which contains updated
/// information for an [`Updateable`] of Type T.
/// # Example:
/// ```rs
/// impl UpdateMessage<Channel> for ChannelUpdate {
/// fn update(...) {...}
/// fn id(...) {...}
/// }
/// ```
/// This would imply, that the [`WebSocketEvent`] "[`ChannelUpdate`]" contains new/updated information
/// about a [`Channel`]. The update method describes how this new information will be turned into
/// a [`Channel`] object.
pub(crate) trait UpdateMessage<T>: Clone
where
T: Updateable,
{
fn update(&self, object_to_update: &mut T);
fn id(&self) -> Snowflake;
}

View File

@ -28,10 +28,11 @@ async fn delete_channel() {
#[tokio::test]
async fn modify_channel() {
const CHANNEL_NAME: &str = "beepboop";
let mut bundle = common::setup().await;
let channel = &mut bundle.channel;
let modify_data: types::ChannelModifySchema = types::ChannelModifySchema {
name: Some("beepboop".to_string()),
name: Some(CHANNEL_NAME.to_string()),
channel_type: None,
topic: None,
icon: None,
@ -49,10 +50,10 @@ async fn modify_channel() {
default_thread_rate_limit_per_user: None,
video_quality_mode: None,
};
Channel::modify(channel, modify_data, channel.id, &mut bundle.user)
let modified_channel = Channel::modify(channel, modify_data, channel.id, &mut bundle.user)
.await
.unwrap();
assert_eq!(channel.name, Some("beepboop".to_string()));
assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string()));
let permission_override = PermissionFlags::from_vec(Vec::from([
PermissionFlags::MANAGE_CHANNELS,

View File

@ -1,6 +1,7 @@
mod common;
use chorus::gateway::*;
use chorus::types;
use chorus::types::{self, Channel};
#[tokio::test]
/// Tests establishing a connection (hello and heartbeats) on the local gateway;
@ -22,3 +23,26 @@ async fn test_gateway_authenticate() {
gateway.send_identify(identify).await;
}
#[tokio::test]
async fn test_self_updating_structs() {
let mut bundle = common::setup().await;
let gateway = Gateway::new(bundle.urls.wss).await.unwrap();
let mut identify = types::GatewayIdentifyPayload::common();
identify.token = bundle.user.token.clone();
gateway.send_identify(identify).await;
let channel_receiver = gateway.observe(bundle.channel.clone()).await;
let received_channel = channel_receiver.borrow();
assert_eq!(*received_channel, bundle.channel);
drop(received_channel);
let channel = &mut bundle.channel;
let modify_data = types::ChannelModifySchema {
name: Some("beepboop".to_string()),
..Default::default()
};
Channel::modify(channel, modify_data, channel.id, &mut bundle.user)
.await
.unwrap();
let received_channel = channel_receiver.borrow();
assert_eq!(received_channel.name.as_ref().unwrap(), "beepboop");
}