Auto updating structs (#163)
* Gateway fields don't need to be pub * Add store to Gateway * Add UpdateMessage trait * Proof of concept: impl UpdateMessage for Channel * Start working on auto updating structs * Send entity updates over watch channel * Add id to UpdateMessage * Create trait Updateable * Add documentation * add gateway test * Complete test * Impl UpdateMessage::update() for ChannelUpdate * Impl UpdateMessage::update() for ChannelUpdate Co-authored by: SpecificProtagonist <specificprotagonist@posteo.org> * make channel::modify no longer mutate Channel * change modify call * remove unused imports * Allow dead code with TODO to remove it * fix channel::modify test * Update src/gateway.rs Co-authored-by: SpecificProtagonist <vincentjunge@posteo.net> --------- Co-authored-by: SpecificProtagonist <vincentjunge@posteo.net>
This commit is contained in:
parent
776efce4dc
commit
ccf44a5375
|
@ -64,11 +64,11 @@ impl Channel {
|
||||||
///
|
///
|
||||||
/// A `Result` that contains a `Channel` object if the request was successful, or an `ChorusLibError` if an error occurred during the request.
|
/// A `Result` that contains a `Channel` object if the request was successful, or an `ChorusLibError` if an error occurred during the request.
|
||||||
pub async fn modify(
|
pub async fn modify(
|
||||||
&mut self,
|
&self,
|
||||||
modify_data: ChannelModifySchema,
|
modify_data: ChannelModifySchema,
|
||||||
channel_id: Snowflake,
|
channel_id: Snowflake,
|
||||||
user: &mut UserMeta,
|
user: &mut UserMeta,
|
||||||
) -> ChorusResult<()> {
|
) -> ChorusResult<Channel> {
|
||||||
let chorus_request = ChorusRequest {
|
let chorus_request = ChorusRequest {
|
||||||
request: Client::new()
|
request: Client::new()
|
||||||
.patch(format!(
|
.patch(format!(
|
||||||
|
@ -80,9 +80,7 @@ impl Channel {
|
||||||
.body(to_string(&modify_data).unwrap()),
|
.body(to_string(&modify_data).unwrap()),
|
||||||
limit_type: LimitType::Channel(channel_id),
|
limit_type: LimitType::Channel(channel_id),
|
||||||
};
|
};
|
||||||
let new_channel = chorus_request.deserialize_response::<Channel>(user).await?;
|
chorus_request.deserialize_response::<Channel>(user).await
|
||||||
let _ = std::mem::replace(self, new_channel);
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn messages(
|
pub async fn messages(
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
use crate::errors::GatewayError;
|
use crate::errors::GatewayError;
|
||||||
use crate::gateway::events::Events;
|
use crate::gateway::events::Events;
|
||||||
use crate::types;
|
use crate::types::{self, Channel, ChannelUpdate, Snowflake};
|
||||||
use crate::types::WebSocketEvent;
|
use crate::types::{UpdateMessage, WebSocketEvent};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use std::any::Any;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fmt::Debug;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use tokio::sync::watch;
|
||||||
use tokio::time::sleep_until;
|
use tokio::time::sleep_until;
|
||||||
|
|
||||||
use futures_util::stream::SplitSink;
|
use futures_util::stream::SplitSink;
|
||||||
|
@ -163,6 +167,12 @@ pub struct GatewayHandle {
|
||||||
pub handle: JoinHandle<()>,
|
pub handle: JoinHandle<()>,
|
||||||
/// Tells gateway tasks to close
|
/// Tells gateway tasks to close
|
||||||
kill_send: tokio::sync::broadcast::Sender<()>,
|
kill_send: tokio::sync::broadcast::Sender<()>,
|
||||||
|
store: Arc<Mutex<HashMap<Snowflake, Box<dyn Send + Any>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake.
|
||||||
|
pub trait Updateable: 'static + Send + Sync {
|
||||||
|
fn id(&self) -> Snowflake;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GatewayHandle {
|
impl GatewayHandle {
|
||||||
|
@ -186,6 +196,27 @@ impl GatewayHandle {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn observe<T: Updateable>(&self, object: T) -> watch::Receiver<T> {
|
||||||
|
let mut store = self.store.lock().await;
|
||||||
|
if let Some(channel) = store.get(&object.id()) {
|
||||||
|
let (_, rx) = channel
|
||||||
|
.downcast_ref::<(watch::Sender<T>, watch::Receiver<T>)>()
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
panic!(
|
||||||
|
"Snowflake {} already exists in the store, but it is not of type T.",
|
||||||
|
object.id()
|
||||||
|
)
|
||||||
|
});
|
||||||
|
rx.clone()
|
||||||
|
} else {
|
||||||
|
let id = object.id();
|
||||||
|
let channel = watch::channel(object);
|
||||||
|
let receiver = channel.1.clone();
|
||||||
|
store.insert(id, Box::new(channel));
|
||||||
|
receiver
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Sends an identify event to the gateway
|
/// Sends an identify event to the gateway
|
||||||
pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) {
|
pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) {
|
||||||
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
||||||
|
@ -263,9 +294,9 @@ impl GatewayHandle {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Gateway {
|
pub struct Gateway {
|
||||||
pub events: Arc<Mutex<Events>>,
|
events: Arc<Mutex<Events>>,
|
||||||
heartbeat_handler: HeartbeatHandler,
|
heartbeat_handler: HeartbeatHandler,
|
||||||
pub websocket_send: Arc<
|
websocket_send: Arc<
|
||||||
Mutex<
|
Mutex<
|
||||||
SplitSink<
|
SplitSink<
|
||||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||||
|
@ -273,8 +304,9 @@ pub struct Gateway {
|
||||||
>,
|
>,
|
||||||
>,
|
>,
|
||||||
>,
|
>,
|
||||||
pub websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||||
kill_send: tokio::sync::broadcast::Sender<()>,
|
kill_send: tokio::sync::broadcast::Sender<()>,
|
||||||
|
store: Arc<Mutex<HashMap<Snowflake, Box<dyn Send + Any>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Gateway {
|
impl Gateway {
|
||||||
|
@ -325,6 +357,8 @@ impl Gateway {
|
||||||
let events = Events::default();
|
let events = Events::default();
|
||||||
let shared_events = Arc::new(Mutex::new(events));
|
let shared_events = Arc::new(Mutex::new(events));
|
||||||
|
|
||||||
|
let store = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
|
||||||
let mut gateway = Gateway {
|
let mut gateway = Gateway {
|
||||||
events: shared_events.clone(),
|
events: shared_events.clone(),
|
||||||
heartbeat_handler: HeartbeatHandler::new(
|
heartbeat_handler: HeartbeatHandler::new(
|
||||||
|
@ -335,6 +369,7 @@ impl Gateway {
|
||||||
websocket_send: shared_websocket_send.clone(),
|
websocket_send: shared_websocket_send.clone(),
|
||||||
websocket_receive,
|
websocket_receive,
|
||||||
kill_send: kill_send.clone(),
|
kill_send: kill_send.clone(),
|
||||||
|
store: store.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Now we can continuously check for messages in a different task, since we aren't going to receive another hello
|
// Now we can continuously check for messages in a different task, since we aren't going to receive another hello
|
||||||
|
@ -348,6 +383,7 @@ impl Gateway {
|
||||||
websocket_send: shared_websocket_send.clone(),
|
websocket_send: shared_websocket_send.clone(),
|
||||||
handle,
|
handle,
|
||||||
kill_send: kill_send.clone(),
|
kill_send: kill_send.clone(),
|
||||||
|
store,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,6 +415,7 @@ impl Gateway {
|
||||||
|
|
||||||
/// Deserializes and updates a dispatched event, when we already know its type;
|
/// Deserializes and updates a dispatched event, when we already know its type;
|
||||||
/// (Called for every event in handle_message)
|
/// (Called for every event in handle_message)
|
||||||
|
#[allow(dead_code)] // TODO: Remove this allow annotation
|
||||||
async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>(
|
async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>(
|
||||||
data: &'a str,
|
data: &'a str,
|
||||||
event: &mut GatewayEvent<T>,
|
event: &mut GatewayEvent<T>,
|
||||||
|
@ -431,17 +468,25 @@ impl Gateway {
|
||||||
trace!("Gateway: Received {event_name}");
|
trace!("Gateway: Received {event_name}");
|
||||||
|
|
||||||
macro_rules! handle {
|
macro_rules! handle {
|
||||||
($($name:literal => $($path:ident).+),*) => {
|
($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => {
|
||||||
match event_name.as_str() {
|
match event_name.as_str() {
|
||||||
$($name => {
|
$($name => {
|
||||||
let event = &mut self.events.lock().await.$($path).+;
|
let event = &mut self.events.lock().await.$($path).+;
|
||||||
|
match serde_json::from_str(gateway_payload.event_data.unwrap().get()) {
|
||||||
let result =
|
Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"),
|
||||||
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
|
Ok(message) => {
|
||||||
.await;
|
$(
|
||||||
|
let message: $message_type = message;
|
||||||
if let Err(err) = result {
|
if let Some(to_update) = self.store.lock().await.get(&message.id()) {
|
||||||
warn!("Failed to parse gateway event {event_name} ({err})");
|
if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<$update_type>, watch::Receiver<$update_type>)>() {
|
||||||
|
tx.send_modify(|object| message.update(object));
|
||||||
|
} else {
|
||||||
|
warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)?
|
||||||
|
event.notify(message).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},)*
|
},)*
|
||||||
"RESUMED" => (),
|
"RESUMED" => (),
|
||||||
|
@ -482,7 +527,7 @@ impl Gateway {
|
||||||
"AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete,
|
"AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete,
|
||||||
"AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution,
|
"AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution,
|
||||||
"CHANNEL_CREATE" => channel.create,
|
"CHANNEL_CREATE" => channel.create,
|
||||||
"CHANNEL_UPDATE" => channel.update,
|
"CHANNEL_UPDATE" => channel.update ChannelUpdate: Channel,
|
||||||
"CHANNEL_UNREAD_UPDATE" => channel.unread_update,
|
"CHANNEL_UNREAD_UPDATE" => channel.unread_update,
|
||||||
"CHANNEL_DELETE" => channel.delete,
|
"CHANNEL_DELETE" => channel.delete,
|
||||||
"CHANNEL_PINS_UPDATE" => channel.pins_update,
|
"CHANNEL_PINS_UPDATE" => channel.pins_update,
|
||||||
|
|
|
@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use serde_aux::prelude::deserialize_string_from_number;
|
use serde_aux::prelude::deserialize_string_from_number;
|
||||||
use serde_repr::{Deserialize_repr, Serialize_repr};
|
use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||||
|
|
||||||
|
use crate::gateway::Updateable;
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
entities::{GuildMember, User},
|
entities::{GuildMember, User},
|
||||||
utils::Snowflake,
|
utils::Snowflake,
|
||||||
|
@ -65,6 +66,12 @@ pub struct Channel {
|
||||||
pub video_quality_mode: Option<i32>,
|
pub video_quality_mode: Option<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Updateable for Channel {
|
||||||
|
fn id(&self) -> Snowflake {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||||
pub struct Tag {
|
pub struct Tag {
|
||||||
pub id: Snowflake,
|
pub id: Snowflake,
|
||||||
|
|
|
@ -3,6 +3,8 @@ use crate::types::{entities::Channel, Snowflake};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use super::UpdateMessage;
|
||||||
|
|
||||||
#[derive(Debug, Default, Deserialize, Serialize)]
|
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||||
/// See https://discord.com/developers/docs/topics/gateway-events#channel-pins-update
|
/// See https://discord.com/developers/docs/topics/gateway-events#channel-pins-update
|
||||||
pub struct ChannelPinsUpdate {
|
pub struct ChannelPinsUpdate {
|
||||||
|
@ -31,6 +33,15 @@ pub struct ChannelUpdate {
|
||||||
|
|
||||||
impl WebSocketEvent for ChannelUpdate {}
|
impl WebSocketEvent for ChannelUpdate {}
|
||||||
|
|
||||||
|
impl UpdateMessage<Channel> for ChannelUpdate {
|
||||||
|
fn update(&self, object_to_update: &mut Channel) {
|
||||||
|
*object_to_update = self.channel.clone();
|
||||||
|
}
|
||||||
|
fn id(&self) -> Snowflake {
|
||||||
|
self.channel.id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Deserialize, Serialize, Clone)]
|
#[derive(Debug, Default, Deserialize, Serialize, Clone)]
|
||||||
/// Officially undocumented.
|
/// Officially undocumented.
|
||||||
/// Sends updates to client about a new message with its id
|
/// Sends updates to client about a new message with its id
|
||||||
|
|
|
@ -26,6 +26,10 @@ pub use user::*;
|
||||||
pub use voice::*;
|
pub use voice::*;
|
||||||
pub use webhooks::*;
|
pub use webhooks::*;
|
||||||
|
|
||||||
|
use crate::gateway::Updateable;
|
||||||
|
|
||||||
|
use super::Snowflake;
|
||||||
|
|
||||||
mod application;
|
mod application;
|
||||||
mod auto_moderation;
|
mod auto_moderation;
|
||||||
mod call;
|
mod call;
|
||||||
|
@ -95,3 +99,23 @@ pub struct GatewayReceivePayload<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> WebSocketEvent for GatewayReceivePayload<'a> {}
|
impl<'a> WebSocketEvent for GatewayReceivePayload<'a> {}
|
||||||
|
|
||||||
|
/// An [`UpdateMessage<T>`] represents a received Gateway Message which contains updated
|
||||||
|
/// information for an [`Updateable`] of Type T.
|
||||||
|
/// # Example:
|
||||||
|
/// ```rs
|
||||||
|
/// impl UpdateMessage<Channel> for ChannelUpdate {
|
||||||
|
/// fn update(...) {...}
|
||||||
|
/// fn id(...) {...}
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
/// This would imply, that the [`WebSocketEvent`] "[`ChannelUpdate`]" contains new/updated information
|
||||||
|
/// about a [`Channel`]. The update method describes how this new information will be turned into
|
||||||
|
/// a [`Channel`] object.
|
||||||
|
pub(crate) trait UpdateMessage<T>: Clone
|
||||||
|
where
|
||||||
|
T: Updateable,
|
||||||
|
{
|
||||||
|
fn update(&self, object_to_update: &mut T);
|
||||||
|
fn id(&self) -> Snowflake;
|
||||||
|
}
|
||||||
|
|
|
@ -28,10 +28,11 @@ async fn delete_channel() {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn modify_channel() {
|
async fn modify_channel() {
|
||||||
|
const CHANNEL_NAME: &str = "beepboop";
|
||||||
let mut bundle = common::setup().await;
|
let mut bundle = common::setup().await;
|
||||||
let channel = &mut bundle.channel;
|
let channel = &mut bundle.channel;
|
||||||
let modify_data: types::ChannelModifySchema = types::ChannelModifySchema {
|
let modify_data: types::ChannelModifySchema = types::ChannelModifySchema {
|
||||||
name: Some("beepboop".to_string()),
|
name: Some(CHANNEL_NAME.to_string()),
|
||||||
channel_type: None,
|
channel_type: None,
|
||||||
topic: None,
|
topic: None,
|
||||||
icon: None,
|
icon: None,
|
||||||
|
@ -49,10 +50,10 @@ async fn modify_channel() {
|
||||||
default_thread_rate_limit_per_user: None,
|
default_thread_rate_limit_per_user: None,
|
||||||
video_quality_mode: None,
|
video_quality_mode: None,
|
||||||
};
|
};
|
||||||
Channel::modify(channel, modify_data, channel.id, &mut bundle.user)
|
let modified_channel = Channel::modify(channel, modify_data, channel.id, &mut bundle.user)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(channel.name, Some("beepboop".to_string()));
|
assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string()));
|
||||||
|
|
||||||
let permission_override = PermissionFlags::from_vec(Vec::from([
|
let permission_override = PermissionFlags::from_vec(Vec::from([
|
||||||
PermissionFlags::MANAGE_CHANNELS,
|
PermissionFlags::MANAGE_CHANNELS,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use chorus::gateway::*;
|
use chorus::gateway::*;
|
||||||
use chorus::types;
|
use chorus::types::{self, Channel};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
/// Tests establishing a connection (hello and heartbeats) on the local gateway;
|
/// Tests establishing a connection (hello and heartbeats) on the local gateway;
|
||||||
|
@ -22,3 +23,26 @@ async fn test_gateway_authenticate() {
|
||||||
|
|
||||||
gateway.send_identify(identify).await;
|
gateway.send_identify(identify).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_self_updating_structs() {
|
||||||
|
let mut bundle = common::setup().await;
|
||||||
|
let gateway = Gateway::new(bundle.urls.wss).await.unwrap();
|
||||||
|
let mut identify = types::GatewayIdentifyPayload::common();
|
||||||
|
identify.token = bundle.user.token.clone();
|
||||||
|
gateway.send_identify(identify).await;
|
||||||
|
let channel_receiver = gateway.observe(bundle.channel.clone()).await;
|
||||||
|
let received_channel = channel_receiver.borrow();
|
||||||
|
assert_eq!(*received_channel, bundle.channel);
|
||||||
|
drop(received_channel);
|
||||||
|
let channel = &mut bundle.channel;
|
||||||
|
let modify_data = types::ChannelModifySchema {
|
||||||
|
name: Some("beepboop".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
Channel::modify(channel, modify_data, channel.id, &mut bundle.user)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let received_channel = channel_receiver.borrow();
|
||||||
|
assert_eq!(received_channel.name.as_ref().unwrap(), "beepboop");
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue