Compare commits

...

16 Commits

Author SHA1 Message Date
xystrive ea62c35b3e
Merge 317dbe1ed1 into b0667a33fb 2024-07-13 18:41:15 +02:00
Flori b0667a33fb
Replace `Observer` and `GatewayEvent` with `pubserve` crate (#524)
This PR replaces our internal `Observer` trait with a more generic
version, as offered by our `pubserve` crate.

This also has the added benefit of removing the... creative logic used
in `Observer::unsubscribe()` which did not check the memory the internal
`Arc` points to, but rather some weird

```rs
self.observers
            .retain(|obs| format!("{:?}", obs) != to_remove);
```

which was written ages ago by a version of me that was still at the very
beginning of learning Rust
2024-07-13 18:03:02 +02:00
bitfl0wer ebcb6b65e4
Fix voice, voice_udp features 2024-07-13 14:53:04 +02:00
bitfl0wer 7554f90187
Replace `Observer` and `GatewayEvent` with `pubserve` crate 2024-07-13 14:28:22 +02:00
xystrive 317dbe1ed1 fix: according to changes made to `ChorusUser` object field 4920c91e52 2024-07-05 17:15:24 +01:00
xystrive ee19cb762f fix: concatenate `Instance`'s api url with endpoint 2024-07-05 17:06:16 +01:00
xystrive c3c506bc1b fix: typo in `ChorusError::MfaRequired` `Display` message 2024-07-05 17:04:05 +01:00
xystrive 8f995a9f63 refactor: change `ChorusUser` new method calls according to changes done 2024-07-04 19:09:28 +01:00
xystrive 6ef33c01c7 refactor: change `User` object assignements according to changes done in `ChorusUser` 2024-07-04 19:08:38 +01:00
xystrive a3aa4625f9 feat: add handling for MFA Required response errors 2024-07-04 18:37:33 +01:00
xystrive 2c4e069269 refactor: remove mfa_token argument from `ChorusRequest` `new` method 2024-07-04 18:34:50 +01:00
xystrive 85a2878f9a feat: add `complete_mfa_challenge` method to ChorusUser 2024-07-04 18:28:23 +01:00
xystrive 4920c91e52 refactor: wrap object field type with `Option` and change `shell()` method accordingly
This change allows to exactly know when a ChorusUser is authenticated or not
2024-07-04 18:27:37 +01:00
xystrive 538781502c feat: add MFARequired error to ChorusError 2024-07-04 18:03:04 +01:00
xystrive c3eec16c29 refactor: remove old Response object involved in MFA implementation 2024-07-04 18:02:10 +01:00
xystrive 2dcee5cff0 feat: add Request/Response and other necessary objects for MFA implementation 2024-07-04 17:58:46 +01:00
36 changed files with 312 additions and 252 deletions

10
Cargo.lock generated
View File

@ -233,6 +233,7 @@ dependencies = [
"lazy_static", "lazy_static",
"log", "log",
"poem", "poem",
"pubserve",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest",
@ -1594,6 +1595,15 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "pubserve"
version = "1.1.0-alpha.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1781b2a51798c98a381e839e61bc5ce6426bd89bb9c3f9142de2086a80591cd"
dependencies = [
"async-trait",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.36" version = "1.0.36"

View File

@ -66,6 +66,7 @@ crypto_secretbox = { version = "0.1.1", optional = true }
rand = "0.8.5" rand = "0.8.5"
flate2 = { version = "1.0.30", optional = true } flate2 = { version = "1.0.30", optional = true }
webpki-roots = "0.26.3" webpki-roots = "0.26.3"
pubserve = { version = "1.1.0-alpha.1", features = ["async", "send"] }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
rustls = "0.21.10" rustls = "0.21.10"

View File

@ -17,9 +17,9 @@ use async_trait::async_trait;
use chorus::gateway::{Gateway, GatewayOptions}; use chorus::gateway::{Gateway, GatewayOptions};
use chorus::{ use chorus::{
self, self,
gateway::Observer,
types::{GatewayIdentifyPayload, GatewayReady}, types::{GatewayIdentifyPayload, GatewayReady},
}; };
use pubserve::Subscriber;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use tokio::{self}; use tokio::{self};
@ -38,7 +38,7 @@ pub struct ExampleObserver {}
// The Observer trait can be implemented for a struct for a given websocketevent to handle observing it // The Observer trait can be implemented for a struct for a given websocketevent to handle observing it
// One struct can be an observer of multiple websocketevents, if needed // One struct can be an observer of multiple websocketevents, if needed
#[async_trait] #[async_trait]
impl Observer<GatewayReady> for ExampleObserver { impl Subscriber<GatewayReady> for ExampleObserver {
// After we subscribe to an event this function is called every time we receive it // After we subscribe to an event this function is called every time we receive it
async fn update(&self, _data: &GatewayReady) { async fn update(&self, _data: &GatewayReady) {
println!("Observed Ready!"); println!("Observed Ready!");
@ -56,7 +56,9 @@ async fn main() {
let options = GatewayOptions::default(); let options = GatewayOptions::default();
// Initiate the gateway connection // Initiate the gateway connection
let gateway = Gateway::spawn(gateway_websocket_url, options).await.unwrap(); let gateway = Gateway::spawn(gateway_websocket_url, options)
.await
.unwrap();
// Create an instance of our observer // Create an instance of our observer
let observer = ExampleObserver {}; let observer = ExampleObserver {};

View File

@ -24,5 +24,5 @@ async fn main() {
.await .await
.expect("An error occurred during the login process"); .expect("An error occurred during the login process");
dbg!(user.belongs_to); dbg!(user.belongs_to);
dbg!(&user.object.read().unwrap().username); dbg!(&user.object.unwrap().as_ref().read().unwrap().username);
} }

View File

@ -40,7 +40,7 @@ impl Instance {
user.settings = login_result.settings; user.settings = login_result.settings;
let object = User::get(&mut user, None).await?; let object = User::get(&mut user, None).await?;
*user.object.write().unwrap() = object; user.object = Some(Arc::new(RwLock::new(object)));
let mut identify = GatewayIdentifyPayload::common(); let mut identify = GatewayIdentifyPayload::common();
identify.token = user.token(); identify.token = user.token();

View File

@ -29,7 +29,7 @@ impl Instance {
let object = User::get(&mut user, None).await?; let object = User::get(&mut user, None).await?;
let settings = User::get_settings(&mut user).await?; let settings = User::get_settings(&mut user).await?;
*user.object.write().unwrap() = object; user.object = Some(Arc::new(RwLock::new(object)));
*user.settings.write().unwrap() = settings; *user.settings.write().unwrap() = settings;
let mut identify = GatewayIdentifyPayload::common(); let mut identify = GatewayIdentifyPayload::common();

View File

@ -49,7 +49,7 @@ impl Instance {
let object = User::get(&mut user, None).await?; let object = User::get(&mut user, None).await?;
let settings = User::get_settings(&mut user).await?; let settings = User::get_settings(&mut user).await?;
*user.object.write().unwrap() = object; user.object = Some(Arc::new(RwLock::new(object)));
*user.settings.write().unwrap() = settings; *user.settings.write().unwrap() = settings;
let mut identify = GatewayIdentifyPayload::common(); let mut identify = GatewayIdentifyPayload::common();

View File

@ -30,7 +30,6 @@ impl Channel {
), ),
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -61,7 +60,6 @@ impl Channel {
&url, &url,
None, None,
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Channel(self.id), LimitType::Channel(self.id),
); );
@ -101,7 +99,6 @@ impl Channel {
&url, &url,
Some(to_string(&modify_data).unwrap()), Some(to_string(&modify_data).unwrap()),
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -134,7 +131,6 @@ impl Channel {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
Default::default(), Default::default(),
); );
@ -196,7 +192,6 @@ impl Channel {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(self.id), LimitType::Channel(self.id),
); );
@ -225,7 +220,6 @@ impl Channel {
&url, &url,
Some(to_string(&schema).unwrap()), Some(to_string(&schema).unwrap()),
None, None,
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );

View File

@ -151,7 +151,6 @@ impl Message {
.as_str(), .as_str(),
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -182,7 +181,6 @@ impl Message {
.as_str(), .as_str(),
None, None,
audit_log_reason, audit_log_reason,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -209,7 +207,6 @@ impl Message {
.as_str(), .as_str(),
None, None,
audit_log_reason, audit_log_reason,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -258,7 +255,6 @@ impl Message {
.as_str(), .as_str(),
Some(to_string(&schema).unwrap()), Some(to_string(&schema).unwrap()),
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -292,7 +288,6 @@ impl Message {
.as_str(), .as_str(),
Some(to_string(&schema).unwrap()), Some(to_string(&schema).unwrap()),
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -321,7 +316,6 @@ impl Message {
.as_str(), .as_str(),
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -348,7 +342,6 @@ impl Message {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -382,7 +375,6 @@ impl Message {
&url, &url,
Some(to_string(&schema).unwrap()), Some(to_string(&schema).unwrap()),
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -409,7 +401,6 @@ impl Message {
&url, &url,
None, None,
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -447,7 +438,6 @@ impl Message {
.as_str(), .as_str(),
Some(to_string(&messages).unwrap()), Some(to_string(&messages).unwrap()),
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );
@ -472,7 +462,6 @@ impl Message {
.as_str(), .as_str(),
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );

View File

@ -83,7 +83,6 @@ impl types::Channel {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(channel_id), LimitType::Channel(channel_id),
); );

View File

@ -36,7 +36,6 @@ impl ReactionMeta {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(self.channel_id), LimitType::Channel(self.channel_id),
); );
@ -65,7 +64,6 @@ impl ReactionMeta {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(self.channel_id), LimitType::Channel(self.channel_id),
); );
@ -96,7 +94,6 @@ impl ReactionMeta {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(self.channel_id), LimitType::Channel(self.channel_id),
); );
@ -130,7 +127,6 @@ impl ReactionMeta {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(self.channel_id), LimitType::Channel(self.channel_id),
); );
@ -159,7 +155,6 @@ impl ReactionMeta {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(self.channel_id), LimitType::Channel(self.channel_id),
); );
@ -196,7 +191,6 @@ impl ReactionMeta {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Channel(self.channel_id), LimitType::Channel(self.channel_id),
); );

View File

@ -220,7 +220,6 @@ impl Guild {
.as_str(), .as_str(),
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -246,7 +245,6 @@ impl Guild {
.as_str(), .as_str(),
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -279,7 +277,6 @@ impl Guild {
.as_str(), .as_str(),
None, None,
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -309,7 +306,6 @@ impl Guild {
.as_str(), .as_str(),
Some(to_string(&schema).unwrap()), Some(to_string(&schema).unwrap()),
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -336,7 +332,6 @@ impl Guild {
.as_str(), .as_str(),
Some(to_string(&schema).unwrap()), Some(to_string(&schema).unwrap()),
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -362,7 +357,6 @@ impl Guild {
.as_str(), .as_str(),
Some(to_string(&schema).unwrap()), Some(to_string(&schema).unwrap()),
None, None,
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -393,7 +387,6 @@ impl Guild {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -426,7 +419,6 @@ impl Guild {
&url, &url,
None, None,
None, None,
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -456,7 +448,6 @@ impl Guild {
.as_str(), .as_str(),
Some(to_string(&schema).unwrap()), Some(to_string(&schema).unwrap()),
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );
@ -487,7 +478,6 @@ impl Guild {
&url, &url,
None, None,
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );

View File

@ -188,7 +188,6 @@ impl types::RoleObject {
&url, &url,
None, None,
audit_log_reason.as_deref(), audit_log_reason.as_deref(),
None,
Some(user), Some(user),
LimitType::Guild(guild_id), LimitType::Guild(guild_id),
); );

View File

@ -5,7 +5,7 @@
//! Contains all the errors that can be returned by the library. //! Contains all the errors that can be returned by the library.
use custom_error::custom_error; use custom_error::custom_error;
use crate::types::WebSocketEvent; use crate::types::{MfaRequiredSchema, WebSocketEvent};
use chorus_macros::WebSocketEvent; use chorus_macros::WebSocketEvent;
custom_error! { custom_error! {
@ -46,7 +46,9 @@ custom_error! {
/// Malformed or unexpected response. /// Malformed or unexpected response.
InvalidResponse{error: String} = "The response is malformed and cannot be processed. Error: {error}", InvalidResponse{error: String} = "The response is malformed and cannot be processed. Error: {error}",
/// Invalid, insufficient or too many arguments provided. /// Invalid, insufficient or too many arguments provided.
InvalidArguments{error: String} = "Invalid arguments were provided. Error: {error}" InvalidArguments{error: String} = "Invalid arguments were provided. Error: {error}",
/// The request requires MFA verification
MfaRequired {error: MfaRequiredSchema} = "Mfa verification is required to perform this action"
} }
impl From<reqwest::Error> for ChorusError { impl From<reqwest::Error> for ChorusError {

View File

@ -2,6 +2,8 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use pubserve::Publisher;
use super::*; use super::*;
use crate::types; use crate::types;
@ -23,144 +25,144 @@ pub struct Events {
pub call: Call, pub call: Call,
pub voice: Voice, pub voice: Voice,
pub webhooks: Webhooks, pub webhooks: Webhooks,
pub gateway_identify_payload: GatewayEvent<types::GatewayIdentifyPayload>, pub gateway_identify_payload: Publisher<types::GatewayIdentifyPayload>,
pub gateway_resume: GatewayEvent<types::GatewayResume>, pub gateway_resume: Publisher<types::GatewayResume>,
pub error: GatewayEvent<GatewayError>, pub error: Publisher<GatewayError>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Application { pub struct Application {
pub command_permissions_update: GatewayEvent<types::ApplicationCommandPermissionsUpdate>, pub command_permissions_update: Publisher<types::ApplicationCommandPermissionsUpdate>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct AutoModeration { pub struct AutoModeration {
pub rule_create: GatewayEvent<types::AutoModerationRuleCreate>, pub rule_create: Publisher<types::AutoModerationRuleCreate>,
pub rule_update: GatewayEvent<types::AutoModerationRuleUpdate>, pub rule_update: Publisher<types::AutoModerationRuleUpdate>,
pub rule_delete: GatewayEvent<types::AutoModerationRuleDelete>, pub rule_delete: Publisher<types::AutoModerationRuleDelete>,
pub action_execution: GatewayEvent<types::AutoModerationActionExecution>, pub action_execution: Publisher<types::AutoModerationActionExecution>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Session { pub struct Session {
pub ready: GatewayEvent<types::GatewayReady>, pub ready: Publisher<types::GatewayReady>,
pub ready_supplemental: GatewayEvent<types::GatewayReadySupplemental>, pub ready_supplemental: Publisher<types::GatewayReadySupplemental>,
pub replace: GatewayEvent<types::SessionsReplace>, pub replace: Publisher<types::SessionsReplace>,
pub reconnect: GatewayEvent<types::GatewayReconnect>, pub reconnect: Publisher<types::GatewayReconnect>,
pub invalid: GatewayEvent<types::GatewayInvalidSession>, pub invalid: Publisher<types::GatewayInvalidSession>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct StageInstance { pub struct StageInstance {
pub create: GatewayEvent<types::StageInstanceCreate>, pub create: Publisher<types::StageInstanceCreate>,
pub update: GatewayEvent<types::StageInstanceUpdate>, pub update: Publisher<types::StageInstanceUpdate>,
pub delete: GatewayEvent<types::StageInstanceDelete>, pub delete: Publisher<types::StageInstanceDelete>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Message { pub struct Message {
pub create: GatewayEvent<types::MessageCreate>, pub create: Publisher<types::MessageCreate>,
pub update: GatewayEvent<types::MessageUpdate>, pub update: Publisher<types::MessageUpdate>,
pub delete: GatewayEvent<types::MessageDelete>, pub delete: Publisher<types::MessageDelete>,
pub delete_bulk: GatewayEvent<types::MessageDeleteBulk>, pub delete_bulk: Publisher<types::MessageDeleteBulk>,
pub reaction_add: GatewayEvent<types::MessageReactionAdd>, pub reaction_add: Publisher<types::MessageReactionAdd>,
pub reaction_remove: GatewayEvent<types::MessageReactionRemove>, pub reaction_remove: Publisher<types::MessageReactionRemove>,
pub reaction_remove_all: GatewayEvent<types::MessageReactionRemoveAll>, pub reaction_remove_all: Publisher<types::MessageReactionRemoveAll>,
pub reaction_remove_emoji: GatewayEvent<types::MessageReactionRemoveEmoji>, pub reaction_remove_emoji: Publisher<types::MessageReactionRemoveEmoji>,
pub ack: GatewayEvent<types::MessageACK>, pub ack: Publisher<types::MessageACK>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct User { pub struct User {
pub update: GatewayEvent<types::UserUpdate>, pub update: Publisher<types::UserUpdate>,
pub guild_settings_update: GatewayEvent<types::UserGuildSettingsUpdate>, pub guild_settings_update: Publisher<types::UserGuildSettingsUpdate>,
pub presence_update: GatewayEvent<types::PresenceUpdate>, pub presence_update: Publisher<types::PresenceUpdate>,
pub typing_start: GatewayEvent<types::TypingStartEvent>, pub typing_start: Publisher<types::TypingStartEvent>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Relationship { pub struct Relationship {
pub add: GatewayEvent<types::RelationshipAdd>, pub add: Publisher<types::RelationshipAdd>,
pub remove: GatewayEvent<types::RelationshipRemove>, pub remove: Publisher<types::RelationshipRemove>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Channel { pub struct Channel {
pub create: GatewayEvent<types::ChannelCreate>, pub create: Publisher<types::ChannelCreate>,
pub update: GatewayEvent<types::ChannelUpdate>, pub update: Publisher<types::ChannelUpdate>,
pub unread_update: GatewayEvent<types::ChannelUnreadUpdate>, pub unread_update: Publisher<types::ChannelUnreadUpdate>,
pub delete: GatewayEvent<types::ChannelDelete>, pub delete: Publisher<types::ChannelDelete>,
pub pins_update: GatewayEvent<types::ChannelPinsUpdate>, pub pins_update: Publisher<types::ChannelPinsUpdate>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Thread { pub struct Thread {
pub create: GatewayEvent<types::ThreadCreate>, pub create: Publisher<types::ThreadCreate>,
pub update: GatewayEvent<types::ThreadUpdate>, pub update: Publisher<types::ThreadUpdate>,
pub delete: GatewayEvent<types::ThreadDelete>, pub delete: Publisher<types::ThreadDelete>,
pub list_sync: GatewayEvent<types::ThreadListSync>, pub list_sync: Publisher<types::ThreadListSync>,
pub member_update: GatewayEvent<types::ThreadMemberUpdate>, pub member_update: Publisher<types::ThreadMemberUpdate>,
pub members_update: GatewayEvent<types::ThreadMembersUpdate>, pub members_update: Publisher<types::ThreadMembersUpdate>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Guild { pub struct Guild {
pub create: GatewayEvent<types::GuildCreate>, pub create: Publisher<types::GuildCreate>,
pub update: GatewayEvent<types::GuildUpdate>, pub update: Publisher<types::GuildUpdate>,
pub delete: GatewayEvent<types::GuildDelete>, pub delete: Publisher<types::GuildDelete>,
pub audit_log_entry_create: GatewayEvent<types::GuildAuditLogEntryCreate>, pub audit_log_entry_create: Publisher<types::GuildAuditLogEntryCreate>,
pub ban_add: GatewayEvent<types::GuildBanAdd>, pub ban_add: Publisher<types::GuildBanAdd>,
pub ban_remove: GatewayEvent<types::GuildBanRemove>, pub ban_remove: Publisher<types::GuildBanRemove>,
pub emojis_update: GatewayEvent<types::GuildEmojisUpdate>, pub emojis_update: Publisher<types::GuildEmojisUpdate>,
pub stickers_update: GatewayEvent<types::GuildStickersUpdate>, pub stickers_update: Publisher<types::GuildStickersUpdate>,
pub integrations_update: GatewayEvent<types::GuildIntegrationsUpdate>, pub integrations_update: Publisher<types::GuildIntegrationsUpdate>,
pub member_add: GatewayEvent<types::GuildMemberAdd>, pub member_add: Publisher<types::GuildMemberAdd>,
pub member_remove: GatewayEvent<types::GuildMemberRemove>, pub member_remove: Publisher<types::GuildMemberRemove>,
pub member_update: GatewayEvent<types::GuildMemberUpdate>, pub member_update: Publisher<types::GuildMemberUpdate>,
pub members_chunk: GatewayEvent<types::GuildMembersChunk>, pub members_chunk: Publisher<types::GuildMembersChunk>,
pub role_create: GatewayEvent<types::GuildRoleCreate>, pub role_create: Publisher<types::GuildRoleCreate>,
pub role_update: GatewayEvent<types::GuildRoleUpdate>, pub role_update: Publisher<types::GuildRoleUpdate>,
pub role_delete: GatewayEvent<types::GuildRoleDelete>, pub role_delete: Publisher<types::GuildRoleDelete>,
pub role_scheduled_event_create: GatewayEvent<types::GuildScheduledEventCreate>, pub role_scheduled_event_create: Publisher<types::GuildScheduledEventCreate>,
pub role_scheduled_event_update: GatewayEvent<types::GuildScheduledEventUpdate>, pub role_scheduled_event_update: Publisher<types::GuildScheduledEventUpdate>,
pub role_scheduled_event_delete: GatewayEvent<types::GuildScheduledEventDelete>, pub role_scheduled_event_delete: Publisher<types::GuildScheduledEventDelete>,
pub role_scheduled_event_user_add: GatewayEvent<types::GuildScheduledEventUserAdd>, pub role_scheduled_event_user_add: Publisher<types::GuildScheduledEventUserAdd>,
pub role_scheduled_event_user_remove: GatewayEvent<types::GuildScheduledEventUserRemove>, pub role_scheduled_event_user_remove: Publisher<types::GuildScheduledEventUserRemove>,
pub passive_update_v1: GatewayEvent<types::PassiveUpdateV1>, pub passive_update_v1: Publisher<types::PassiveUpdateV1>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Invite { pub struct Invite {
pub create: GatewayEvent<types::InviteCreate>, pub create: Publisher<types::InviteCreate>,
pub delete: GatewayEvent<types::InviteDelete>, pub delete: Publisher<types::InviteDelete>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Integration { pub struct Integration {
pub create: GatewayEvent<types::IntegrationCreate>, pub create: Publisher<types::IntegrationCreate>,
pub update: GatewayEvent<types::IntegrationUpdate>, pub update: Publisher<types::IntegrationUpdate>,
pub delete: GatewayEvent<types::IntegrationDelete>, pub delete: Publisher<types::IntegrationDelete>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Interaction { pub struct Interaction {
pub create: GatewayEvent<types::InteractionCreate>, pub create: Publisher<types::InteractionCreate>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Call { pub struct Call {
pub create: GatewayEvent<types::CallCreate>, pub create: Publisher<types::CallCreate>,
pub update: GatewayEvent<types::CallUpdate>, pub update: Publisher<types::CallUpdate>,
pub delete: GatewayEvent<types::CallDelete>, pub delete: Publisher<types::CallDelete>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Voice { pub struct Voice {
pub state_update: GatewayEvent<types::VoiceStateUpdate>, pub state_update: Publisher<types::VoiceStateUpdate>,
pub server_update: GatewayEvent<types::VoiceServerUpdate>, pub server_update: Publisher<types::VoiceServerUpdate>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Webhooks { pub struct Webhooks {
pub update: GatewayEvent<types::WebhooksUpdate>, pub update: Publisher<types::WebhooksUpdate>,
} }

View File

@ -7,6 +7,7 @@ use std::time::Duration;
use flate2::Decompress; use flate2::Decompress;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use log::*; use log::*;
use pubserve::Publisher;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use tokio::task; use tokio::task;
@ -197,7 +198,7 @@ impl Gateway {
#[allow(dead_code)] // TODO: Remove this allow annotation #[allow(dead_code)] // TODO: Remove this allow annotation
async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>( async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>(
data: &'a str, data: &'a str,
event: &mut GatewayEvent<T>, event: &mut Publisher<T>,
) -> Result<(), serde_json::Error> { ) -> Result<(), serde_json::Error> {
let data_deserialize_result: Result<T, serde_json::Error> = serde_json::from_str(data); let data_deserialize_result: Result<T, serde_json::Error> = serde_json::from_str(data);
@ -205,7 +206,7 @@ impl Gateway {
return Err(data_deserialize_result.err().unwrap()); return Err(data_deserialize_result.err().unwrap());
} }
event.notify(data_deserialize_result.unwrap()).await; event.publish(data_deserialize_result.unwrap()).await;
Ok(()) Ok(())
} }
@ -253,7 +254,7 @@ impl Gateway {
if let Some(error) = msg.error() { if let Some(error) = msg.error() {
warn!("GW: Received error {:?}, connection will close..", error); warn!("GW: Received error {:?}, connection will close..", error);
self.close().await; self.close().await;
self.events.lock().await.error.notify(error).await; self.events.lock().await.error.publish(error).await;
} else { } else {
warn!( warn!(
"Message unrecognised: {:?}, please open an issue on the chorus github", "Message unrecognised: {:?}, please open an issue on the chorus github",
@ -292,7 +293,7 @@ impl Gateway {
let id = if message.id().is_some() { let id = if message.id().is_some() {
message.id().unwrap() message.id().unwrap()
} else { } else {
event.notify(message).await; event.publish(message).await;
return; return;
}; };
if let Some(to_update) = store.get(&id) { if let Some(to_update) = store.get(&id) {
@ -314,7 +315,7 @@ impl Gateway {
} }
} }
)? )?
event.notify(message).await; event.publish(message).await;
} }
} }
},)* },)*
@ -329,7 +330,7 @@ impl Gateway {
return; return;
} }
Ok(sessions) => { Ok(sessions) => {
self.events.lock().await.session.replace.notify( self.events.lock().await.session.replace.publish(
types::SessionsReplace {sessions} types::SessionsReplace {sessions}
).await; ).await;
} }
@ -446,7 +447,7 @@ impl Gateway {
.await .await
.session .session
.reconnect .reconnect
.notify(reconnect) .publish(reconnect)
.await; .await;
} }
GATEWAY_INVALID_SESSION => { GATEWAY_INVALID_SESSION => {
@ -471,7 +472,7 @@ impl Gateway {
.await .await
.session .session
.invalid .invalid
.notify(invalid_session) .publish(invalid_session)
.await; .await;
} }
// Starts our heartbeat // Starts our heartbeat

View File

@ -82,56 +82,3 @@ pub type ObservableObject = dyn Send + Sync + Any;
pub trait Updateable: 'static + Send + Sync { pub trait Updateable: 'static + Send + Sync {
fn id(&self) -> Snowflake; fn id(&self) -> Snowflake;
} }
/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to
/// an Observable. The Observer is notified when the Observable's data changes.
/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent.
/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing.
#[async_trait]
pub trait Observer<T>: Sync + Send + std::fmt::Debug {
async fn update(&self, data: &T);
}
/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a
/// change in the WebSocketEvent. GatewayEvents are observable.
#[derive(Default, Debug)]
pub struct GatewayEvent<T: WebSocketEvent> {
observers: Vec<Arc<dyn Observer<T>>>,
}
impl<T: WebSocketEvent> GatewayEvent<T> {
pub fn new() -> Self {
Self {
observers: Vec::new(),
}
}
/// Returns true if the GatewayEvent is observed by at least one Observer.
pub fn is_observed(&self) -> bool {
!self.observers.is_empty()
}
/// Subscribes an Observer to the GatewayEvent.
pub fn subscribe(&mut self, observable: Arc<dyn Observer<T>>) {
self.observers.push(observable);
}
/// Unsubscribes an Observer from the GatewayEvent.
pub fn unsubscribe(&mut self, observable: &dyn Observer<T>) {
// .retain()'s closure retains only those elements of the vector, which have a different
// pointer value than observable.
// The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same
// anddd there is no way to do that without using format
let to_remove = format!("{:?}", observable);
self.observers
.retain(|obs| format!("{:?}", obs) != to_remove);
}
/// Notifies the observers of the GatewayEvent.
pub(crate) async fn notify(&self, new_event_data: T) {
for observer in &self.observers {
observer.update(&new_event_data).await;
}
}
}

View File

@ -8,16 +8,18 @@ use std::collections::HashMap;
use std::fmt; use std::fmt;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use chrono::Utc;
use crate::errors::ChorusResult; use crate::errors::ChorusResult;
use crate::gateway::{Gateway, GatewayHandle, GatewayOptions}; use crate::gateway::{Gateway, GatewayHandle, GatewayOptions};
use crate::ratelimiter::ChorusRequest; use crate::ratelimiter::ChorusRequest;
use crate::types::types::subconfigs::limits::rates::RateLimits; use crate::types::types::subconfigs::limits::rates::RateLimits;
use crate::types::{ use crate::types::{
GeneralConfiguration, Limit, LimitType, LimitsConfiguration, Shared, User, UserSettings, GeneralConfiguration, Limit, LimitType, LimitsConfiguration, MfaTokenSchema, MfaVerifySchema, Shared, User, UserSettings, MfaToken
}; };
use crate::UrlBundle; use crate::UrlBundle;
@ -169,9 +171,10 @@ impl fmt::Display for Token {
pub struct ChorusUser { pub struct ChorusUser {
pub belongs_to: Shared<Instance>, pub belongs_to: Shared<Instance>,
pub token: String, pub token: String,
pub mfa_token: Option<MfaToken>,
pub limits: Option<HashMap<LimitType, Limit>>, pub limits: Option<HashMap<LimitType, Limit>>,
pub settings: Shared<UserSettings>, pub settings: Shared<UserSettings>,
pub object: Shared<User>, pub object: Option<Shared<User>>,
pub gateway: GatewayHandle, pub gateway: GatewayHandle,
} }
@ -202,12 +205,13 @@ impl ChorusUser {
token: String, token: String,
limits: Option<HashMap<LimitType, Limit>>, limits: Option<HashMap<LimitType, Limit>>,
settings: Shared<UserSettings>, settings: Shared<UserSettings>,
object: Shared<User>, object: Option<Shared<User>>,
gateway: GatewayHandle, gateway: GatewayHandle,
) -> ChorusUser { ) -> ChorusUser {
ChorusUser { ChorusUser {
belongs_to, belongs_to,
token, token,
mfa_token: None,
limits, limits,
settings, settings,
object, object,
@ -222,7 +226,6 @@ impl ChorusUser {
/// first. /// first.
pub(crate) async fn shell(instance: Shared<Instance>, token: String) -> ChorusUser { pub(crate) async fn shell(instance: Shared<Instance>, token: String) -> ChorusUser {
let settings = Arc::new(RwLock::new(UserSettings::default())); let settings = Arc::new(RwLock::new(UserSettings::default()));
let object = Arc::new(RwLock::new(User::default()));
let wss_url = instance.read().unwrap().urls.wss.clone(); let wss_url = instance.read().unwrap().urls.wss.clone();
// Dummy gateway object // Dummy gateway object
let gateway = Gateway::spawn(wss_url, GatewayOptions::default()) let gateway = Gateway::spawn(wss_url, GatewayOptions::default())
@ -230,6 +233,7 @@ impl ChorusUser {
.unwrap(); .unwrap();
ChorusUser { ChorusUser {
token, token,
mfa_token: None,
belongs_to: instance.clone(), belongs_to: instance.clone(),
limits: instance limits: instance
.read() .read()
@ -238,8 +242,40 @@ impl ChorusUser {
.as_ref() .as_ref()
.map(|info| info.ratelimits.clone()), .map(|info| info.ratelimits.clone()),
settings, settings,
object, object: None,
gateway, gateway,
} }
} }
/// Sends a request to complete an MFA challenge.
/// # Reference
/// See <https://docs.discord.sex/authentication#verify-mfa>
///
/// If successful, the MFA verification JWT returned is set on the current [ChorusUser] executing the
/// request.
///
/// The JWT token expires after 5 minutes.
pub async fn complete_mfa_challenge(&mut self, mfa_verify_schema: MfaVerifySchema) -> ChorusResult<()> {
let endpoint_url = self.belongs_to.read().unwrap().urls.api.clone() + "/mfa/finish";
let chorus_request = ChorusRequest {
request: Client::new()
.post(endpoint_url)
.header("Authorization", self.token())
.json(&mfa_verify_schema),
limit_type: match self.object.is_some() {
true => LimitType::Global,
false => LimitType::Ip,
},
};
let mfa_token_schema = chorus_request
.deserialize_response::<MfaTokenSchema>(self).await?;
self.mfa_token = Some(MfaToken {
token: mfa_token_schema.token,
expires_at: Utc::now() + Duration::from_secs(60 * 5),
});
Ok(())
}
} }

View File

@ -6,7 +6,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use log::{self, debug};
use reqwest::{Client, RequestBuilder, Response}; use reqwest::{Client, RequestBuilder, Response};
use serde::Deserialize; use serde::Deserialize;
use serde_json::from_str; use serde_json::from_str;
@ -14,7 +13,7 @@ use serde_json::from_str;
use crate::{ use crate::{
errors::{ChorusError, ChorusResult}, errors::{ChorusError, ChorusResult},
instance::ChorusUser, instance::ChorusUser,
types::{types::subconfigs::limits::rates::RateLimits, Limit, LimitType, LimitsConfiguration}, types::{types::subconfigs::limits::rates::RateLimits, Limit, LimitType, LimitsConfiguration, MfaRequiredSchema},
}; };
/// Chorus' request struct. This struct is used to send rate-limited requests to the Spacebar server. /// Chorus' request struct. This struct is used to send rate-limited requests to the Spacebar server.
@ -35,13 +34,12 @@ impl ChorusRequest {
/// * [`http::Method::DELETE`] /// * [`http::Method::DELETE`]
/// * [`http::Method::PATCH`] /// * [`http::Method::PATCH`]
/// * [`http::Method::HEAD`] /// * [`http::Method::HEAD`]
#[allow(unused_variables)] // TODO: Add mfa_token to request, once we figure out *how* to do so correctly #[allow(unused_variables)]
pub fn new( pub fn new(
method: http::Method, method: http::Method,
url: &str, url: &str,
body: Option<String>, body: Option<String>,
audit_log_reason: Option<&str>, audit_log_reason: Option<&str>,
mfa_token: Option<&str>,
chorus_user: Option<&mut ChorusUser>, chorus_user: Option<&mut ChorusUser>,
limit_type: LimitType, limit_type: LimitType,
) -> ChorusRequest { ) -> ChorusRequest {
@ -267,7 +265,14 @@ impl ChorusRequest {
async fn interpret_error(response: reqwest::Response) -> ChorusError { async fn interpret_error(response: reqwest::Response) -> ChorusError {
match response.status().as_u16() { match response.status().as_u16() {
401..=403 | 407 => ChorusError::NoPermission, 401 => {
let response = response.text().await.unwrap();
match serde_json::from_str::<MfaRequiredSchema>(&response) {
Ok(response) => ChorusError::MfaRequired { error: response },
Err(_) => ChorusError::NoPermission,
}
}
402..=403 | 407 => ChorusError::NoPermission,
404 => ChorusError::NotFound { 404 => ChorusError::NotFound {
error: response.text().await.unwrap(), error: response.text().await.unwrap(),
}, },

View File

@ -0,0 +1,7 @@
use chrono::{DateTime, Utc};
#[derive(Debug, Clone)]
pub struct MfaToken {
pub token: String,
pub expires_at: DateTime<Utc>,
}

View File

@ -26,6 +26,7 @@ pub use user::*;
pub use user_settings::*; pub use user_settings::*;
pub use voice_state::*; pub use voice_state::*;
pub use webhook::*; pub use webhook::*;
pub use mfa_token::*;
use crate::types::Shared; use crate::types::Shared;
#[cfg(feature = "client")] #[cfg(feature = "client")]
@ -67,6 +68,7 @@ mod user;
mod user_settings; mod user_settings;
mod voice_state; mod voice_state;
mod webhook; mod webhook;
mod mfa_token;
#[cfg(feature = "client")] #[cfg(feature = "client")]
#[async_trait(?Send)] #[async_trait(?Send)]

View File

@ -35,12 +35,3 @@ pub struct LoginSchema {
pub login_source: Option<String>, pub login_source: Option<String>,
pub gift_code_sku_id: Option<String>, pub gift_code_sku_id: Option<String>,
} }
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TotpSchema {
code: String,
ticket: String,
gift_code_sku_id: Option<String>,
login_source: Option<String>,
}

62
src/types/schema/mfa.rs Normal file
View File

@ -0,0 +1,62 @@
use std::fmt::Display;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub struct MfaRequiredSchema {
pub message: String,
pub code: i32,
pub mfa: MfaVerificationSchema,
}
impl Display for MfaRequiredSchema {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MfaRequired")
.field("message", &self.message)
.field("code", &self.code)
.field("mfa", &self.mfa)
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub struct MfaVerificationSchema {
pub ticket: String,
pub methods: Vec<MfaMethod>
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub struct MfaMethod {
#[serde(rename = "type")]
pub kind: MfaType,
#[serde(skip_serializing_if = "Option::is_none")]
pub challenge: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub backup_codes_allowed: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum MfaType {
TOTP,
SMS,
Backup,
WebAuthn,
Password,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct MfaVerifySchema {
pub ticket: String,
pub mfa_type: MfaType,
pub data: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MfaTokenSchema {
pub token: String,
}

View File

@ -5,6 +5,7 @@
pub use apierror::*; pub use apierror::*;
pub use audit_log::*; pub use audit_log::*;
pub use auth::*; pub use auth::*;
pub use mfa::*;
pub use channel::*; pub use channel::*;
pub use guild::*; pub use guild::*;
pub use message::*; pub use message::*;
@ -17,6 +18,7 @@ pub use voice_state::*;
mod apierror; mod apierror;
mod audit_log; mod audit_log;
mod auth; mod auth;
mod mfa;
mod channel; mod channel;
mod guild; mod guild;
mod message; mod message;

View File

@ -2,9 +2,10 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use pubserve::Publisher;
use crate::{ use crate::{
errors::VoiceGatewayError, errors::VoiceGatewayError,
gateway::GatewayEvent,
types::{ types::{
SessionDescription, SessionUpdate, Speaking, SsrcDefinition, VoiceBackendVersion, SessionDescription, SessionUpdate, Speaking, SsrcDefinition, VoiceBackendVersion,
VoiceClientConnectFlags, VoiceClientConnectPlatform, VoiceClientDisconnection, VoiceClientConnectFlags, VoiceClientConnectPlatform, VoiceClientDisconnection,
@ -14,15 +15,15 @@ use crate::{
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct VoiceEvents { pub struct VoiceEvents {
pub voice_ready: GatewayEvent<VoiceReady>, pub voice_ready: Publisher<VoiceReady>,
pub backend_version: GatewayEvent<VoiceBackendVersion>, pub backend_version: Publisher<VoiceBackendVersion>,
pub session_description: GatewayEvent<SessionDescription>, pub session_description: Publisher<SessionDescription>,
pub session_update: GatewayEvent<SessionUpdate>, pub session_update: Publisher<SessionUpdate>,
pub speaking: GatewayEvent<Speaking>, pub speaking: Publisher<Speaking>,
pub ssrc_definition: GatewayEvent<SsrcDefinition>, pub ssrc_definition: Publisher<SsrcDefinition>,
pub client_disconnect: GatewayEvent<VoiceClientDisconnection>, pub client_disconnect: Publisher<VoiceClientDisconnection>,
pub client_connect_flags: GatewayEvent<VoiceClientConnectFlags>, pub client_connect_flags: Publisher<VoiceClientConnectFlags>,
pub client_connect_platform: GatewayEvent<VoiceClientConnectPlatform>, pub client_connect_platform: Publisher<VoiceClientConnectPlatform>,
pub media_sink_wants: GatewayEvent<VoiceMediaSinkWants>, pub media_sink_wants: Publisher<VoiceMediaSinkWants>,
pub error: GatewayEvent<VoiceGatewayError>, pub error: Publisher<VoiceGatewayError>,
} }

View File

@ -6,6 +6,7 @@ use std::{sync::Arc, time::Duration};
use log::*; use log::*;
use pubserve::Publisher;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use futures_util::SinkExt; use futures_util::SinkExt;
@ -16,7 +17,6 @@ use crate::gateway::Stream;
use crate::gateway::WebSocketBackend; use crate::gateway::WebSocketBackend;
use crate::{ use crate::{
errors::VoiceGatewayError, errors::VoiceGatewayError,
gateway::GatewayEvent,
types::{ types::{
VoiceGatewayReceivePayload, VoiceHelloData, WebSocketEvent, VOICE_BACKEND_VERSION, VoiceGatewayReceivePayload, VoiceHelloData, WebSocketEvent, VOICE_BACKEND_VERSION,
VOICE_CLIENT_CONNECT_FLAGS, VOICE_CLIENT_CONNECT_PLATFORM, VOICE_CLIENT_DISCONNECT, VOICE_CLIENT_CONNECT_FLAGS, VOICE_CLIENT_CONNECT_PLATFORM, VOICE_CLIENT_DISCONNECT,
@ -160,7 +160,7 @@ impl VoiceGateway {
/// (Called for every event in handle_message) /// (Called for every event in handle_message)
async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>( async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>(
data: &'a str, data: &'a str,
event: &mut GatewayEvent<T>, event: &mut Publisher<T>,
) -> Result<(), serde_json::Error> { ) -> Result<(), serde_json::Error> {
let data_deserialize_result: Result<T, serde_json::Error> = serde_json::from_str(data); let data_deserialize_result: Result<T, serde_json::Error> = serde_json::from_str(data);
@ -168,7 +168,7 @@ impl VoiceGateway {
return Err(data_deserialize_result.err().unwrap()); return Err(data_deserialize_result.err().unwrap());
} }
event.notify(data_deserialize_result.unwrap()).await; event.publish(data_deserialize_result.unwrap()).await;
Ok(()) Ok(())
} }
@ -182,7 +182,7 @@ impl VoiceGateway {
if let Some(error) = msg.error() { if let Some(error) = msg.error() {
warn!("GW: Received error {:?}, connection will close..", error); warn!("GW: Received error {:?}, connection will close..", error);
self.close().await; self.close().await;
self.events.lock().await.error.notify(error).await; self.events.lock().await.error.publish(error).await;
} else { } else {
warn!( warn!(
"Message unrecognised: {:?}, please open an issue on the chorus github", "Message unrecognised: {:?}, please open an issue on the chorus github",

View File

@ -3,23 +3,24 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use discortp::{rtcp::Rtcp, rtp::Rtp}; use discortp::{rtcp::Rtcp, rtp::Rtp};
use pubserve::Publisher;
use crate::{gateway::GatewayEvent, types::WebSocketEvent}; use crate::types::WebSocketEvent;
impl WebSocketEvent for Rtp {} impl WebSocketEvent for Rtp {}
impl WebSocketEvent for Rtcp {} impl WebSocketEvent for Rtcp {}
#[derive(Debug)] #[derive(Debug)]
pub struct VoiceUDPEvents { pub struct VoiceUDPEvents {
pub rtp: GatewayEvent<Rtp>, pub rtp: Publisher<Rtp>,
pub rtcp: GatewayEvent<Rtcp>, pub rtcp: Publisher<Rtcp>,
} }
impl Default for VoiceUDPEvents { impl Default for VoiceUDPEvents {
fn default() -> Self { fn default() -> Self {
Self { Self {
rtp: GatewayEvent::new(), rtp: Publisher::new(),
rtcp: GatewayEvent::new(), rtcp: Publisher::new(),
} }
} }
} }

View File

@ -214,7 +214,7 @@ impl UdpHandler {
.lock() .lock()
.await .await
.rtp .rtp
.notify(rtp_with_decrypted_data) .publish(rtp_with_decrypted_data)
.await; .await;
} }
Demuxed::Rtcp(rtcp) => { Demuxed::Rtcp(rtcp) => {
@ -251,7 +251,7 @@ impl UdpHandler {
} }
}; };
self.events.lock().await.rtcp.notify(rtcp_data).await; self.events.lock().await.rtcp.publish(rtcp_data).await;
} }
Demuxed::FailedParse(e) => { Demuxed::FailedParse(e) => {
trace!("VUDP: Failed to parse packet: {:?}", e); trace!("VUDP: Failed to parse packet: {:?}", e);

View File

@ -85,8 +85,10 @@ async fn test_login_with_token() {
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
bundle.user.object.read().unwrap().id, bundle.user.object.as_ref().unwrap()
other_user.object.read().unwrap().id .read().unwrap()
.id,
other_user.object.unwrap().read().unwrap().id
); );
assert_eq!(bundle.user.token, other_user.token); assert_eq!(bundle.user.token, other_user.token);

View File

@ -2,7 +2,11 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use chorus::types::{self, Channel, GetChannelMessagesSchema, MessageSendSchema, PermissionFlags, PermissionOverwrite, PermissionOverwriteType, PrivateChannelCreateSchema, RelationshipType, Snowflake}; use chorus::types::{
self, Channel, GetChannelMessagesSchema, MessageSendSchema, PermissionFlags,
PermissionOverwrite, PermissionOverwriteType, PrivateChannelCreateSchema, RelationshipType,
Snowflake,
};
mod common; mod common;
@ -67,7 +71,7 @@ async fn modify_channel() {
assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string())); assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string()));
let permission_override = PermissionFlags::MANAGE_CHANNELS | PermissionFlags::MANAGE_MESSAGES; let permission_override = PermissionFlags::MANAGE_CHANNELS | PermissionFlags::MANAGE_MESSAGES;
let user_id: types::Snowflake = bundle.user.object.read().unwrap().id; let user_id: types::Snowflake = bundle.user.object.as_ref().unwrap().read().unwrap().id;
let permission_override = PermissionOverwrite { let permission_override = PermissionOverwrite {
id: user_id, id: user_id,
overwrite_type: PermissionOverwriteType::Member, overwrite_type: PermissionOverwriteType::Member,
@ -155,7 +159,13 @@ async fn create_dm() {
let other_user = bundle.create_user("integrationtestuser2").await; let other_user = bundle.create_user("integrationtestuser2").await;
let user = &mut bundle.user; let user = &mut bundle.user;
let private_channel_create_schema = PrivateChannelCreateSchema { let private_channel_create_schema = PrivateChannelCreateSchema {
recipients: Some(Vec::from([other_user.object.read().unwrap().id])), recipients: Some(Vec::from([other_user
.object
.as_ref()
.unwrap()
.read()
.unwrap()
.id])),
access_tokens: None, access_tokens: None,
nicks: None, nicks: None,
}; };
@ -175,7 +185,7 @@ async fn create_dm() {
.unwrap() .unwrap()
.id .id
.clone(), .clone(),
other_user.object.read().unwrap().id other_user.object.unwrap().read().unwrap().id
); );
assert_eq!( assert_eq!(
dm_channel dm_channel
@ -188,7 +198,7 @@ async fn create_dm() {
.unwrap() .unwrap()
.id .id
.clone(), .clone(),
user.object.read().unwrap().id.clone() user.object.as_ref().unwrap().read().unwrap().id.clone()
); );
common::teardown(bundle).await; common::teardown(bundle).await;
} }
@ -200,9 +210,9 @@ async fn remove_add_person_from_to_dm() {
let mut bundle = common::setup().await; let mut bundle = common::setup().await;
let mut other_user = bundle.create_user("integrationtestuser2").await; let mut other_user = bundle.create_user("integrationtestuser2").await;
let mut third_user = bundle.create_user("integrationtestuser3").await; let mut third_user = bundle.create_user("integrationtestuser3").await;
let third_user_id = third_user.object.read().unwrap().id; let third_user_id = third_user.object.as_ref().unwrap().read().unwrap().id;
let other_user_id = other_user.object.read().unwrap().id; let other_user_id = other_user.object.as_ref().unwrap().read().unwrap().id;
let user_id = bundle.user.object.read().unwrap().id; let user_id = bundle.user.object.as_ref().unwrap().read().unwrap().id;
let user = &mut bundle.user; let user = &mut bundle.user;
let private_channel_create_schema = PrivateChannelCreateSchema { let private_channel_create_schema = PrivateChannelCreateSchema {
recipients: Some(Vec::from([other_user_id, third_user_id])), recipients: Some(Vec::from([other_user_id, third_user_id])),

View File

@ -47,6 +47,7 @@ impl TestBundle {
ChorusUser { ChorusUser {
belongs_to: self.user.belongs_to.clone(), belongs_to: self.user.belongs_to.clone(),
token: self.user.token.clone(), token: self.user.token.clone(),
mfa_token: None,
limits: self.user.limits.clone(), limits: self.user.limits.clone(),
settings: self.user.settings.clone(), settings: self.user.settings.clone(),
object: self.user.object.clone(), object: self.user.object.clone(),

View File

@ -14,6 +14,7 @@ use chorus::types::{
self, Channel, ChannelCreateSchema, ChannelModifySchema, GatewayReady, IntoShared, self, Channel, ChannelCreateSchema, ChannelModifySchema, GatewayReady, IntoShared,
RoleCreateModifySchema, RoleObject, RoleCreateModifySchema, RoleObject,
}; };
use pubserve::Subscriber;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*; use wasm_bindgen_test::*;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
@ -30,7 +31,9 @@ use wasmtimer::tokio::sleep;
async fn test_gateway_establish() { async fn test_gateway_establish() {
let bundle = common::setup().await; let bundle = common::setup().await;
let _: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone(), GatewayOptions::default()).await.unwrap(); let _: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone(), GatewayOptions::default())
.await
.unwrap();
common::teardown(bundle).await common::teardown(bundle).await
} }
@ -40,7 +43,7 @@ struct GatewayReadyObserver {
} }
#[async_trait] #[async_trait]
impl Observer<GatewayReady> for GatewayReadyObserver { impl Subscriber<GatewayReady> for GatewayReadyObserver {
async fn update(&self, _data: &GatewayReady) { async fn update(&self, _data: &GatewayReady) {
self.channel.send(()).await.unwrap(); self.channel.send(()).await.unwrap();
} }
@ -52,7 +55,9 @@ impl Observer<GatewayReady> for GatewayReadyObserver {
async fn test_gateway_authenticate() { async fn test_gateway_authenticate() {
let bundle = common::setup().await; let bundle = common::setup().await;
let gateway: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone(), GatewayOptions::default()).await.unwrap(); let gateway: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone(), GatewayOptions::default())
.await
.unwrap();
let (ready_send, mut ready_receive) = tokio::sync::mpsc::channel(1); let (ready_send, mut ready_receive) = tokio::sync::mpsc::channel(1);

View File

@ -60,7 +60,8 @@ async fn guild_create_ban() {
.await .await
.unwrap(); .unwrap();
other_user.accept_invite(&invite.code, None).await.unwrap(); other_user.accept_invite(&invite.code, None).await.unwrap();
let other_user_id = other_user.object.read().unwrap().id; let other_user_id = other_user.object.as_ref().unwrap()
.read().unwrap().id;
Guild::create_ban( Guild::create_ban(
guild.id, guild.id,
other_user_id, other_user_id,
@ -112,7 +113,9 @@ async fn guild_remove_member() {
.await .await
.unwrap(); .unwrap();
other_user.accept_invite(&invite.code, None).await.unwrap(); other_user.accept_invite(&invite.code, None).await.unwrap();
let other_user_id = other_user.object.read().unwrap().id; let other_user_id = other_user.object
.as_ref().unwrap()
.read().unwrap().id;
Guild::remove_member(guild.id, other_user_id, None, &mut bundle.user) Guild::remove_member(guild.id, other_user_id, None, &mut bundle.user)
.await .await
.unwrap(); .unwrap();

View File

@ -16,7 +16,8 @@ async fn add_remove_role() -> ChorusResult<()> {
let mut bundle = common::setup().await; let mut bundle = common::setup().await;
let guild = bundle.guild.read().unwrap().id; let guild = bundle.guild.read().unwrap().id;
let role = bundle.role.read().unwrap().id; let role = bundle.role.read().unwrap().id;
let member_id = bundle.user.object.read().unwrap().id; let member_id = bundle.user.object.as_ref().unwrap()
.read().unwrap().id;
GuildMember::add_role(&mut bundle.user, guild, member_id, role).await?; GuildMember::add_role(&mut bundle.user, guild, member_id, role).await?;
let member = GuildMember::get(&mut bundle.user, guild, member_id) let member = GuildMember::get(&mut bundle.user, guild, member_id)
.await .await

View File

@ -106,7 +106,7 @@ async fn search_messages() {
let _arg = Some(&vec_attach); let _arg = Some(&vec_attach);
let message = bundle.user.send_message(message, channel.id).await.unwrap(); let message = bundle.user.send_message(message, channel.id).await.unwrap();
let query = MessageSearchQuery { let query = MessageSearchQuery {
author_id: Some(Vec::from([bundle.user.object.read().unwrap().id])), author_id: Some(Vec::from([bundle.user.object.as_ref().unwrap().read().unwrap().id])),
..Default::default() ..Default::default()
}; };
let guild_id = bundle.guild.read().unwrap().id; let guild_id = bundle.guild.read().unwrap().id;

View File

@ -16,9 +16,10 @@ async fn test_get_mutual_relationships() {
let mut bundle = common::setup().await; let mut bundle = common::setup().await;
let mut other_user = bundle.create_user("integrationtestuser2").await; let mut other_user = bundle.create_user("integrationtestuser2").await;
let user = &mut bundle.user; let user = &mut bundle.user;
let username = user.object.read().unwrap().username.clone();
let discriminator = user.object.read().unwrap().discriminator.clone(); let username = user.object.as_ref().unwrap().read().unwrap().username.clone();
let other_user_id: types::Snowflake = other_user.object.read().unwrap().id; let discriminator = user.object.as_ref().unwrap().read().unwrap().discriminator.clone();
let other_user_id: types::Snowflake = other_user.object.as_ref().unwrap().read().unwrap().id;
let friend_request_schema = types::FriendRequestSendSchema { let friend_request_schema = types::FriendRequestSendSchema {
username, username,
discriminator: Some(discriminator), discriminator: Some(discriminator),
@ -38,8 +39,8 @@ async fn test_get_relationships() {
let mut bundle = common::setup().await; let mut bundle = common::setup().await;
let mut other_user = bundle.create_user("integrationtestuser2").await; let mut other_user = bundle.create_user("integrationtestuser2").await;
let user = &mut bundle.user; let user = &mut bundle.user;
let username = user.object.read().unwrap().username.clone(); let username = user.object.as_ref().unwrap().read().unwrap().username.clone();
let discriminator = user.object.read().unwrap().discriminator.clone(); let discriminator = user.object.as_ref().unwrap().read().unwrap().discriminator.clone();
let friend_request_schema = types::FriendRequestSendSchema { let friend_request_schema = types::FriendRequestSendSchema {
username, username,
discriminator: Some(discriminator), discriminator: Some(discriminator),
@ -51,7 +52,7 @@ async fn test_get_relationships() {
let relationships = user.get_relationships().await.unwrap(); let relationships = user.get_relationships().await.unwrap();
assert_eq!( assert_eq!(
relationships.get(0).unwrap().id, relationships.get(0).unwrap().id,
other_user.object.read().unwrap().id other_user.object.unwrap().read().unwrap().id
); );
common::teardown(bundle).await common::teardown(bundle).await
} }
@ -62,8 +63,8 @@ async fn test_modify_relationship_friends() {
let mut bundle = common::setup().await; let mut bundle = common::setup().await;
let mut other_user = bundle.create_user("integrationtestuser2").await; let mut other_user = bundle.create_user("integrationtestuser2").await;
let user = &mut bundle.user; let user = &mut bundle.user;
let user_id: types::Snowflake = user.object.read().unwrap().id; let user_id: types::Snowflake = user.object.as_ref().unwrap().read().unwrap().id;
let other_user_id: types::Snowflake = other_user.object.read().unwrap().id; let other_user_id: types::Snowflake = other_user.object.as_ref().unwrap().read().unwrap().id;
other_user other_user
.modify_user_relationship(user_id, types::RelationshipType::Friends) .modify_user_relationship(user_id, types::RelationshipType::Friends)
@ -72,7 +73,7 @@ async fn test_modify_relationship_friends() {
let relationships = user.get_relationships().await.unwrap(); let relationships = user.get_relationships().await.unwrap();
assert_eq!( assert_eq!(
relationships.get(0).unwrap().id, relationships.get(0).unwrap().id,
other_user.object.read().unwrap().id other_user.object.as_ref().unwrap().read().unwrap().id
); );
assert_eq!( assert_eq!(
relationships.get(0).unwrap().relationship_type, relationships.get(0).unwrap().relationship_type,
@ -81,7 +82,7 @@ async fn test_modify_relationship_friends() {
let relationships = other_user.get_relationships().await.unwrap(); let relationships = other_user.get_relationships().await.unwrap();
assert_eq!( assert_eq!(
relationships.get(0).unwrap().id, relationships.get(0).unwrap().id,
user.object.read().unwrap().id user.object.as_ref().unwrap().read().unwrap().id
); );
assert_eq!( assert_eq!(
relationships.get(0).unwrap().relationship_type, relationships.get(0).unwrap().relationship_type,
@ -114,7 +115,7 @@ async fn test_modify_relationship_block() {
let mut bundle = common::setup().await; let mut bundle = common::setup().await;
let mut other_user = bundle.create_user("integrationtestuser2").await; let mut other_user = bundle.create_user("integrationtestuser2").await;
let user = &mut bundle.user; let user = &mut bundle.user;
let user_id: types::Snowflake = user.object.read().unwrap().id; let user_id: types::Snowflake = user.object.as_ref().unwrap().read().unwrap().id;
other_user other_user
.modify_user_relationship(user_id, types::RelationshipType::Blocked) .modify_user_relationship(user_id, types::RelationshipType::Blocked)
@ -125,7 +126,7 @@ async fn test_modify_relationship_block() {
let relationships = other_user.get_relationships().await.unwrap(); let relationships = other_user.get_relationships().await.unwrap();
assert_eq!( assert_eq!(
relationships.get(0).unwrap().id, relationships.get(0).unwrap().id,
user.object.read().unwrap().id user.object.as_ref().unwrap().read().unwrap().id
); );
assert_eq!( assert_eq!(
relationships.get(0).unwrap().relationship_type, relationships.get(0).unwrap().relationship_type,