From 6c9df9508e1b6305873badc8731efa84479c98bf Mon Sep 17 00:00:00 2001 From: Flori Weber Date: Fri, 23 Jun 2023 12:18:22 +0200 Subject: [PATCH 1/8] Fix stupid multi line comments --- src/api/channels/reactions.rs | 170 ++++++++++++++-------------------- 1 file changed, 70 insertions(+), 100 deletions(-) diff --git a/src/api/channels/reactions.rs b/src/api/channels/reactions.rs index c445a8b..c164246 100644 --- a/src/api/channels/reactions.rs +++ b/src/api/channels/reactions.rs @@ -16,20 +16,15 @@ pub struct ReactionMeta { } impl ReactionMeta { - /** - Deletes all reactions for a message. - This endpoint requires the `MANAGE_MESSAGES` permission to be present on the current user. - - # Arguments - * `user` - A mutable reference to a [`UserMeta`] instance. - - # Returns - A `Result` [`()`] [`crate::errors::ChorusLibError`] if something went wrong. - Fires a `Message Reaction Remove All` Gateway event. - - # Reference - See [https://discord.com/developers/docs/resources/channel#delete-all-reactions](https://discord.com/developers/docs/resources/channel#delete-all-reactions) - */ + /// Deletes all reactions for a message. + /// This endpoint requires the `MANAGE_MESSAGES` permission to be present on the current user. + /// # Arguments + /// * `user` - A mutable reference to a [`UserMeta`] instance. + /// # Returns + /// A `Result` [`()`] [`crate::errors::ChorusLibError`] if something went wrong. + /// Fires a `Message Reaction Remove All` Gateway event. + /// # Reference + /// See [https://discord.com/developers/docs/resources/channel#delete-all-reactions](https://discord.com/developers/docs/resources/channel#delete-all-reactions) pub async fn delete_all(&self, user: &mut UserMeta) -> ChorusResult<()> { let url = format!( "{}/channels/{}/messages/{}/reactions/", @@ -41,21 +36,16 @@ impl ReactionMeta { handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await } - /** - Gets a list of users that reacted with a specific emoji to a message. - - # Arguments - * `emoji` - A string slice containing the emoji to search for. The emoji must be URL Encoded or - the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the - format name:id with the emoji name and emoji id. - * `user` - A mutable reference to a [`UserMeta`] instance. - - # Returns - A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong. - - # Reference - See [https://discord.com/developers/docs/resources/channel#get-reactions](https://discord.com/developers/docs/resources/channel#get-reactions) - */ + /// Gets a list of users that reacted with a specific emoji to a message. + /// # Arguments + /// * `emoji` - A string slice containing the emoji to search for. The emoji must be URL Encoded or + /// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the + /// format name:id with the emoji name and emoji id. + /// * `user` - A mutable reference to a [`UserMeta`] instance. + /// # Returns + /// A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong. + /// # Reference + /// See [https://discord.com/developers/docs/resources/channel#get-reactions](https://discord.com/developers/docs/resources/channel#get-reactions) pub async fn get(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> { let url = format!( "{}/channels/{}/messages/{}/reactions/{}/", @@ -68,23 +58,18 @@ impl ReactionMeta { handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await } - /** - Deletes all the reactions for a given `emoji` on a message. This endpoint requires the - MANAGE_MESSAGES permission to be present on the current user. - - # Arguments - * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or - the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the - format name:id with the emoji name and emoji id. - * `user` - A mutable reference to a [`UserMeta`] instance. - - # Returns - A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong. - Fires a `Message Reaction Remove Emoji` Gateway event. - - # Reference - See [https://discord.com/developers/docs/resources/channel#delete-all-reactions-for-emoji](https://discord.com/developers/docs/resources/channel#delete-all-reactions-for-emoji) - */ + /// Deletes all the reactions for a given `emoji` on a message. This endpoint requires the + /// MANAGE_MESSAGES permission to be present on the current user. + /// # Arguments + /// * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or + /// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the + /// format name:id with the emoji name and emoji id. + /// * `user` - A mutable reference to a [`UserMeta`] instance. + /// # Returns + /// A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong. + /// Fires a `Message Reaction Remove Emoji` Gateway event. + /// # Reference + /// See [https://discord.com/developers/docs/resources/channel#delete-all-reactions-for-emoji](https://discord.com/developers/docs/resources/channel#delete-all-reactions-for-emoji) pub async fn delete_emoji(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> { let url = format!( "{}/channels/{}/messages/{}/reactions/{}/", @@ -97,25 +82,21 @@ impl ReactionMeta { handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await } - /** - Create a reaction for the message. - - This endpoint requires the READ_MESSAGE_HISTORY permission - to be present on the current user. Additionally, if nobody else has reacted to the message using - this emoji, this endpoint requires the ADD_REACTIONS permission to be present on the current - user. - # Arguments - * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or - the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the - format name:id with the emoji name and emoji id. - * `user` - A mutable reference to a [`UserMeta`] instance. - - # Returns - A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`]. - - # Reference - See [https://discord.com/developers/docs/resources/channel#create-reaction](https://discord.com/developers/docs/resources/channel#create-reaction) - */ + /// Create a reaction for the message. + /// This endpoint requires the READ_MESSAGE_HISTORY permission + /// to be present on the current user. Additionally, if nobody else has reacted to the message using + /// this emoji, this endpoint requires the ADD_REACTIONS permission to be present on the current + /// user. + /// # Arguments + /// * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or + /// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the + /// format name:id with the emoji name and emoji id. + /// * `user` - A mutable reference to a [`UserMeta`] instance. + /// # Returns + /// A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`]. + /// # Reference + /// See [https://discord.com/developers/docs/resources/channel#create-reaction](https://discord.com/developers/docs/resources/channel#create-reaction) + /// pub async fn create(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> { let url = format!( "{}/channels/{}/messages/{}/reactions/{}/@me/", @@ -128,22 +109,17 @@ impl ReactionMeta { handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await } - /** - Delete a reaction the current user has made for the message. - - # Arguments - * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or - the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the - format name:id with the emoji name and emoji id. - * `user` - A mutable reference to a [`UserMeta`] instance. - - # Returns - A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`]. - Fires a `Message Reaction Remove` Gateway event. - - # Reference - See [https://discord.com/developers/docs/resources/channel#delete-own-reaction](https://discord.com/developers/docs/resources/channel#delete-own-reaction) - */ + /// Delete a reaction the current user has made for the message. + /// # Arguments + /// * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or + /// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the + /// format name:id with the emoji name and emoji id. + /// * `user` - A mutable reference to a [`UserMeta`] instance. + /// # Returns + /// A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`]. + /// Fires a `Message Reaction Remove` Gateway event. + /// # Reference + /// See [https://discord.com/developers/docs/resources/channel#delete-own-reaction](https://discord.com/developers/docs/resources/channel#delete-own-reaction) pub async fn remove(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> { let url = format!( "{}/channels/{}/messages/{}/reactions/{}/@me/", @@ -156,25 +132,19 @@ impl ReactionMeta { handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await } - /** - Delete a user's reaction to a message. - - This endpoint requires the MANAGE_MESSAGES permission to be present on the current user. - - # Arguments - * `user_id` - ID of the user whose reaction is to be deleted. - * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or - the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the - format name:id with the emoji name and emoji id. - * `user` - A mutable reference to a [`UserMeta`] instance. - - # Returns - A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`]. - Fires a Message Reaction Remove Gateway event. - - # Reference - See [https://discord.com/developers/docs/resources/channel#delete-own-reaction](https://discord.com/developers/docs/resources/channel#delete-own-reaction) - */ + /// Delete a user's reaction to a message. + /// This endpoint requires the MANAGE_MESSAGES permission to be present on the current user. + /// # Arguments + /// * `user_id` - ID of the user whose reaction is to be deleted. + /// * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or + /// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the + /// format name:id with the emoji name and emoji id. + /// * `user` - A mutable reference to a [`UserMeta`] instance. + /// # Returns + /// A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`]. + /// Fires a Message Reaction Remove Gateway event. + /// # Reference + /// See [https://discord.com/developers/docs/resources/channel#delete-own-reaction](https://discord.com/developers/docs/resources/channel#delete-own-reaction) pub async fn delete_user( &self, user_id: Snowflake, From c25795ab4df4e768bd807819a3b73c530a8335f1 Mon Sep 17 00:00:00 2001 From: Flori Weber Date: Fri, 23 Jun 2023 12:54:08 +0200 Subject: [PATCH 2/8] Make ReactionMeta::get() return Vec --- src/api/channels/reactions.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/api/channels/reactions.rs b/src/api/channels/reactions.rs index c164246..96a830a 100644 --- a/src/api/channels/reactions.rs +++ b/src/api/channels/reactions.rs @@ -1,10 +1,10 @@ use reqwest::Client; use crate::{ - api::handle_request_as_result, + api::{deserialize_response, handle_request_as_result}, errors::ChorusResult, instance::UserMeta, - types::{self, Snowflake}, + types::{self, PublicUser, Snowflake}, }; /** @@ -46,7 +46,7 @@ impl ReactionMeta { /// A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong. /// # Reference /// See [https://discord.com/developers/docs/resources/channel#get-reactions](https://discord.com/developers/docs/resources/channel#get-reactions) - pub async fn get(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> { + pub async fn get(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult> { let url = format!( "{}/channels/{}/messages/{}/reactions/{}/", user.belongs_to.borrow().urls.api, @@ -55,7 +55,12 @@ impl ReactionMeta { emoji ); let request = Client::new().get(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await + deserialize_response::>( + request, + user, + crate::api::limits::LimitType::Channel, + ) + .await } /// Deletes all the reactions for a given `emoji` on a message. This endpoint requires the From ecf54111a64a522fc9909b2ada771f5e2d28e8d5 Mon Sep 17 00:00:00 2001 From: Flori Weber Date: Fri, 23 Jun 2023 12:54:15 +0200 Subject: [PATCH 3/8] cargo fix --- tests/common/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 900e31e..d4699cb 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,5 +1,4 @@ use chorus::{ - errors::ChorusResult, instance::{Instance, UserMeta}, types::{ Channel, ChannelCreateSchema, Guild, GuildCreateSchema, RegisterSchema, From c0d484efdc605176c60e3be0e0084df47a05541c Mon Sep 17 00:00:00 2001 From: Vincent Junge Date: Sat, 24 Jun 2023 08:52:45 +0200 Subject: [PATCH 4/8] fixed tokio features --- Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index cece3e2..3590341 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ backend = ["poem", "sqlx"] client = [] [dependencies] -tokio = {version = "1.28.1", features = ["rt", "macros", "rt-multi-thread", "full"]} +tokio = {version = "1.28.1"} serde = {version = "1.0.163", features = ["derive"]} serde_json = {version= "1.0.96", features = ["raw_value"]} serde-aux = "4.2.0" @@ -36,5 +36,6 @@ thiserror = "1.0.40" jsonwebtoken = "8.3.0" [dev-dependencies] +tokio = {version = "1.28.1", features = ["full"]} lazy_static = "1.4.0" rusty-hook = "0.11.2" \ No newline at end of file From 8b0f41fad3c8cc0dcb77a2fe21ca4b0bcacfc75a Mon Sep 17 00:00:00 2001 From: Vincent Junge Date: Sun, 25 Jun 2023 11:33:50 +0200 Subject: [PATCH 5/8] remove client side validation --- src/api/auth/login.rs | 19 ++-- src/api/auth/register.rs | 16 +-- src/gateway.rs | 6 +- src/types/schema/auth.rs | 226 +-------------------------------------- src/types/schema/mod.rs | 73 ------------- tests/auth.rs | 12 +-- tests/common/mod.rs | 12 +-- tests/relationships.rs | 42 ++++---- 8 files changed, 52 insertions(+), 354 deletions(-) diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 9670d15..b9c4332 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -26,16 +26,12 @@ impl Instance { self, &mut cloned_limits, ) - .await; - if response.is_err() { - return Err(ChorusLibError::NoResponse); - } + .await?; - let response_unwrap = response.unwrap(); - let status = response_unwrap.status(); - let response_text_string = response_unwrap.text().await.unwrap(); + let status = response.status(); + let response_text = response.text().await.unwrap(); if status.is_client_error() { - let json: ErrorResponse = serde_json::from_str(&response_text_string).unwrap(); + let json: ErrorResponse = serde_json::from_str(&response_text).unwrap(); let error_type = json.errors.errors.iter().next().unwrap().0.to_owned(); let mut error = "".to_string(); for (_, value) in json.errors.errors.iter() { @@ -47,11 +43,8 @@ impl Instance { } let cloned_limits = self.limits.clone(); - let login_result: LoginResult = from_str(&response_text_string).unwrap(); - let object = self - .get_user(login_result.token.clone(), None) - .await - .unwrap(); + let login_result: LoginResult = from_str(&response_text).unwrap(); + let object = self.get_user(login_result.token.clone(), None).await?; let user = UserMeta::new( Rc::new(RefCell::new(self.clone())), login_result.token, diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index b22d6ce..6d9d2ea 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -39,15 +39,11 @@ impl Instance { self, &mut cloned_limits, ) - .await; - if response.is_err() { - return Err(ChorusLibError::NoResponse); - } + .await?; - let response_unwrap = response.unwrap(); - let status = response_unwrap.status(); - let response_unwrap_text = response_unwrap.text().await.unwrap(); - let token = from_str::(&response_unwrap_text).unwrap(); + let status = response.status(); + let response_text = response.text().await.unwrap(); + let token = from_str::(&response_text).unwrap(); let token = token.token; if status.is_client_error() { let json: ErrorResponse = serde_json::from_str(&token).unwrap(); @@ -61,9 +57,7 @@ impl Instance { return Err(ChorusLibError::InvalidFormBodyError { error_type, error }); } let user_object = self.get_user(token.clone(), None).await.unwrap(); - let settings = UserMeta::get_settings(&token, &self.urls.api.clone(), self) - .await - .unwrap(); + let settings = UserMeta::get_settings(&token, &self.urls.api.clone(), self).await?; let user = UserMeta::new( Rc::new(RefCell::new(self.clone())), token.clone(), diff --git a/src/gateway.rs b/src/gateway.rs index 2f90217..ca57544 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1879,7 +1879,7 @@ mod example { #[derive(Debug)] struct Consumer { - name: String, + _name: String, events_received: AtomicI32, } @@ -1900,13 +1900,13 @@ mod example { }; let consumer = Arc::new(Consumer { - name: "first".into(), + _name: "first".into(), events_received: 0.into(), }); event.subscribe(consumer.clone()); let second_consumer = Arc::new(Consumer { - name: "second".into(), + _name: "second".into(), events_received: 0.into(), }); event.subscribe(second_consumer.clone()); diff --git a/src/types/schema/auth.rs b/src/types/schema/auth.rs index 8b8a601..3fe0604 100644 --- a/src/types/schema/auth.rs +++ b/src/types/schema/auth.rs @@ -1,122 +1,8 @@ -use regex::Regex; use serde::{Deserialize, Serialize}; -use crate::errors::FieldFormatError; - -/** -A struct that represents a well-formed email address. - */ -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct AuthEmail { - pub email: String, -} - -impl AuthEmail { - /** - Returns a new [`Result`]. - ## Arguments - The email address you want to validate. - ## Errors - You will receive a [`FieldFormatError`], if: - - The email address is not in a valid format. - - */ - pub fn new(email: String) -> Result { - let regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap(); - if !regex.is_match(email.as_str()) { - return Err(FieldFormatError::EmailError); - } - Ok(AuthEmail { email }) - } -} - -/** -A struct that represents a well-formed username. -## Arguments -Please use new() to create a new instance of this struct. -## Errors -You will receive a [`FieldFormatError`], if: -- The username is not between 2 and 32 characters. - */ -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct AuthUsername { - pub username: String, -} - -impl AuthUsername { - /** - Returns a new [`Result`]. - ## Arguments - The username you want to validate. - ## Errors - You will receive a [`FieldFormatError`], if: - - The username is not between 2 and 32 characters. - */ - pub fn new(username: String) -> Result { - if username.len() < 2 || username.len() > 32 { - Err(FieldFormatError::UsernameError) - } else { - Ok(AuthUsername { username }) - } - } -} - -/** -A struct that represents a well-formed password. -## Arguments -Please use new() to create a new instance of this struct. -## Errors -You will receive a [`FieldFormatError`], if: -- The password is not between 1 and 72 characters. - */ -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct AuthPassword { - pub password: String, -} - -impl AuthPassword { - /** - Returns a new [`Result`]. - ## Arguments - The password you want to validate. - ## Errors - You will receive a [`FieldFormatError`], if: - - The password is not between 1 and 72 characters. - */ - pub fn new(password: String) -> Result { - if password.is_empty() || password.len() > 72 { - Err(FieldFormatError::PasswordError) - } else { - Ok(AuthPassword { password }) - } - } -} - -/** -A struct that represents a well-formed register request. -## Arguments -Please use new() to create a new instance of this struct. -## Errors -You will receive a [`FieldFormatError`], if: -- The username is not between 2 and 32 characters. -- The password is not between 1 and 72 characters. - */ -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Default, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub struct RegisterSchema { - username: String, - password: Option, - consent: bool, - email: Option, - fingerprint: Option, - invite: Option, - date_of_birth: Option, - gift_code_sku_id: Option, - captcha_key: Option, - promotional_email_opt_in: Option, -} - -pub struct RegisterSchemaOptions { pub username: String, pub password: Option, pub consent: bool, @@ -129,83 +15,14 @@ pub struct RegisterSchemaOptions { pub promotional_email_opt_in: Option, } -impl RegisterSchema { - pub fn builder(username: impl Into, consent: bool) -> RegisterSchemaOptions { - RegisterSchemaOptions { - username: username.into(), - password: None, - consent, - email: None, - fingerprint: None, - invite: None, - date_of_birth: None, - gift_code_sku_id: None, - captcha_key: None, - promotional_email_opt_in: None, - } - } -} - -impl RegisterSchemaOptions { - /** - Create a new [`RegisterSchema`]. - ## Arguments - All but "String::username" and "bool::consent" are optional. - - ## Errors - You will receive a [`FieldFormatError`], if: - - The username is less than 2 or more than 32 characters in length - - You supply a `password` which is less than 1 or more than 72 characters in length. - - These constraints have been defined [in the Spacebar-API](https://docs.spacebar.chat/routes/) - */ - pub fn build(self) -> Result { - let username = AuthUsername::new(self.username)?.username; - - let email = if let Some(email) = self.email { - Some(AuthEmail::new(email)?.email) - } else { - None - }; - - let password = if let Some(password) = self.password { - Some(AuthPassword::new(password)?.password) - } else { - None - }; - - if !self.consent { - return Err(FieldFormatError::ConsentError); - } - - Ok(RegisterSchema { - username, - password, - consent: self.consent, - email, - fingerprint: self.fingerprint, - invite: self.invite, - date_of_birth: self.date_of_birth, - gift_code_sku_id: self.gift_code_sku_id, - captcha_key: self.captcha_key, - promotional_email_opt_in: self.promotional_email_opt_in, - }) - } -} - -/** -A struct that represents a well-formed login request. -## Arguments -Please use new() to create a new instance of this struct. -## Errors -You will receive a [`FieldFormatError`], if: -- The username is not between 2 and 32 characters. -- The password is not between 1 and 72 characters. - */ #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub struct LoginSchema { + /// For Discord, usernames must be between 2 and 32 characters, + /// but other servers may have different limits. pub login: String, + /// For Discord, must be between 1 and 72 characters, + /// but other servers may have different limits. pub password: Option, pub undelete: Option, pub captcha_key: Option, @@ -213,39 +30,6 @@ pub struct LoginSchema { pub gift_code_sku_id: Option, } -impl LoginSchema { - /** - Returns a new [`Result`]. - ## Arguments - login: The username you want to login with. - password: The password you want to login with. - undelete: Honestly no idea what this is for. - captcha_key: The captcha key you want to login with. - login_source: The login source. - gift_code_sku_id: The gift code sku id. - ## Errors - You will receive a [`FieldFormatError`], if: - - The username is less than 2 or more than 32 characters in length - */ - pub fn new( - login: String, - password: Option, - undelete: Option, - captcha_key: Option, - login_source: Option, - gift_code_sku_id: Option, - ) -> Result { - Ok(LoginSchema { - login, - password, - undelete, - captcha_key, - login_source, - gift_code_sku_id, - }) - } -} - #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct TotpSchema { diff --git a/src/types/schema/mod.rs b/src/types/schema/mod.rs index 1069428..08dae05 100644 --- a/src/types/schema/mod.rs +++ b/src/types/schema/mod.rs @@ -15,76 +15,3 @@ mod message; mod relationship; mod role; mod user; - -#[cfg(test)] -mod schemas_tests { - use crate::errors::FieldFormatError; - - use super::*; - - #[test] - fn password_too_short() { - assert_eq!( - AuthPassword::new("".to_string()), - Err(FieldFormatError::PasswordError) - ); - } - - #[test] - fn password_too_long() { - let mut long_pw = String::new(); - for _ in 0..73 { - long_pw += "a"; - } - assert_eq!( - AuthPassword::new(long_pw), - Err(FieldFormatError::PasswordError) - ); - } - - #[test] - fn username_too_short() { - assert_eq!( - AuthUsername::new("T".to_string()), - Err(FieldFormatError::UsernameError) - ); - } - - #[test] - fn username_too_long() { - let mut long_un = String::new(); - for _ in 0..33 { - long_un += "a"; - } - assert_eq!( - AuthUsername::new(long_un), - Err(FieldFormatError::UsernameError) - ); - } - - #[test] - fn consent_false() { - assert_eq!( - RegisterSchema::builder("Test", false).build(), - Err(FieldFormatError::ConsentError) - ); - } - - #[test] - fn invalid_email() { - assert_eq!( - AuthEmail::new("p@p.p".to_string()), - Err(FieldFormatError::EmailError) - ) - } - - #[test] - fn valid_email() { - let reg = RegisterSchemaOptions { - email: Some("me@mail.de".to_string()), - ..RegisterSchema::builder("Testy", true) - } - .build(); - assert_ne!(reg, Err(FieldFormatError::EmailError)); - } -} diff --git a/tests/auth.rs b/tests/auth.rs index 6972ace..c26552f 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -1,16 +1,16 @@ -use chorus::types::{RegisterSchema, RegisterSchemaOptions}; +use chorus::types::RegisterSchema; mod common; #[tokio::test] async fn test_registration() { let mut bundle = common::setup().await; - let reg = RegisterSchemaOptions { + let reg = RegisterSchema { + username: "Hiiii".into(), date_of_birth: Some("2000-01-01".to_string()), - ..RegisterSchema::builder("Hiiii", true) - } - .build() - .unwrap(); + consent: true, + ..Default::default() + }; bundle.instance.register_account(®).await.unwrap(); common::teardown(bundle).await; } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index d4699cb..9a62585 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -2,7 +2,7 @@ use chorus::{ instance::{Instance, UserMeta}, types::{ Channel, ChannelCreateSchema, Guild, GuildCreateSchema, RegisterSchema, - RegisterSchemaOptions, RoleCreateModifySchema, RoleObject, + RoleCreateModifySchema, RoleObject, }, UrlBundle, }; @@ -26,12 +26,12 @@ pub async fn setup() -> TestBundle { ); let mut instance = Instance::new(urls.clone()).await.unwrap(); // Requires the existance of the below user. - let reg = RegisterSchemaOptions { + let reg = RegisterSchema { + username: "integrationtestuser".into(), + consent: true, date_of_birth: Some("2000-01-01".to_string()), - ..RegisterSchema::builder("integrationtestuser", true) - } - .build() - .unwrap(); + ..Default::default() + }; let guild_create_schema = GuildCreateSchema { name: Some("Test-Guild!".to_string()), region: None, diff --git a/tests/relationships.rs b/tests/relationships.rs index 81f3230..e23f0f3 100644 --- a/tests/relationships.rs +++ b/tests/relationships.rs @@ -1,15 +1,15 @@ -use chorus::types::{self, RegisterSchema, RegisterSchemaOptions, Relationship, RelationshipType}; +use chorus::types::{self, RegisterSchema, Relationship, RelationshipType}; mod common; #[tokio::test] async fn test_get_mutual_relationships() { - let register_schema = RegisterSchemaOptions { + let register_schema = RegisterSchema { + username: "integrationtestuser2".to_string(), + consent: true, date_of_birth: Some("2000-01-01".to_string()), - ..RegisterSchema::builder("integrationtestuser2", true) - } - .build() - .unwrap(); + ..Default::default() + }; let mut bundle = common::setup().await; let belongs_to = &mut bundle.instance; @@ -30,12 +30,12 @@ async fn test_get_mutual_relationships() { #[tokio::test] async fn test_get_relationships() { - let register_schema = RegisterSchemaOptions { + let register_schema = RegisterSchema { + username: "integrationtestuser2".to_string(), + consent: true, date_of_birth: Some("2000-01-01".to_string()), - ..RegisterSchema::builder("integrationtestuser2", true) - } - .build() - .unwrap(); + ..Default::default() + }; let mut bundle = common::setup().await; let belongs_to = &mut bundle.instance; @@ -53,12 +53,12 @@ async fn test_get_relationships() { #[tokio::test] async fn test_modify_relationship_friends() { - let register_schema = RegisterSchemaOptions { + let register_schema = RegisterSchema { + username: "integrationtestuser2".to_string(), + consent: true, date_of_birth: Some("2000-01-01".to_string()), - ..RegisterSchema::builder("integrationtestuser2", true) - } - .build() - .unwrap(); + ..Default::default() + }; let mut bundle = common::setup().await; let belongs_to = &mut bundle.instance; @@ -101,12 +101,12 @@ async fn test_modify_relationship_friends() { #[tokio::test] async fn test_modify_relationship_block() { - let register_schema = RegisterSchemaOptions { + let register_schema = RegisterSchema { + username: "integrationtestuser2".to_string(), + consent: true, date_of_birth: Some("2000-01-01".to_string()), - ..RegisterSchema::builder("integrationtestuser2", true) - } - .build() - .unwrap(); + ..Default::default() + }; let mut bundle = common::setup().await; let belongs_to = &mut bundle.instance; From 94a051631f04c3d63dbffb635b86e46eb3b596bf Mon Sep 17 00:00:00 2001 From: Vincent Junge Date: Sun, 25 Jun 2023 22:34:05 +0200 Subject: [PATCH 6/8] require password to log in --- src/types/schema/auth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types/schema/auth.rs b/src/types/schema/auth.rs index 3fe0604..9a3b9d6 100644 --- a/src/types/schema/auth.rs +++ b/src/types/schema/auth.rs @@ -23,7 +23,7 @@ pub struct LoginSchema { pub login: String, /// For Discord, must be between 1 and 72 characters, /// but other servers may have different limits. - pub password: Option, + pub password: String, pub undelete: Option, pub captcha_key: Option, pub login_source: Option, From 3f6d0e2d9dc720fd81b05bcf4d302717a2424a28 Mon Sep 17 00:00:00 2001 From: kozabrada123 <59031733+kozabrada123@users.noreply.github.com> Date: Sat, 1 Jul 2023 19:29:50 +0200 Subject: [PATCH 7/8] Use log instead of prints --- Cargo.toml | 1 + src/gateway.rs | 184 ++++++++++++++++++++++++------------------------- 2 files changed, 93 insertions(+), 92 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3590341..d7fdf0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ poem = { version = "1.3.55", optional = true } sqlx = { git = "https://github.com/zert3x/sqlx", branch="feature/skip", features = ["mysql", "sqlite", "json", "chrono", "ipnetwork", "runtime-tokio-native-tls", "any"], optional = true } thiserror = "1.0.40" jsonwebtoken = "8.3.0" +log = "0.4.19" [dev-dependencies] tokio = {version = "1.28.1", features = ["full"]} diff --git a/src/gateway.rs b/src/gateway.rs index ca57544..87e1589 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -8,6 +8,7 @@ use futures_util::stream::SplitSink; use futures_util::stream::SplitStream; use futures_util::SinkExt; use futures_util::StreamExt; +use log::{debug, info, trace, warn}; use native_tls::TlsConnector; use tokio::net::TcpStream; use tokio::sync::mpsc::error::TryRecvError; @@ -191,7 +192,7 @@ impl GatewayHandle { pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { let to_send_value = serde_json::to_value(&to_send).unwrap(); - println!("GW: Sending Identify.."); + trace!("GW: Sending Identify.."); self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; } @@ -200,7 +201,7 @@ impl GatewayHandle { pub async fn send_resume(&self, to_send: types::GatewayResume) { let to_send_value = serde_json::to_value(&to_send).unwrap(); - println!("GW: Sending Resume.."); + trace!("GW: Sending Resume.."); self.send_json_event(GATEWAY_RESUME, to_send_value).await; } @@ -209,7 +210,7 @@ impl GatewayHandle { pub async fn send_update_presence(&self, to_send: types::UpdatePresence) { let to_send_value = serde_json::to_value(&to_send).unwrap(); - println!("GW: Sending Update Presence.."); + trace!("GW: Sending Update Presence.."); self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) .await; @@ -219,7 +220,7 @@ impl GatewayHandle { pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { let to_send_value = serde_json::to_value(&to_send).unwrap(); - println!("GW: Sending Request Guild Members.."); + trace!("GW: Sending Request Guild Members.."); self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) .await; @@ -229,7 +230,7 @@ impl GatewayHandle { pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { let to_send_value = serde_json::to_value(&to_send).unwrap(); - println!("GW: Sending Update Voice State.."); + trace!("GW: Sending Update Voice State.."); self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) .await; @@ -239,7 +240,7 @@ impl GatewayHandle { pub async fn send_call_sync(&self, to_send: types::CallSync) { let to_send_value = serde_json::to_value(&to_send).unwrap(); - println!("GW: Sending Call Sync.."); + trace!("GW: Sending Call Sync.."); self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; } @@ -248,7 +249,7 @@ impl GatewayHandle { pub async fn send_lazy_request(&self, to_send: types::LazyRequest) { let to_send_value = serde_json::to_value(&to_send).unwrap(); - println!("GW: Sending Lazy Request.."); + trace!("GW: Sending Lazy Request.."); self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) .await; @@ -318,7 +319,7 @@ impl Gateway { }); } - println!("GW: Received Hello"); + info!("GW: Received Hello"); let gateway_hello: types::HelloData = serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap(); @@ -367,7 +368,7 @@ impl Gateway { } // We couldn't receive the next message or it was an error, something is wrong with the websocket, close - println!("GW: Websocket is broken, stopping gateway"); + warn!("GW: Websocket is broken, stopping gateway"); break; } } @@ -401,7 +402,7 @@ impl Gateway { } if !msg.is_error() && !msg.is_payload() { - println!( + warn!( "Message unrecognised: {:?}, please open an issue on the chorus github", msg.message.to_string() ); @@ -410,12 +411,10 @@ impl Gateway { // To:do: handle errors in a good way, maybe observers like events? if msg.is_error() { - println!("GW: Received error, connection will close.."); + warn!("GW: Received error, connection will close.."); let _error = msg.error(); - {} - self.close().await; return; } @@ -428,7 +427,7 @@ impl Gateway { GATEWAY_DISPATCH => { let gateway_payload_t = gateway_payload.clone().event_name.unwrap(); - println!("GW: Received {}..", gateway_payload_t); + trace!("GW: Received {}..", gateway_payload_t); //println!("Event data dump: {}", gateway_payload.d.clone().unwrap().get()); @@ -443,7 +442,7 @@ impl Gateway { .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -459,7 +458,7 @@ impl Gateway { .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -481,7 +480,7 @@ impl Gateway { .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -495,7 +494,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -509,7 +508,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -523,7 +522,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -537,7 +536,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -551,7 +550,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -565,7 +564,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -579,7 +578,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -593,7 +592,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -607,7 +606,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -621,7 +620,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -635,7 +634,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -649,7 +648,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -663,7 +662,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -677,7 +676,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -691,7 +690,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -705,7 +704,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -719,7 +718,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -733,7 +732,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -747,7 +746,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -761,7 +760,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -775,7 +774,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -789,7 +788,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -803,7 +802,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -817,7 +816,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -831,7 +830,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -845,7 +844,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -859,7 +858,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -873,7 +872,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -887,7 +886,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -901,7 +900,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -915,7 +914,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -929,7 +928,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -943,7 +942,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -957,7 +956,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -971,7 +970,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -985,7 +984,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -999,7 +998,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1014,7 +1013,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1033,7 +1032,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1047,7 +1046,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1061,7 +1060,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1075,7 +1074,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1089,7 +1088,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1103,7 +1102,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1117,7 +1116,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1131,7 +1130,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1145,7 +1144,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1159,7 +1158,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1173,7 +1172,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1187,7 +1186,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1201,7 +1200,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1215,7 +1214,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1229,7 +1228,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1243,7 +1242,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1257,7 +1256,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1271,7 +1270,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1285,7 +1284,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1299,7 +1298,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1313,7 +1312,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1327,7 +1326,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1341,7 +1340,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1353,7 +1352,7 @@ impl Gateway { let result: Result, serde_json::Error> = serde_json::from_str(gateway_payload.event_data.unwrap().get()); if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1373,7 +1372,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1387,7 +1386,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1401,7 +1400,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1415,7 +1414,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1429,7 +1428,7 @@ impl Gateway { Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) .await; if result.is_err() { - println!( + warn!( "Failed to parse gateway event {} ({})", gateway_payload_t, result.err().unwrap() @@ -1438,14 +1437,14 @@ impl Gateway { } } _ => { - println!("Received unrecognized gateway event ({})! Please open an issue on the chorus github so we can implement it", &gateway_payload_t); + warn!("Received unrecognized gateway event ({})! Please open an issue on the chorus github so we can implement it", &gateway_payload_t); } } } // We received a heartbeat from the server // "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately." GATEWAY_HEARTBEAT => { - println!("GW: Received Heartbeat // Heartbeat Request"); + trace!("GW: Received Heartbeat // Heartbeat Request"); // Tell the heartbeat handler it should send a heartbeat right away @@ -1469,10 +1468,10 @@ impl Gateway { // Starts our heartbeat // We should have already handled this in gateway init GATEWAY_HELLO => { - panic!("Received hello when it was unexpected"); + warn!("Received hello when it was unexpected"); } GATEWAY_HEARTBEAT_ACK => { - println!("GW: Received Heartbeat ACK"); + trace!("GW: Received Heartbeat ACK"); // Tell the heartbeat handler we received an ack @@ -1500,7 +1499,7 @@ impl Gateway { Err::<(), GatewayError>(error).unwrap(); } _ => { - println!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); + warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); } } @@ -1588,6 +1587,7 @@ impl HeartbeatHandler { loop { let should_shutdown = kill_receive.try_recv().is_ok(); if should_shutdown { + trace!("GW: Closing heartbeat task"); break; } @@ -1627,11 +1627,11 @@ impl HeartbeatHandler { && last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT { should_send = true; - println!("GW: Timed out waiting for a heartbeat ack, resending"); + info!("GW: Timed out waiting for a heartbeat ack, resending"); } if should_send { - println!("GW: Sending Heartbeat.."); + trace!("GW: Sending Heartbeat.."); let heartbeat = types::GatewayHeartbeat { op: GATEWAY_HEARTBEAT, @@ -1645,7 +1645,7 @@ impl HeartbeatHandler { let send_result = websocket_tx.lock().await.send(msg).await; if send_result.is_err() { // We couldn't send, the websocket is broken - println!("GW: Couldnt send heartbeat, websocket seems broken"); + warn!("GW: Couldnt send heartbeat, websocket seems broken"); break; } From 69b7c2445cb96a6b2830032cf79c0c1090d05a65 Mon Sep 17 00:00:00 2001 From: Flori <39242991+bitfl0wer@users.noreply.github.com> Date: Sun, 9 Jul 2023 18:38:02 +0200 Subject: [PATCH 8/8] Ratelimiter overhaul (#144) * Rename limits and limit to have better names * Remove empty lines * Remove handle_request (moved to requestlimiter) * Start working on new ratelimiter * Make limits Option, add "limited?" to constructor * Add missing logic to send_request * Rename Limits * Create Ratelimits and Limit Struct * Define Limit * Import Ratelimits * Define get_rate_limits * Remove unused import * + check_rate_limits & limits_config_to_ratelimits * Remove Absolute Limits These limits are not meant to be tracked anyways. * add ratelimits is_exhausted * Add error handling and send request checking * change limits to option ratelimits * Add strum * Change Ratelimits to Hashmap * Remove ratelimits in favor of hashmap * Change code from struct to hashmap * start working on update rate limits * Remove wrong import * Rename ChorusLibError to ChorusError * Documented the chorus errors * Made error documentation docstring * Make ReceivedErrorCodeError have error string * Remove unneeded import * Match changes in errors.rs * Improve update_rate_limits and can_send_request * add ratelimits.to_hash_map() * use instances' client instead of new client * add LimitsConfiguration to instance * improve update_limits, change a method name * Fix un-updated errors * Get LimitConfiguration in a sane way * Move common.rs into ratelimiter::ChorusRequest * Delete common.rs * Make instance.rs use overhauled errors * Refactor to use new Rate limiting implementation * Refactor to use new Rate limiting implementation * Refactor to use new Rate limiting implementation * Refactor to use new Rate limiting implementation * Refactor to use new Rate limiting implementation * Refactor to use new Rate limiting implementation * update ratelimiter implementation across all files * Fix remaining errors post-refactor * Changed Enum case to be correct * Use result * Re-add missing body to request * Remove unneeded late initalization * Change visibility from pub to pub(crate) I feel like these core methods don't need to be exposed as public API. * Remove unnecessary import * Fix clippy warnings * Add docstring * Change Error names across all files * Update Cargo.toml Strum is not needed * Update ratelimits.rs * Update ratelimits.rs * Bug/discord instance info unavailable (#146) * Change text to be more ambigous * Use default Configuration instead of erroring out * Emit warning log if instance config cant be gotten * Remove import * Update src/instance.rs Co-authored-by: SpecificProtagonist * Add missing closing bracket * Put limits and limits_configuration as one struct * Derive Hash * remove import * rename limits and limits_configuration * Save clone call * Change LimitsConfiguration to RateLimits `LimitsConfiguration` is in no way related to whether the instance has API rate limits enabled or not. Therefore, it has been replaced with what it should have been all along. * Add ensure_limit_in_map(), add `window` to `Limit` * Remove unneeded var * Remove import * Clean up unneeded things Dead code warnings have been supressed, but flagged as FIXME so they don't get forgotten. Anyone using tools like TODO Tree in VSCode can still see that they are there, however, they will not be shown as warnings anymore * Remove nested submodule `limit` * Add doc comments * Add more doc comments * Add some log messages to some methods --------- Co-authored-by: SpecificProtagonist --- Cargo.toml | 2 +- src/api/auth/login.rs | 54 +- src/api/auth/register.rs | 51 +- src/api/channels/channels.rs | 102 ++-- src/api/channels/messages.rs | 36 +- src/api/channels/permissions.rs | 35 +- src/api/channels/reactions.rs | 52 +- src/api/common.rs | 72 --- src/api/guilds/guilds.rs | 161 ++---- src/api/guilds/member.rs | 33 +- src/api/guilds/roles.rs | 78 +-- src/api/mod.rs | 4 +- src/api/policies/instance/instance.rs | 15 +- src/api/policies/instance/limits.rs | 499 ------------------ src/api/policies/instance/mod.rs | 4 +- src/api/policies/instance/ratelimits.rs | 35 ++ src/api/policies/mod.rs | 2 +- src/api/users/relationships.rs | 82 +-- src/api/users/users.rs | 85 ++- src/errors.rs | 81 +-- src/gateway.rs | 41 +- src/instance.rs | 88 ++- src/lib.rs | 2 +- src/limit.rs | 304 ----------- src/ratelimiter.rs | 462 ++++++++++++++++ .../config/types/general_configuration.rs | 6 +- .../config/types/subconfigs/limits/rates.rs | 24 +- src/types/entities/user.rs | 5 +- src/types/utils/rights.rs | 1 + src/types/utils/snowflake.rs | 2 +- tests/common/mod.rs | 2 +- tests/relationships.rs | 18 +- 32 files changed, 1049 insertions(+), 1389 deletions(-) delete mode 100644 src/api/common.rs delete mode 100644 src/api/policies/instance/limits.rs create mode 100644 src/api/policies/instance/ratelimits.rs delete mode 100644 src/limit.rs create mode 100644 src/ratelimiter.rs diff --git a/Cargo.toml b/Cargo.toml index d7fdf0c..5cffce1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,4 +39,4 @@ log = "0.4.19" [dev-dependencies] tokio = {version = "1.28.1", features = ["full"]} lazy_static = "1.4.0" -rusty-hook = "0.11.2" \ No newline at end of file +rusty-hook = "0.11.2" diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index b9c4332..baae4ae 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -2,57 +2,41 @@ use std::cell::RefCell; use std::rc::Rc; use reqwest::Client; -use serde_json::{from_str, json}; +use serde_json::to_string; -use crate::api::limits::LimitType; -use crate::errors::{ChorusLibError, ChorusResult}; +use crate::api::LimitType; +use crate::errors::ChorusResult; use crate::instance::{Instance, UserMeta}; -use crate::limit::LimitedRequester; -use crate::types::{ErrorResponse, LoginResult, LoginSchema}; +use crate::ratelimiter::ChorusRequest; +use crate::types::{LoginResult, LoginSchema}; impl Instance { pub async fn login_account(&mut self, login_schema: &LoginSchema) -> ChorusResult { - let json_schema = json!(login_schema); - let client = Client::new(); let endpoint_url = self.urls.api.clone() + "/auth/login"; - let request_builder = client.post(endpoint_url).body(json_schema.to_string()); + let chorus_request = ChorusRequest { + request: Client::new() + .post(endpoint_url) + .body(to_string(login_schema).unwrap()), + limit_type: LimitType::AuthLogin, + }; // We do not have a user yet, and the UserRateLimits will not be affected by a login // request (since login is an instance wide limit), which is why we are just cloning the // instances' limits to pass them on as user_rate_limits later. - let mut cloned_limits = self.limits.clone(); - let response = LimitedRequester::send_request( - request_builder, - LimitType::AuthRegister, - self, - &mut cloned_limits, - ) - .await?; - - let status = response.status(); - let response_text = response.text().await.unwrap(); - if status.is_client_error() { - let json: ErrorResponse = serde_json::from_str(&response_text).unwrap(); - let error_type = json.errors.errors.iter().next().unwrap().0.to_owned(); - let mut error = "".to_string(); - for (_, value) in json.errors.errors.iter() { - for error_item in value._errors.iter() { - error += &(error_item.message.to_string() + " (" + &error_item.code + ")"); - } - } - return Err(ChorusLibError::InvalidFormBodyError { error_type, error }); - } - - let cloned_limits = self.limits.clone(); - let login_result: LoginResult = from_str(&response_text).unwrap(); + let mut shell = UserMeta::shell(Rc::new(RefCell::new(self.clone())), "None".to_string()); + let login_result = chorus_request + .deserialize_response::(&mut shell) + .await?; let object = self.get_user(login_result.token.clone(), None).await?; + if self.limits_information.is_some() { + self.limits_information.as_mut().unwrap().ratelimits = shell.limits.clone().unwrap(); + } let user = UserMeta::new( Rc::new(RefCell::new(self.clone())), login_result.token, - cloned_limits, + self.clone_limits_if_some(), login_result.settings, object, ); - Ok(user) } } diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index 6d9d2ea..a6849cc 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -1,14 +1,14 @@ use std::{cell::RefCell, rc::Rc}; use reqwest::Client; -use serde_json::{from_str, json}; +use serde_json::to_string; use crate::{ - api::limits::LimitType, - errors::{ChorusLibError, ChorusResult}, + api::policies::instance::LimitType, + errors::ChorusResult, instance::{Instance, Token, UserMeta}, - limit::LimitedRequester, - types::{ErrorResponse, RegisterSchema}, + ratelimiter::ChorusRequest, + types::RegisterSchema, }; impl Instance { @@ -25,43 +25,30 @@ impl Instance { &mut self, register_schema: &RegisterSchema, ) -> ChorusResult { - let json_schema = json!(register_schema); - let client = Client::new(); let endpoint_url = self.urls.api.clone() + "/auth/register"; - let request_builder = client.post(endpoint_url).body(json_schema.to_string()); + let chorus_request = ChorusRequest { + request: Client::new() + .post(endpoint_url) + .body(to_string(register_schema).unwrap()), + limit_type: LimitType::AuthRegister, + }; // We do not have a user yet, and the UserRateLimits will not be affected by a login // request (since register is an instance wide limit), which is why we are just cloning // the instances' limits to pass them on as user_rate_limits later. - let mut cloned_limits = self.limits.clone(); - let response = LimitedRequester::send_request( - request_builder, - LimitType::AuthRegister, - self, - &mut cloned_limits, - ) - .await?; - - let status = response.status(); - let response_text = response.text().await.unwrap(); - let token = from_str::(&response_text).unwrap(); - let token = token.token; - if status.is_client_error() { - let json: ErrorResponse = serde_json::from_str(&token).unwrap(); - let error_type = json.errors.errors.iter().next().unwrap().0.to_owned(); - let mut error = "".to_string(); - for (_, value) in json.errors.errors.iter() { - for error_item in value._errors.iter() { - error += &(error_item.message.to_string() + " (" + &error_item.code + ")"); - } - } - return Err(ChorusLibError::InvalidFormBodyError { error_type, error }); + let mut shell = UserMeta::shell(Rc::new(RefCell::new(self.clone())), "None".to_string()); + let token = chorus_request + .deserialize_response::(&mut shell) + .await? + .token; + if self.limits_information.is_some() { + self.limits_information.as_mut().unwrap().ratelimits = shell.limits.unwrap(); } let user_object = self.get_user(token.clone(), None).await.unwrap(); let settings = UserMeta::get_settings(&token, &self.urls.api.clone(), self).await?; let user = UserMeta::new( Rc::new(RefCell::new(self.clone())), token.clone(), - cloned_limits, + self.clone_limits_if_some(), settings, user_object, ); diff --git a/src/api/channels/channels.rs b/src/api/channels/channels.rs index cbe481c..82662e3 100644 --- a/src/api/channels/channels.rs +++ b/src/api/channels/channels.rs @@ -2,32 +2,23 @@ use reqwest::Client; use serde_json::to_string; use crate::{ - api::common, - errors::{ChorusLibError, ChorusResult}, + api::LimitType, + errors::{ChorusError, ChorusResult}, instance::UserMeta, + ratelimiter::ChorusRequest, types::{Channel, ChannelModifySchema, GetChannelMessagesSchema, Message, Snowflake}, }; impl Channel { pub async fn get(user: &mut UserMeta, channel_id: Snowflake) -> ChorusResult { - let url = user.belongs_to.borrow_mut().urls.api.clone(); - let request = Client::new() - .get(format!("{}/channels/{}/", url, channel_id)) - .bearer_auth(user.token()); - - let result = common::deserialize_response::( - request, - user, - crate::api::limits::LimitType::Channel, - ) - .await; - if result.is_err() { - return Err(ChorusLibError::RequestErrorError { - url: format!("{}/channels/{}/", url, channel_id), - error: result.err().unwrap().to_string(), - }); - } - Ok(result.unwrap()) + let url = user.belongs_to.borrow().urls.api.clone(); + let chorus_request = ChorusRequest { + request: Client::new() + .get(format!("{}/channels/{}/", url, channel_id)) + .bearer_auth(user.token()), + limit_type: LimitType::Channel(channel_id), + }; + chorus_request.deserialize_response::(user).await } /// Deletes a channel. @@ -44,15 +35,17 @@ impl Channel { /// /// A `Result` that contains a `ChorusLibError` if an error occurred during the request, or `()` if the request was successful. pub async fn delete(self, user: &mut UserMeta) -> ChorusResult<()> { - let request = Client::new() - .delete(format!( - "{}/channels/{}/", - user.belongs_to.borrow_mut().urls.api, - self.id - )) - .bearer_auth(user.token()); - common::handle_request_as_result(request, user, crate::api::limits::LimitType::Channel) - .await + let chorus_request = ChorusRequest { + request: Client::new() + .delete(format!( + "{}/channels/{}/", + user.belongs_to.borrow().urls.api, + self.id + )) + .bearer_auth(user.token()), + limit_type: LimitType::Channel(self.id), + }; + chorus_request.handle_request_as_result(user).await } /// Modifies a channel. @@ -75,20 +68,18 @@ impl Channel { channel_id: Snowflake, user: &mut UserMeta, ) -> ChorusResult<()> { - let request = Client::new() - .patch(format!( - "{}/channels/{}/", - user.belongs_to.borrow().urls.api, - channel_id - )) - .bearer_auth(user.token()) - .body(to_string(&modify_data).unwrap()); - let new_channel = common::deserialize_response::( - request, - user, - crate::api::limits::LimitType::Channel, - ) - .await?; + let chorus_request = ChorusRequest { + request: Client::new() + .patch(format!( + "{}/channels/{}/", + user.belongs_to.borrow().urls.api, + channel_id + )) + .bearer_auth(user.token()) + .body(to_string(&modify_data).unwrap()), + limit_type: LimitType::Channel(channel_id), + }; + let new_channel = chorus_request.deserialize_response::(user).await?; let _ = std::mem::replace(self, new_channel); Ok(()) } @@ -97,16 +88,21 @@ impl Channel { range: GetChannelMessagesSchema, channel_id: Snowflake, user: &mut UserMeta, - ) -> Result, ChorusLibError> { - let request = Client::new() - .get(format!( - "{}/channels/{}/messages", - user.belongs_to.borrow().urls.api, - channel_id - )) - .bearer_auth(user.token()) - .query(&range); + ) -> Result, ChorusError> { + let chorus_request = ChorusRequest { + request: Client::new() + .get(format!( + "{}/channels/{}/messages", + user.belongs_to.borrow().urls.api, + channel_id + )) + .bearer_auth(user.token()) + .query(&range), + limit_type: Default::default(), + }; - common::deserialize_response::>(request, user, Default::default()).await + chorus_request + .deserialize_response::>(user) + .await } } diff --git a/src/api/channels/messages.rs b/src/api/channels/messages.rs index c61756a..dbffaa8 100644 --- a/src/api/channels/messages.rs +++ b/src/api/channels/messages.rs @@ -3,8 +3,9 @@ use http::HeaderMap; use reqwest::{multipart, Client}; use serde_json::to_string; -use crate::api::deserialize_response; +use crate::api::LimitType; use crate::instance::UserMeta; +use crate::ratelimiter::ChorusRequest; use crate::types::{Message, MessageSendSchema, PartialDiscordFileAttachment, Snowflake}; impl Message { @@ -24,16 +25,18 @@ impl Message { channel_id: Snowflake, message: &mut MessageSendSchema, files: Option>, - ) -> Result { + ) -> Result { let url_api = user.belongs_to.borrow().urls.api.clone(); if files.is_none() { - let request = Client::new() - .post(format!("{}/channels/{}/messages/", url_api, channel_id)) - .bearer_auth(user.token()) - .body(to_string(message).unwrap()); - deserialize_response::(request, user, crate::api::limits::LimitType::Channel) - .await + let chorus_request = ChorusRequest { + request: Client::new() + .post(format!("{}/channels/{}/messages/", url_api, channel_id)) + .bearer_auth(user.token()) + .body(to_string(message).unwrap()), + limit_type: LimitType::Channel(channel_id), + }; + chorus_request.deserialize_response::(user).await } else { for (index, attachment) in message.attachments.iter_mut().enumerate() { attachment.get_mut(index).unwrap().set_id(index as i16); @@ -62,13 +65,14 @@ impl Message { form = form.part(part_name, part); } - let request = Client::new() - .post(format!("{}/channels/{}/messages/", url_api, channel_id)) - .bearer_auth(user.token()) - .multipart(form); - - deserialize_response::(request, user, crate::api::limits::LimitType::Channel) - .await + let chorus_request = ChorusRequest { + request: Client::new() + .post(format!("{}/channels/{}/messages/", url_api, channel_id)) + .bearer_auth(user.token()) + .multipart(form), + limit_type: LimitType::Channel(channel_id), + }; + chorus_request.deserialize_response::(user).await } } } @@ -91,7 +95,7 @@ impl UserMeta { message: &mut MessageSendSchema, channel_id: Snowflake, files: Option>, - ) -> Result { + ) -> Result { Message::send(self, channel_id, message, files).await } } diff --git a/src/api/channels/permissions.rs b/src/api/channels/permissions.rs index 0958821..bc666ff 100644 --- a/src/api/channels/permissions.rs +++ b/src/api/channels/permissions.rs @@ -2,9 +2,10 @@ use reqwest::Client; use serde_json::to_string; use crate::{ - api::handle_request_as_result, - errors::{ChorusLibError, ChorusResult}, + api::LimitType, + errors::{ChorusError, ChorusResult}, instance::UserMeta, + ratelimiter::ChorusRequest, types::{self, PermissionOverwrite, Snowflake}, }; @@ -25,24 +26,25 @@ impl types::Channel { channel_id: Snowflake, overwrite: PermissionOverwrite, ) -> ChorusResult<()> { - let url = { - format!( - "{}/channels/{}/permissions/{}", - user.belongs_to.borrow_mut().urls.api, - channel_id, - overwrite.id - ) - }; + let url = format!( + "{}/channels/{}/permissions/{}", + user.belongs_to.borrow_mut().urls.api, + channel_id, + overwrite.id + ); let body = match to_string(&overwrite) { Ok(string) => string, Err(e) => { - return Err(ChorusLibError::FormCreationError { + return Err(ChorusError::FormCreation { error: e.to_string(), }); } }; - let request = Client::new().put(url).bearer_auth(user.token()).body(body); - handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await + let chorus_request = ChorusRequest { + request: Client::new().put(url).bearer_auth(user.token()).body(body), + limit_type: LimitType::Channel(channel_id), + }; + chorus_request.handle_request_as_result(user).await } /// Deletes a permission overwrite for a channel. @@ -67,7 +69,10 @@ impl types::Channel { channel_id, overwrite_id ); - let request = Client::new().delete(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await + let chorus_request = ChorusRequest { + request: Client::new().delete(url).bearer_auth(user.token()), + limit_type: LimitType::Channel(channel_id), + }; + chorus_request.handle_request_as_result(user).await } } diff --git a/src/api/channels/reactions.rs b/src/api/channels/reactions.rs index 96a830a..35dbb94 100644 --- a/src/api/channels/reactions.rs +++ b/src/api/channels/reactions.rs @@ -1,9 +1,10 @@ use reqwest::Client; use crate::{ - api::{deserialize_response, handle_request_as_result}, + api::LimitType, errors::ChorusResult, instance::UserMeta, + ratelimiter::ChorusRequest, types::{self, PublicUser, Snowflake}, }; @@ -32,8 +33,11 @@ impl ReactionMeta { self.channel_id, self.message_id ); - let request = Client::new().delete(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await + let chorus_request = ChorusRequest { + request: Client::new().delete(url).bearer_auth(user.token()), + limit_type: LimitType::Channel(self.channel_id), + }; + chorus_request.handle_request_as_result(user).await } /// Gets a list of users that reacted with a specific emoji to a message. @@ -54,13 +58,13 @@ impl ReactionMeta { self.message_id, emoji ); - let request = Client::new().get(url).bearer_auth(user.token()); - deserialize_response::>( - request, - user, - crate::api::limits::LimitType::Channel, - ) - .await + let chorus_request = ChorusRequest { + request: Client::new().get(url).bearer_auth(user.token()), + limit_type: LimitType::Channel(self.channel_id), + }; + chorus_request + .deserialize_response::>(user) + .await } /// Deletes all the reactions for a given `emoji` on a message. This endpoint requires the @@ -83,8 +87,11 @@ impl ReactionMeta { self.message_id, emoji ); - let request = Client::new().delete(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await + let chorus_request = ChorusRequest { + request: Client::new().delete(url).bearer_auth(user.token()), + limit_type: LimitType::Channel(self.channel_id), + }; + chorus_request.handle_request_as_result(user).await } /// Create a reaction for the message. @@ -110,8 +117,11 @@ impl ReactionMeta { self.message_id, emoji ); - let request = Client::new().put(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await + let chorus_request = ChorusRequest { + request: Client::new().put(url).bearer_auth(user.token()), + limit_type: LimitType::Channel(self.channel_id), + }; + chorus_request.handle_request_as_result(user).await } /// Delete a reaction the current user has made for the message. @@ -133,8 +143,11 @@ impl ReactionMeta { self.message_id, emoji ); - let request = Client::new().delete(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await + let chorus_request = ChorusRequest { + request: Client::new().delete(url).bearer_auth(user.token()), + limit_type: LimitType::Channel(self.channel_id), + }; + chorus_request.handle_request_as_result(user).await } /// Delete a user's reaction to a message. @@ -164,7 +177,10 @@ impl ReactionMeta { emoji, user_id ); - let request = Client::new().delete(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await + let chorus_request = ChorusRequest { + request: Client::new().delete(url).bearer_auth(user.token()), + limit_type: LimitType::Channel(self.channel_id), + }; + chorus_request.handle_request_as_result(user).await } } diff --git a/src/api/common.rs b/src/api/common.rs deleted file mode 100644 index 86a972e..0000000 --- a/src/api/common.rs +++ /dev/null @@ -1,72 +0,0 @@ -use reqwest::RequestBuilder; -use serde::Deserialize; -use serde_json::from_str; - -use crate::{ - errors::{ChorusLibError, ChorusResult}, - instance::UserMeta, - limit::LimitedRequester, -}; - -use super::limits::LimitType; - -/// Sends a request to wherever it needs to go and performs some basic error handling. -pub async fn handle_request( - request: RequestBuilder, - user: &mut UserMeta, - limit_type: LimitType, -) -> Result { - LimitedRequester::send_request( - request, - limit_type, - &mut user.belongs_to.borrow_mut(), - &mut user.limits, - ) - .await -} - -/// Sends a request to wherever it needs to go. Returns [`Ok(())`] on success and -/// [`Err(ChorusLibError)`] on failure. -pub async fn handle_request_as_result( - request: RequestBuilder, - user: &mut UserMeta, - limit_type: LimitType, -) -> ChorusResult<()> { - match handle_request(request, user, limit_type).await { - Ok(_) => Ok(()), - Err(e) => Err(ChorusLibError::InvalidResponseError { - error: e.to_string(), - }), - } -} - -pub async fn deserialize_response Deserialize<'a>>( - request: RequestBuilder, - user: &mut UserMeta, - limit_type: LimitType, -) -> ChorusResult { - let response = handle_request(request, user, limit_type).await.unwrap(); - let response_text = match response.text().await { - Ok(string) => string, - Err(e) => { - return Err(ChorusLibError::InvalidResponseError { - error: format!( - "Error while trying to process the HTTP response into a String: {}", - e - ), - }); - } - }; - let object = match from_str::(&response_text) { - Ok(object) => object, - Err(e) => { - return Err(ChorusLibError::InvalidResponseError { - error: format!( - "Error while trying to deserialize the JSON response into T: {}", - e - ), - }) - } - }; - Ok(object) -} diff --git a/src/api/guilds/guilds.rs b/src/api/guilds/guilds.rs index ac00df1..3654698 100644 --- a/src/api/guilds/guilds.rs +++ b/src/api/guilds/guilds.rs @@ -2,15 +2,11 @@ use reqwest::Client; use serde_json::from_str; use serde_json::to_string; -use crate::api::deserialize_response; -use crate::api::handle_request; -use crate::api::handle_request_as_result; -use crate::api::limits::Limits; -use crate::errors::ChorusLibError; +use crate::api::LimitType; +use crate::errors::ChorusError; use crate::errors::ChorusResult; -use crate::instance::Instance; use crate::instance::UserMeta; -use crate::limit::LimitedRequester; +use crate::ratelimiter::ChorusRequest; use crate::types::Snowflake; use crate::types::{Channel, ChannelCreateSchema, Guild, GuildCreateSchema}; @@ -36,11 +32,14 @@ impl Guild { guild_create_schema: GuildCreateSchema, ) -> ChorusResult { let url = format!("{}/guilds/", user.belongs_to.borrow().urls.api); - let request = reqwest::Client::new() - .post(url.clone()) - .bearer_auth(user.token.clone()) - .body(to_string(&guild_create_schema).unwrap()); - deserialize_response::(request, user, crate::api::limits::LimitType::Guild).await + let chorus_request = ChorusRequest { + request: Client::new() + .post(url.clone()) + .bearer_auth(user.token.clone()) + .body(to_string(&guild_create_schema).unwrap()), + limit_type: LimitType::Global, + }; + chorus_request.deserialize_response::(user).await } /// Deletes a guild. @@ -73,10 +72,13 @@ impl Guild { user.belongs_to.borrow().urls.api, guild_id ); - let request = reqwest::Client::new() - .post(url.clone()) - .bearer_auth(user.token.clone()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Guild).await + let chorus_request = ChorusRequest { + request: Client::new() + .post(url.clone()) + .bearer_auth(user.token.clone()), + limit_type: LimitType::Global, + }; + chorus_request.handle_request_as_result(user).await } /// Sends a request to create a new channel in the guild. @@ -97,14 +99,7 @@ impl Guild { user: &mut UserMeta, schema: ChannelCreateSchema, ) -> ChorusResult { - Channel::_create( - &user.token, - self.id, - schema, - &mut user.limits, - &mut user.belongs_to.borrow_mut(), - ) - .await + Channel::create(user, self.id, schema).await } /// Returns a `Result` containing a vector of `Channel` structs if the request was successful, or an `ChorusLibError` if there was an error. @@ -117,20 +112,21 @@ impl Guild { /// * `limits_instance` - A mutable reference to a `Limits` struct containing the instance's rate limits. /// pub async fn channels(&self, user: &mut UserMeta) -> ChorusResult> { - let request = Client::new() - .get(format!( - "{}/guilds/{}/channels/", - user.belongs_to.borrow().urls.api, - self.id - )) - .bearer_auth(user.token()); - let result = handle_request(request, user, crate::api::limits::LimitType::Channel) - .await - .unwrap(); + let chorus_request = ChorusRequest { + request: Client::new() + .get(format!( + "{}/guilds/{}/channels/", + user.belongs_to.borrow().urls.api, + self.id + )) + .bearer_auth(user.token()), + limit_type: LimitType::Channel(self.id), + }; + let result = chorus_request.send_request(user).await?; let stringed_response = match result.text().await { Ok(value) => value, Err(e) => { - return Err(ChorusLibError::InvalidResponseError { + return Err(ChorusError::InvalidResponse { error: e.to_string(), }); } @@ -138,7 +134,7 @@ impl Guild { let _: Vec = match from_str(&stringed_response) { Ok(result) => return Ok(result), Err(e) => { - return Err(ChorusLibError::InvalidResponseError { + return Err(ChorusError::InvalidResponse { error: e.to_string(), }); } @@ -155,35 +151,19 @@ impl Guild { /// * `limits_user` - A mutable reference to a `Limits` struct containing the user's rate limits. /// * `limits_instance` - A mutable reference to a `Limits` struct containing the instance's rate limits. /// - pub async fn get(user: &mut UserMeta, guild_id: Snowflake) -> ChorusResult { - let mut belongs_to = user.belongs_to.borrow_mut(); - Guild::_get(guild_id, &user.token, &mut user.limits, &mut belongs_to).await - } - - /// For internal use. Does the same as the public get method, but does not require a second, mutable - /// borrow of `UserMeta::belongs_to`, when used in conjunction with other methods, which borrow `UserMeta::belongs_to`. - async fn _get( - guild_id: Snowflake, - token: &str, - limits_user: &mut Limits, - instance: &mut Instance, - ) -> ChorusResult { - let request = Client::new() - .get(format!("{}/guilds/{}/", instance.urls.api, guild_id)) - .bearer_auth(token); - let response = match LimitedRequester::send_request( - request, - crate::api::limits::LimitType::Guild, - instance, - limits_user, - ) - .await - { - Ok(response) => response, - Err(e) => return Err(e), + pub async fn get(guild_id: Snowflake, user: &mut UserMeta) -> ChorusResult { + let chorus_request = ChorusRequest { + request: Client::new() + .get(format!( + "{}/guilds/{}/", + user.belongs_to.borrow().urls.api, + guild_id + )) + .bearer_auth(user.token()), + limit_type: LimitType::Guild(guild_id), }; - let guild: Guild = from_str(&response.text().await.unwrap()).unwrap(); - Ok(guild) + let response = chorus_request.deserialize_response::(user).await?; + Ok(response) } } @@ -207,48 +187,17 @@ impl Channel { guild_id: Snowflake, schema: ChannelCreateSchema, ) -> ChorusResult { - let mut belongs_to = user.belongs_to.borrow_mut(); - Channel::_create( - &user.token, - guild_id, - schema, - &mut user.limits, - &mut belongs_to, - ) - .await - } - - async fn _create( - token: &str, - guild_id: Snowflake, - schema: ChannelCreateSchema, - limits_user: &mut Limits, - instance: &mut Instance, - ) -> ChorusResult { - let request = Client::new() - .post(format!( - "{}/guilds/{}/channels/", - instance.urls.api, guild_id - )) - .bearer_auth(token) - .body(to_string(&schema).unwrap()); - let result = match LimitedRequester::send_request( - request, - crate::api::limits::LimitType::Guild, - instance, - limits_user, - ) - .await - { - Ok(result) => result, - Err(e) => return Err(e), + let chorus_request = ChorusRequest { + request: Client::new() + .post(format!( + "{}/guilds/{}/channels/", + user.belongs_to.borrow().urls.api, + guild_id + )) + .bearer_auth(user.token()) + .body(to_string(&schema).unwrap()), + limit_type: LimitType::Guild(guild_id), }; - match from_str::(&result.text().await.unwrap()) { - Ok(object) => Ok(object), - Err(e) => Err(ChorusLibError::RequestErrorError { - url: format!("{}/guilds/{}/channels/", instance.urls.api, guild_id), - error: e.to_string(), - }), - } + chorus_request.deserialize_response::(user).await } } diff --git a/src/api/guilds/member.rs b/src/api/guilds/member.rs index e110cfa..5fa99cd 100644 --- a/src/api/guilds/member.rs +++ b/src/api/guilds/member.rs @@ -1,9 +1,10 @@ use reqwest::Client; use crate::{ - api::{deserialize_response, handle_request_as_result}, + api::LimitType, errors::ChorusResult, instance::UserMeta, + ratelimiter::ChorusRequest, types::{self, Snowflake}, }; @@ -30,13 +31,13 @@ impl types::GuildMember { guild_id, member_id ); - let request = Client::new().get(url).bearer_auth(user.token()); - deserialize_response::( - request, - user, - crate::api::limits::LimitType::Guild, - ) - .await + let chorus_request = ChorusRequest { + request: Client::new().get(url).bearer_auth(user.token()), + limit_type: LimitType::Guild(guild_id), + }; + chorus_request + .deserialize_response::(user) + .await } /// Adds a role to a guild member. @@ -64,8 +65,11 @@ impl types::GuildMember { member_id, role_id ); - let request = Client::new().put(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Guild).await + let chorus_request = ChorusRequest { + request: Client::new().put(url).bearer_auth(user.token()), + limit_type: LimitType::Guild(guild_id), + }; + chorus_request.handle_request_as_result(user).await } /// Removes a role from a guild member. @@ -85,7 +89,7 @@ impl types::GuildMember { guild_id: Snowflake, member_id: Snowflake, role_id: Snowflake, - ) -> Result<(), crate::errors::ChorusLibError> { + ) -> Result<(), crate::errors::ChorusError> { let url = format!( "{}/guilds/{}/members/{}/roles/{}/", user.belongs_to.borrow().urls.api, @@ -93,7 +97,10 @@ impl types::GuildMember { member_id, role_id ); - let request = Client::new().delete(url).bearer_auth(user.token()); - handle_request_as_result(request, user, crate::api::limits::LimitType::Guild).await + let chorus_request = ChorusRequest { + request: Client::new().delete(url).bearer_auth(user.token()), + limit_type: LimitType::Guild(guild_id), + }; + chorus_request.handle_request_as_result(user).await } } diff --git a/src/api/guilds/roles.rs b/src/api/guilds/roles.rs index 80dddad..1d1bc68 100644 --- a/src/api/guilds/roles.rs +++ b/src/api/guilds/roles.rs @@ -2,9 +2,10 @@ use reqwest::Client; use serde_json::to_string; use crate::{ - api::deserialize_response, - errors::{ChorusLibError, ChorusResult}, + api::LimitType, + errors::{ChorusError, ChorusResult}, instance::UserMeta, + ratelimiter::ChorusRequest, types::{self, RoleCreateModifySchema, RoleObject, Snowflake}, }; @@ -32,14 +33,14 @@ impl types::RoleObject { user.belongs_to.borrow().urls.api, guild_id ); - let request = Client::new().get(url).bearer_auth(user.token()); - let roles = deserialize_response::>( - request, - user, - crate::api::limits::LimitType::Guild, - ) - .await - .unwrap(); + let chorus_request = ChorusRequest { + request: Client::new().get(url).bearer_auth(user.token()), + limit_type: LimitType::Guild(guild_id), + }; + let roles = chorus_request + .deserialize_response::>(user) + .await + .unwrap(); if roles.is_empty() { return Ok(None); } @@ -72,8 +73,13 @@ impl types::RoleObject { guild_id, role_id ); - let request = Client::new().get(url).bearer_auth(user.token()); - deserialize_response(request, user, crate::api::limits::LimitType::Guild).await + let chorus_request = ChorusRequest { + request: Client::new().get(url).bearer_auth(user.token()), + limit_type: LimitType::Guild(guild_id), + }; + chorus_request + .deserialize_response::(user) + .await } /// Creates a new role for a given guild. @@ -102,12 +108,17 @@ impl types::RoleObject { guild_id ); let body = to_string::(&role_create_schema).map_err(|e| { - ChorusLibError::FormCreationError { + ChorusError::FormCreation { error: e.to_string(), } })?; - let request = Client::new().post(url).bearer_auth(user.token()).body(body); - deserialize_response(request, user, crate::api::limits::LimitType::Guild).await + let chorus_request = ChorusRequest { + request: Client::new().post(url).bearer_auth(user.token()).body(body), + limit_type: LimitType::Guild(guild_id), + }; + chorus_request + .deserialize_response::(user) + .await } /// Updates the position of a role in the guild's hierarchy. @@ -135,16 +146,19 @@ impl types::RoleObject { user.belongs_to.borrow().urls.api, guild_id ); - let body = to_string(&role_position_update_schema).map_err(|e| { - ChorusLibError::FormCreationError { + let body = + to_string(&role_position_update_schema).map_err(|e| ChorusError::FormCreation { error: e.to_string(), - } - })?; - let request = Client::new() - .patch(url) - .bearer_auth(user.token()) - .body(body); - deserialize_response::(request, user, crate::api::limits::LimitType::Guild) + })?; + let chorus_request = ChorusRequest { + request: Client::new() + .patch(url) + .bearer_auth(user.token()) + .body(body), + limit_type: LimitType::Guild(guild_id), + }; + chorus_request + .deserialize_response::(user) .await } @@ -177,15 +191,19 @@ impl types::RoleObject { role_id ); let body = to_string::(&role_create_schema).map_err(|e| { - ChorusLibError::FormCreationError { + ChorusError::FormCreation { error: e.to_string(), } })?; - let request = Client::new() - .patch(url) - .bearer_auth(user.token()) - .body(body); - deserialize_response::(request, user, crate::api::limits::LimitType::Guild) + let chorus_request = ChorusRequest { + request: Client::new() + .patch(url) + .bearer_auth(user.token()) + .body(body), + limit_type: LimitType::Guild(guild_id), + }; + chorus_request + .deserialize_response::(user) .await } } diff --git a/src/api/mod.rs b/src/api/mod.rs index 56abb5f..0c02b98 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,12 +1,10 @@ pub use channels::messages::*; -pub use common::*; pub use guilds::*; pub use policies::instance::instance::*; -pub use policies::instance::limits::*; +pub use policies::instance::ratelimits::*; pub mod auth; pub mod channels; -pub mod common; pub mod guilds; pub mod policies; pub mod users; diff --git a/src/api/policies/instance/instance.rs b/src/api/policies/instance/instance.rs index 0ea3699..75f832c 100644 --- a/src/api/policies/instance/instance.rs +++ b/src/api/policies/instance/instance.rs @@ -1,7 +1,6 @@ -use reqwest::Client; use serde_json::from_str; -use crate::errors::{ChorusLibError, ChorusResult}; +use crate::errors::{ChorusError, ChorusResult}; use crate::instance::Instance; use crate::types::GeneralConfiguration; @@ -10,21 +9,21 @@ impl Instance { /// # Errors /// [`ChorusLibError`] - If the request fails. pub async fn general_configuration_schema(&self) -> ChorusResult { - let client = Client::new(); let endpoint_url = self.urls.api.clone() + "/policies/instance/"; - let request = match client.get(&endpoint_url).send().await { + let request = match self.client.get(&endpoint_url).send().await { Ok(result) => result, Err(e) => { - return Err(ChorusLibError::RequestErrorError { + return Err(ChorusError::RequestFailed { url: endpoint_url, - error: e.to_string(), + error: e, }); } }; if !request.status().as_str().starts_with('2') { - return Err(ChorusLibError::ReceivedErrorCodeError { - error_code: request.status().to_string(), + return Err(ChorusError::ReceivedErrorCode { + error_code: request.status().as_u16(), + error: request.text().await.unwrap(), }); } diff --git a/src/api/policies/instance/limits.rs b/src/api/policies/instance/limits.rs deleted file mode 100644 index 3c06d29..0000000 --- a/src/api/policies/instance/limits.rs +++ /dev/null @@ -1,499 +0,0 @@ -pub mod limits { - use std::collections::HashMap; - - use reqwest::Client; - use serde::{Deserialize, Serialize}; - use serde_json::from_str; - - #[derive(Clone, Copy, Eq, Hash, PartialEq, Debug, Default)] - pub enum LimitType { - AuthRegister, - AuthLogin, - AbsoluteMessage, - AbsoluteRegister, - #[default] - Global, - Ip, - Channel, - Error, - Guild, - Webhook, - } - - impl ToString for LimitType { - fn to_string(&self) -> String { - match self { - LimitType::AuthRegister => "AuthRegister".to_string(), - LimitType::AuthLogin => "AuthLogin".to_string(), - LimitType::AbsoluteMessage => "AbsoluteMessage".to_string(), - LimitType::AbsoluteRegister => "AbsoluteRegister".to_string(), - LimitType::Global => "Global".to_string(), - LimitType::Ip => "Ip".to_string(), - LimitType::Channel => "Channel".to_string(), - LimitType::Error => "Error".to_string(), - LimitType::Guild => "Guild".to_string(), - LimitType::Webhook => "Webhook".to_string(), - } - } - } - - #[derive(Debug, Deserialize, Serialize)] - #[allow(non_snake_case)] - pub struct User { - pub maxGuilds: u64, - pub maxUsername: u64, - pub maxFriends: u64, - } - - #[derive(Debug, Deserialize, Serialize)] - #[allow(non_snake_case)] - pub struct Guild { - pub maxRoles: u64, - pub maxEmojis: u64, - pub maxMembers: u64, - pub maxChannels: u64, - pub maxChannelsInCategory: u64, - } - - #[derive(Debug, Deserialize, Serialize)] - #[allow(non_snake_case)] - pub struct Message { - pub maxCharacters: u64, - pub maxTTSCharacters: u64, - pub maxReactions: u64, - pub maxAttachmentSize: u64, - pub maxBulkDelete: u64, - pub maxEmbedDownloadSize: u64, - } - - #[derive(Debug, Deserialize, Serialize)] - #[allow(non_snake_case)] - pub struct Channel { - pub maxPins: u64, - pub maxTopic: u64, - pub maxWebhooks: u64, - } - - #[derive(Debug, Deserialize, Serialize)] - pub struct Rate { - pub enabled: bool, - pub ip: Window, - pub global: Window, - pub error: Window, - pub routes: Routes, - } - - #[derive(Debug, Deserialize, Serialize)] - pub struct Window { - pub count: u64, - pub window: u64, - } - - #[derive(Debug, Deserialize, Serialize)] - pub struct Routes { - pub guild: Window, - pub webhook: Window, - pub channel: Window, - pub auth: AuthRoutes, - } - - #[derive(Debug, Deserialize, Serialize)] - #[allow(non_snake_case)] - pub struct AuthRoutes { - pub login: Window, - pub register: Window, - } - - #[derive(Debug, Deserialize, Serialize)] - #[allow(non_snake_case)] - pub struct AbsoluteRate { - pub register: AbsoluteWindow, - pub sendMessage: AbsoluteWindow, - } - - #[derive(Debug, Deserialize, Serialize)] - pub struct AbsoluteWindow { - pub limit: u64, - pub window: u64, - pub enabled: bool, - } - - #[derive(Debug, Deserialize, Serialize)] - #[allow(non_snake_case)] - pub struct Config { - pub user: User, - pub guild: Guild, - pub message: Message, - pub channel: Channel, - pub rate: Rate, - pub absoluteRate: AbsoluteRate, - } - - #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] - pub struct Limit { - pub bucket: LimitType, - pub limit: u64, - pub remaining: u64, - pub reset: u64, - } - - impl std::fmt::Display for Limit { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Bucket: {:?}, Limit: {}, Remaining: {}, Reset: {}", - self.bucket, self.limit, self.remaining, self.reset - ) - } - } - - impl Limit { - pub fn add_remaining(&mut self, remaining: i64) { - if remaining < 0 { - if (self.remaining as i64 + remaining) <= 0 { - self.remaining = 0; - return; - } - self.remaining -= remaining.unsigned_abs(); - return; - } - self.remaining += remaining.unsigned_abs(); - } - } - - pub struct LimitsMutRef<'a> { - pub limit_absolute_messages: &'a mut Limit, - pub limit_absolute_register: &'a mut Limit, - pub limit_auth_login: &'a mut Limit, - pub limit_auth_register: &'a mut Limit, - pub limit_ip: &'a mut Limit, - pub limit_global: &'a mut Limit, - pub limit_error: &'a mut Limit, - pub limit_guild: &'a mut Limit, - pub limit_webhook: &'a mut Limit, - pub limit_channel: &'a mut Limit, - } - - impl LimitsMutRef<'_> { - pub fn combine_mut_ref<'a>( - instance_rate_limits: &'a mut Limits, - user_rate_limits: &'a mut Limits, - ) -> LimitsMutRef<'a> { - LimitsMutRef { - limit_absolute_messages: &mut instance_rate_limits.limit_absolute_messages, - limit_absolute_register: &mut instance_rate_limits.limit_absolute_register, - limit_auth_login: &mut instance_rate_limits.limit_auth_login, - limit_auth_register: &mut instance_rate_limits.limit_auth_register, - limit_channel: &mut user_rate_limits.limit_channel, - limit_error: &mut user_rate_limits.limit_error, - limit_global: &mut instance_rate_limits.limit_global, - limit_guild: &mut user_rate_limits.limit_guild, - limit_ip: &mut instance_rate_limits.limit_ip, - limit_webhook: &mut user_rate_limits.limit_webhook, - } - } - - pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit { - match limit_type { - LimitType::AbsoluteMessage => self.limit_absolute_messages, - LimitType::AbsoluteRegister => self.limit_absolute_register, - LimitType::AuthLogin => self.limit_auth_login, - LimitType::AuthRegister => self.limit_auth_register, - LimitType::Channel => self.limit_channel, - LimitType::Error => self.limit_error, - LimitType::Global => self.limit_global, - LimitType::Guild => self.limit_guild, - LimitType::Ip => self.limit_ip, - LimitType::Webhook => self.limit_webhook, - } - } - - pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit { - match limit_type { - LimitType::AbsoluteMessage => self.limit_absolute_messages, - LimitType::AbsoluteRegister => self.limit_absolute_register, - LimitType::AuthLogin => self.limit_auth_login, - LimitType::AuthRegister => self.limit_auth_register, - LimitType::Channel => self.limit_channel, - LimitType::Error => self.limit_error, - LimitType::Global => self.limit_global, - LimitType::Guild => self.limit_guild, - LimitType::Ip => self.limit_ip, - LimitType::Webhook => self.limit_webhook, - } - } - } - - #[derive(Debug, Clone, Default)] - pub struct Limits { - pub limit_absolute_messages: Limit, - pub limit_absolute_register: Limit, - pub limit_auth_login: Limit, - pub limit_auth_register: Limit, - pub limit_ip: Limit, - pub limit_global: Limit, - pub limit_error: Limit, - pub limit_guild: Limit, - pub limit_webhook: Limit, - pub limit_channel: Limit, - } - - impl Limits { - pub fn combine(instance_rate_limits: &Limits, user_rate_limits: &Limits) -> Limits { - Limits { - limit_absolute_messages: instance_rate_limits.limit_absolute_messages, - limit_absolute_register: instance_rate_limits.limit_absolute_register, - limit_auth_login: instance_rate_limits.limit_auth_login, - limit_auth_register: instance_rate_limits.limit_auth_register, - limit_channel: user_rate_limits.limit_channel, - limit_error: user_rate_limits.limit_error, - limit_global: instance_rate_limits.limit_global, - limit_guild: user_rate_limits.limit_guild, - limit_ip: instance_rate_limits.limit_ip, - limit_webhook: user_rate_limits.limit_webhook, - } - } - - pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit { - match limit_type { - LimitType::AbsoluteMessage => &self.limit_absolute_messages, - LimitType::AbsoluteRegister => &self.limit_absolute_register, - LimitType::AuthLogin => &self.limit_auth_login, - LimitType::AuthRegister => &self.limit_auth_register, - LimitType::Channel => &self.limit_channel, - LimitType::Error => &self.limit_error, - LimitType::Global => &self.limit_global, - LimitType::Guild => &self.limit_guild, - LimitType::Ip => &self.limit_ip, - LimitType::Webhook => &self.limit_webhook, - } - } - - pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit { - match limit_type { - LimitType::AbsoluteMessage => &mut self.limit_absolute_messages, - LimitType::AbsoluteRegister => &mut self.limit_absolute_register, - LimitType::AuthLogin => &mut self.limit_auth_login, - LimitType::AuthRegister => &mut self.limit_auth_register, - LimitType::Channel => &mut self.limit_channel, - LimitType::Error => &mut self.limit_error, - LimitType::Global => &mut self.limit_global, - LimitType::Guild => &mut self.limit_guild, - LimitType::Ip => &mut self.limit_ip, - LimitType::Webhook => &mut self.limit_webhook, - } - } - - pub fn to_hash_map(&self) -> HashMap { - let mut map: HashMap = HashMap::new(); - map.insert(LimitType::AbsoluteMessage, self.limit_absolute_messages); - map.insert(LimitType::AbsoluteRegister, self.limit_absolute_register); - map.insert(LimitType::AuthLogin, self.limit_auth_login); - map.insert(LimitType::AuthRegister, self.limit_auth_register); - map.insert(LimitType::Ip, self.limit_ip); - map.insert(LimitType::Global, self.limit_global); - map.insert(LimitType::Error, self.limit_error); - map.insert(LimitType::Guild, self.limit_guild); - map.insert(LimitType::Webhook, self.limit_webhook); - map.insert(LimitType::Channel, self.limit_channel); - map - } - - pub fn get_as_mut(&mut self) -> &mut Limits { - self - } - - /// check_limits uses the API to get the current request limits of the instance. - /// It returns a `Limits` struct containing all the limits. - /// If the rate limit is disabled, then the limit is set to `u64::MAX`. - /// # Errors - /// This function will panic if the request fails or if the response body cannot be parsed. - /// TODO: Change this to return a Result and handle the errors properly. - pub async fn check_limits(api_url: String) -> Limits { - let client = Client::new(); - let url_parsed = crate::UrlBundle::parse_url(api_url) + "/policies/instance/limits"; - let result = client - .get(url_parsed) - .send() - .await - .unwrap_or_else(|e| panic!("An error occured while performing the request: {}", e)) - .text() - .await - .unwrap_or_else(|e| { - panic!( - "An error occured while parsing the request body string: {}", - e - ) - }); - let config: Config = from_str(&result).unwrap(); - // If config.rate.enabled is false, then add return a Limits struct with all limits set to u64::MAX - let mut limits: Limits; - if !config.rate.enabled { - limits = Limits { - limit_absolute_messages: Limit { - bucket: LimitType::AbsoluteMessage, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_absolute_register: Limit { - bucket: LimitType::AbsoluteRegister, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_auth_login: Limit { - bucket: LimitType::AuthLogin, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_auth_register: Limit { - bucket: LimitType::AuthRegister, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_ip: Limit { - bucket: LimitType::Ip, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_global: Limit { - bucket: LimitType::Global, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_error: Limit { - bucket: LimitType::Error, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_guild: Limit { - bucket: LimitType::Guild, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_webhook: Limit { - bucket: LimitType::Webhook, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - limit_channel: Limit { - bucket: LimitType::Channel, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - }; - } else { - limits = Limits { - limit_absolute_messages: Limit { - bucket: LimitType::AbsoluteMessage, - limit: config.absoluteRate.sendMessage.limit, - remaining: config.absoluteRate.sendMessage.limit, - reset: config.absoluteRate.sendMessage.window, - }, - limit_absolute_register: Limit { - bucket: LimitType::AbsoluteRegister, - limit: config.absoluteRate.register.limit, - remaining: config.absoluteRate.register.limit, - reset: config.absoluteRate.register.window, - }, - limit_auth_login: Limit { - bucket: LimitType::AuthLogin, - limit: config.rate.routes.auth.login.count, - remaining: config.rate.routes.auth.login.count, - reset: config.rate.routes.auth.login.window, - }, - limit_auth_register: Limit { - bucket: LimitType::AuthRegister, - limit: config.rate.routes.auth.register.count, - remaining: config.rate.routes.auth.register.count, - reset: config.rate.routes.auth.register.window, - }, - limit_ip: Limit { - bucket: LimitType::Ip, - limit: config.rate.ip.count, - remaining: config.rate.ip.count, - reset: config.rate.ip.window, - }, - limit_global: Limit { - bucket: LimitType::Global, - limit: config.rate.global.count, - remaining: config.rate.global.count, - reset: config.rate.global.window, - }, - limit_error: Limit { - bucket: LimitType::Error, - limit: config.rate.error.count, - remaining: config.rate.error.count, - reset: config.rate.error.window, - }, - limit_guild: Limit { - bucket: LimitType::Guild, - limit: config.rate.routes.guild.count, - remaining: config.rate.routes.guild.count, - reset: config.rate.routes.guild.window, - }, - limit_webhook: Limit { - bucket: LimitType::Webhook, - limit: config.rate.routes.webhook.count, - remaining: config.rate.routes.webhook.count, - reset: config.rate.routes.webhook.window, - }, - limit_channel: Limit { - bucket: LimitType::Channel, - limit: config.rate.routes.channel.count, - remaining: config.rate.routes.channel.count, - reset: config.rate.routes.channel.window, - }, - }; - } - - if !config.absoluteRate.register.enabled { - limits.limit_absolute_register = Limit { - bucket: LimitType::AbsoluteRegister, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }; - } - - if !config.absoluteRate.sendMessage.enabled { - limits.limit_absolute_messages = Limit { - bucket: LimitType::AbsoluteMessage, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }; - } - - limits - } - } -} - -#[cfg(test)] -mod instance_limits { - use crate::api::limits::{Limit, LimitType}; - - #[test] - fn limit_below_zero() { - let mut limit = Limit { - bucket: LimitType::AbsoluteMessage, - limit: 0, - remaining: 1, - reset: 0, - }; - limit.add_remaining(-2); - assert_eq!(0_u64, limit.remaining); - limit.add_remaining(-2123123); - assert_eq!(0_u64, limit.remaining); - } -} diff --git a/src/api/policies/instance/mod.rs b/src/api/policies/instance/mod.rs index 7be9605..0a1f245 100644 --- a/src/api/policies/instance/mod.rs +++ b/src/api/policies/instance/mod.rs @@ -1,5 +1,5 @@ pub use instance::*; -pub use limits::*; +pub use ratelimits::*; pub mod instance; -pub mod limits; +pub mod ratelimits; diff --git a/src/api/policies/instance/ratelimits.rs b/src/api/policies/instance/ratelimits.rs new file mode 100644 index 0000000..125af32 --- /dev/null +++ b/src/api/policies/instance/ratelimits.rs @@ -0,0 +1,35 @@ +use std::hash::Hash; + +use crate::types::Snowflake; + +/// The different types of ratelimits that can be applied to a request. Includes "Baseline"-variants +/// for when the Snowflake is not yet known. +/// See for more information. +#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, Hash)] +pub enum LimitType { + AuthRegister, + AuthLogin, + #[default] + Global, + Ip, + Channel(Snowflake), + ChannelBaseline, + Error, + Guild(Snowflake), + GuildBaseline, + Webhook(Snowflake), + WebhookBaseline, +} + +/// A struct that represents the current ratelimits, either instance-wide or user-wide. +/// Unlike [`RateLimits`], this struct shows the current ratelimits, not the rate limit +/// configuration for the instance. +/// See for more information. +#[derive(Debug, Clone)] +pub struct Limit { + pub bucket: LimitType, + pub limit: u64, + pub remaining: u64, + pub reset: u64, + pub window: u64, +} diff --git a/src/api/policies/mod.rs b/src/api/policies/mod.rs index 3e25d8c..d0c29f1 100644 --- a/src/api/policies/mod.rs +++ b/src/api/policies/mod.rs @@ -1,3 +1,3 @@ -pub use instance::limits::*; +pub use instance::ratelimits::*; pub mod instance; diff --git a/src/api/users/relationships.rs b/src/api/users/relationships.rs index 36cabd2..39c75d8 100644 --- a/src/api/users/relationships.rs +++ b/src/api/users/relationships.rs @@ -2,9 +2,10 @@ use reqwest::Client; use serde_json::to_string; use crate::{ - api::{deserialize_response, handle_request_as_result}, + api::LimitType, errors::ChorusResult, instance::UserMeta, + ratelimiter::ChorusRequest, types::{self, CreateUserRelationshipSchema, RelationshipType, Snowflake}, }; @@ -26,13 +27,13 @@ impl UserMeta { self.belongs_to.borrow().urls.api, user_id ); - let request = Client::new().get(url).bearer_auth(self.token()); - deserialize_response::>( - request, - self, - crate::api::limits::LimitType::Global, - ) - .await + let chorus_request = ChorusRequest { + request: Client::new().get(url).bearer_auth(self.token()), + limit_type: LimitType::Global, + }; + chorus_request + .deserialize_response::>(self) + .await } /// Retrieves the authenticated user's relationships. @@ -44,13 +45,13 @@ impl UserMeta { "{}/users/@me/relationships/", self.belongs_to.borrow().urls.api ); - let request = Client::new().get(url).bearer_auth(self.token()); - deserialize_response::>( - request, - self, - crate::api::limits::LimitType::Global, - ) - .await + let chorus_request = ChorusRequest { + request: Client::new().get(url).bearer_auth(self.token()), + limit_type: LimitType::Global, + }; + chorus_request + .deserialize_response::>(self) + .await } /// Sends a friend request to a user. @@ -70,8 +71,11 @@ impl UserMeta { self.belongs_to.borrow().urls.api ); let body = to_string(&schema).unwrap(); - let request = Client::new().post(url).bearer_auth(self.token()).body(body); - handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await + let chorus_request = ChorusRequest { + request: Client::new().post(url).bearer_auth(self.token()).body(body), + limit_type: LimitType::Global, + }; + chorus_request.handle_request_as_result(self).await } /// Modifies the relationship between the authenticated user and the specified user. @@ -96,10 +100,13 @@ impl UserMeta { let api_url = self.belongs_to.borrow().urls.api.clone(); match relationship_type { RelationshipType::None => { - let request = Client::new() - .delete(format!("{}/users/@me/relationships/{}/", api_url, user_id)) - .bearer_auth(self.token()); - handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await + let chorus_request = ChorusRequest { + request: Client::new() + .delete(format!("{}/users/@me/relationships/{}/", api_url, user_id)) + .bearer_auth(self.token()), + limit_type: LimitType::Global, + }; + chorus_request.handle_request_as_result(self).await } RelationshipType::Friends | RelationshipType::Incoming | RelationshipType::Outgoing => { let body = CreateUserRelationshipSchema { @@ -107,11 +114,14 @@ impl UserMeta { from_friend_suggestion: None, friend_token: None, }; - let request = Client::new() - .put(format!("{}/users/@me/relationships/{}/", api_url, user_id)) - .bearer_auth(self.token()) - .body(to_string(&body).unwrap()); - handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await + let chorus_request = ChorusRequest { + request: Client::new() + .put(format!("{}/users/@me/relationships/{}/", api_url, user_id)) + .bearer_auth(self.token()) + .body(to_string(&body).unwrap()), + limit_type: LimitType::Global, + }; + chorus_request.handle_request_as_result(self).await } RelationshipType::Blocked => { let body = CreateUserRelationshipSchema { @@ -119,11 +129,14 @@ impl UserMeta { from_friend_suggestion: None, friend_token: None, }; - let request = Client::new() - .put(format!("{}/users/@me/relationships/{}/", api_url, user_id)) - .bearer_auth(self.token()) - .body(to_string(&body).unwrap()); - handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await + let chorus_request = ChorusRequest { + request: Client::new() + .put(format!("{}/users/@me/relationships/{}/", api_url, user_id)) + .bearer_auth(self.token()) + .body(to_string(&body).unwrap()), + limit_type: LimitType::Global, + }; + chorus_request.handle_request_as_result(self).await } RelationshipType::Suggestion | RelationshipType::Implicit => Ok(()), } @@ -143,7 +156,10 @@ impl UserMeta { self.belongs_to.borrow().urls.api, user_id ); - let request = Client::new().delete(url).bearer_auth(self.token()); - handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await + let chorus_request = ChorusRequest { + request: Client::new().delete(url).bearer_auth(self.token()), + limit_type: LimitType::Global, + }; + chorus_request.handle_request_as_result(self).await } } diff --git a/src/api/users/users.rs b/src/api/users/users.rs index cd777dc..bd46a36 100644 --- a/src/api/users/users.rs +++ b/src/api/users/users.rs @@ -1,11 +1,13 @@ +use std::{cell::RefCell, rc::Rc}; + use reqwest::Client; use serde_json::to_string; use crate::{ - api::{deserialize_response, handle_request_as_result}, - errors::{ChorusLibError, ChorusResult}, + api::LimitType, + errors::{ChorusError, ChorusResult}, instance::{Instance, UserMeta}, - limit::LimitedRequester, + ratelimiter::ChorusRequest, types::{User, UserModifySchema, UserSettings}, }; @@ -48,16 +50,20 @@ impl UserMeta { || modify_schema.email.is_some() || modify_schema.code.is_some() { - return Err(ChorusLibError::PasswordRequiredError); + return Err(ChorusError::PasswordRequired); } let request = Client::new() .patch(format!("{}/users/@me/", self.belongs_to.borrow().urls.api)) .body(to_string(&modify_schema).unwrap()) .bearer_auth(self.token()); - let user_updated = - deserialize_response::(request, self, crate::api::limits::LimitType::Ip) - .await - .unwrap(); + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + }; + let user_updated = chorus_request + .deserialize_response::(self) + .await + .unwrap(); let _ = std::mem::replace(&mut self.object, user_updated.clone()); Ok(user_updated) } @@ -78,43 +84,28 @@ impl UserMeta { self.belongs_to.borrow().urls.api )) .bearer_auth(self.token()); - handle_request_as_result(request, &mut self, crate::api::limits::LimitType::Ip).await + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + }; + chorus_request.handle_request_as_result(&mut self).await } } impl User { pub async fn get(user: &mut UserMeta, id: Option<&String>) -> ChorusResult { - let mut belongs_to = user.belongs_to.borrow_mut(); - User::_get( - &user.token(), - &format!("{}", belongs_to.urls.api), - &mut belongs_to, - id, - ) - .await - } - - async fn _get( - token: &str, - url_api: &str, - instance: &mut Instance, - id: Option<&String>, - ) -> ChorusResult { + let url_api = user.belongs_to.borrow().urls.api.clone(); let url = if id.is_none() { format!("{}/users/@me/", url_api) } else { format!("{}/users/{}", url_api, id.unwrap()) }; - let request = reqwest::Client::new().get(url).bearer_auth(token); - let mut cloned_limits = instance.limits.clone(); - match LimitedRequester::send_request( + let request = reqwest::Client::new().get(url).bearer_auth(user.token()); + let chorus_request = ChorusRequest { request, - crate::api::limits::LimitType::Ip, - instance, - &mut cloned_limits, - ) - .await - { + limit_type: LimitType::Global, + }; + match chorus_request.send_request(user).await { Ok(result) => { let result_text = result.text().await.unwrap(); Ok(serde_json::from_str::(&result_text).unwrap()) @@ -131,18 +122,20 @@ impl User { let request: reqwest::RequestBuilder = Client::new() .get(format!("{}/users/@me/settings/", url_api)) .bearer_auth(token); - let mut cloned_limits = instance.limits.clone(); - match LimitedRequester::send_request( + let mut user = UserMeta::shell(Rc::new(RefCell::new(instance.clone())), token.clone()); + let chorus_request = ChorusRequest { request, - crate::api::limits::LimitType::Ip, - instance, - &mut cloned_limits, - ) - .await - { + limit_type: LimitType::Global, + }; + let result = match chorus_request.send_request(&mut user).await { Ok(result) => Ok(serde_json::from_str(&result.text().await.unwrap()).unwrap()), Err(e) => Err(e), + }; + if instance.limits_information.is_some() { + instance.limits_information.as_mut().unwrap().ratelimits = + user.belongs_to.borrow().clone_limits_if_some().unwrap(); } + result } } @@ -158,6 +151,12 @@ impl Instance { This function is a wrapper around [`User::get`]. */ pub async fn get_user(&mut self, token: String, id: Option<&String>) -> ChorusResult { - User::_get(&token, &self.urls.api.clone(), self, id).await + let mut user = UserMeta::shell(Rc::new(RefCell::new(self.clone())), token); + let result = User::get(&mut user, id).await; + if self.limits_information.is_some() { + self.limits_information.as_mut().unwrap().ratelimits = + user.belongs_to.borrow().clone_limits_if_some().unwrap(); + } + result } } diff --git a/src/errors.rs b/src/errors.rs index d8bda9c..bf08772 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,39 +1,50 @@ use custom_error::custom_error; +use reqwest::Error; custom_error! { #[derive(PartialEq, Eq)] - pub FieldFormatError - PasswordError = "Password must be between 1 and 72 characters.", - UsernameError = "Username must be between 2 and 32 characters.", - ConsentError = "Consent must be 'true' to register.", - EmailError = "The provided email address is in an invalid format.", + pub RegistrationError + Consent = "Consent must be 'true' to register.", } -pub type ChorusResult = std::result::Result; +pub type ChorusResult = std::result::Result; custom_error! { - #[derive(PartialEq, Eq)] - pub ChorusLibError + pub ChorusError + /// Server did not respond. NoResponse = "Did not receive a response from the Server.", - RequestErrorError{url:String, error:String} = "An error occured while trying to GET from {url}: {error}", - ReceivedErrorCodeError{error_code:String} = "Received the following error code while requesting from the route: {error_code}", - CantGetInfoError{error:String} = "Something seems to be wrong with the instance. Cannot get information about the instance: {error}", - InvalidFormBodyError{error_type: String, error:String} = "The server responded with: {error_type}: {error}", + /// Reqwest returned an Error instead of a Response object. + RequestFailed{url:String, error: Error} = "An error occured while trying to GET from {url}: {error}", + /// Response received, however, it was not of the successful responses type. Used when no other, special case applies. + ReceivedErrorCode{error_code: u16, error: String} = "Received the following error code while requesting from the route: {error_code}", + /// Used when there is likely something wrong with the instance, the request was directed to. + CantGetInformation{error:String} = "Something seems to be wrong with the instance. Cannot get information about the instance: {error}", + /// The requests form body was malformed/invalid. + InvalidFormBody{error_type: String, error:String} = "The server responded with: {error_type}: {error}", + /// The request has not been processed by the server due to a relevant rate limit bucket being exhausted. RateLimited{bucket:String} = "Ratelimited on Bucket {bucket}", - MultipartCreationError{error: String} = "Got an error whilst creating the form: {error}", - FormCreationError{error: String} = "Got an error whilst creating the form: {error}", + /// The multipart form could not be created. + MultipartCreation{error: String} = "Got an error whilst creating the form: {error}", + /// The regular form could not be created. + FormCreation{error: String} = "Got an error whilst creating the form: {error}", + /// The token is invalid. TokenExpired = "Token expired, invalid or not found.", + /// No permission NoPermission = "You do not have the permissions needed to perform this action.", + /// Resource not found NotFound{error: String} = "The provided resource hasn't been found: {error}", - PasswordRequiredError = "You need to provide your current password to authenticate for this action.", - InvalidResponseError{error: String} = "The response is malformed and cannot be processed. Error: {error}", - InvalidArgumentsError{error: String} = "Invalid arguments were provided. Error: {error}" + /// Used when you, for example, try to change your spacebar account password without providing your old password for verification. + PasswordRequired = "You need to provide your current password to authenticate for this action.", + /// Malformed or unexpected response. + InvalidResponse{error: String} = "The response is malformed and cannot be processed. Error: {error}", + /// Invalid, insufficient or too many arguments provided. + InvalidArguments{error: String} = "Invalid arguments were provided. Error: {error}" } custom_error! { #[derive(PartialEq, Eq)] pub ObserverError - AlreadySubscribedError = "Each event can only be subscribed to once." + AlreadySubscribed = "Each event can only be subscribed to once." } custom_error! { @@ -45,25 +56,25 @@ custom_error! { #[derive(PartialEq, Eq)] pub GatewayError // Errors we have received from the gateway - UnknownError = "We're not sure what went wrong. Try reconnecting?", - UnknownOpcodeError = "You sent an invalid Gateway opcode or an invalid payload for an opcode", - DecodeError = "Gateway server couldn't decode payload", - NotAuthenticatedError = "You sent a payload prior to identifying", - AuthenticationFailedError = "The account token sent with your identify payload is invalid", - AlreadyAuthenticatedError = "You've already identified, no need to reauthenticate", - InvalidSequenceNumberError = "The sequence number sent when resuming the session was invalid. Reconnect and start a new session", - RateLimitedError = "You are being rate limited!", - SessionTimedOutError = "Your session timed out. Reconnect and start a new one", - InvalidShardError = "You sent us an invalid shard when identifying", - ShardingRequiredError = "The session would have handled too many guilds - you are required to shard your connection in order to connect", - InvalidAPIVersionError = "You sent an invalid Gateway version", - InvalidIntentsError = "You sent an invalid intent", - DisallowedIntentsError = "You sent a disallowed intent. You may have tried to specify an intent that you have not enabled or are not approved for", + Unknown = "We're not sure what went wrong. Try reconnecting?", + UnknownOpcode = "You sent an invalid Gateway opcode or an invalid payload for an opcode", + Decode = "Gateway server couldn't decode payload", + NotAuthenticated = "You sent a payload prior to identifying", + AuthenticationFailed = "The account token sent with your identify payload is invalid", + AlreadyAuthenticated = "You've already identified, no need to reauthenticate", + InvalidSequenceNumber = "The sequence number sent when resuming the session was invalid. Reconnect and start a new session", + RateLimited = "You are being rate limited!", + SessionTimedOut = "Your session timed out. Reconnect and start a new one", + InvalidShard = "You sent us an invalid shard when identifying", + ShardingRequired = "The session would have handled too many guilds - you are required to shard your connection in order to connect", + InvalidAPIVersion = "You sent an invalid Gateway version", + InvalidIntents = "You sent an invalid intent", + DisallowedIntents = "You sent a disallowed intent. You may have tried to specify an intent that you have not enabled or are not approved for", // Errors when initiating a gateway connection - CannotConnectError{error: String} = "Cannot connect due to a tungstenite error: {error}", - NonHelloOnInitiateError{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong", + CannotConnect{error: String} = "Cannot connect due to a tungstenite error: {error}", + NonHelloOnInitiate{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong", // Other misc errors - UnexpectedOpcodeReceivedError{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}", + UnexpectedOpcodeReceived{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}", } diff --git a/src/gateway.rs b/src/gateway.rs index 87e1589..ed96aac 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -8,7 +8,7 @@ use futures_util::stream::SplitSink; use futures_util::stream::SplitStream; use futures_util::SinkExt; use futures_util::StreamExt; -use log::{debug, info, trace, warn}; +use log::{info, trace, warn}; use native_tls::TlsConnector; use tokio::net::TcpStream; use tokio::sync::mpsc::error::TryRecvError; @@ -95,25 +95,21 @@ impl GatewayMessage { let processed_content = content.to_lowercase().replace('.', ""); match processed_content.as_str() { - "unknown error" | "4000" => Some(GatewayError::UnknownError), - "unknown opcode" | "4001" => Some(GatewayError::UnknownOpcodeError), - "decode error" | "error while decoding payload" | "4002" => { - Some(GatewayError::DecodeError) - } - "not authenticated" | "4003" => Some(GatewayError::NotAuthenticatedError), - "authentication failed" | "4004" => Some(GatewayError::AuthenticationFailedError), - "already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticatedError), - "invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumberError), - "rate limited" | "4008" => Some(GatewayError::RateLimitedError), - "session timed out" | "4009" => Some(GatewayError::SessionTimedOutError), - "invalid shard" | "4010" => Some(GatewayError::InvalidShardError), - "sharding required" | "4011" => Some(GatewayError::ShardingRequiredError), - "invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersionError), - "invalid intent(s)" | "invalid intent" | "4013" => { - Some(GatewayError::InvalidIntentsError) - } + "unknown error" | "4000" => Some(GatewayError::Unknown), + "unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode), + "decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode), + "not authenticated" | "4003" => Some(GatewayError::NotAuthenticated), + "authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed), + "already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated), + "invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber), + "rate limited" | "4008" => Some(GatewayError::RateLimited), + "session timed out" | "4009" => Some(GatewayError::SessionTimedOut), + "invalid shard" | "4010" => Some(GatewayError::InvalidShard), + "sharding required" | "4011" => Some(GatewayError::ShardingRequired), + "invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion), + "invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents), "disallowed intent(s)" | "disallowed intents" | "4014" => { - Some(GatewayError::DisallowedIntentsError) + Some(GatewayError::DisallowedIntents) } _ => None, } @@ -294,7 +290,7 @@ impl Gateway { { Ok(websocket_stream) => websocket_stream, Err(e) => { - return Err(GatewayError::CannotConnectError { + return Err(GatewayError::CannotConnect { error: e.to_string(), }) } @@ -314,7 +310,7 @@ impl Gateway { serde_json::from_str(msg.to_text().unwrap()).unwrap(); if gateway_payload.op_code != GATEWAY_HELLO { - return Err(GatewayError::NonHelloOnInitiateError { + return Err(GatewayError::NonHelloOnInitiate { opcode: gateway_payload.op_code, }); } @@ -1493,7 +1489,7 @@ impl Gateway { | GATEWAY_REQUEST_GUILD_MEMBERS | GATEWAY_CALL_SYNC | GATEWAY_LAZY_REQUEST => { - let error = GatewayError::UnexpectedOpcodeReceivedError { + let error = GatewayError::UnexpectedOpcodeReceived { opcode: gateway_payload.op_code, }; Err::<(), GatewayError>(error).unwrap(); @@ -1521,6 +1517,7 @@ impl Gateway { } /// Handles sending heartbeats to the gateway in another thread +#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used struct HeartbeatHandler { /// The heartbeat interval in milliseconds pub heartbeat_interval: u128, diff --git a/src/instance.rs b/src/instance.rs index 2477217..15c513c 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -1,26 +1,36 @@ use std::cell::RefCell; +use std::collections::HashMap; use std::fmt; use std::rc::Rc; use reqwest::Client; use serde::{Deserialize, Serialize}; -use crate::api::limits::Limits; -use crate::errors::{ChorusLibError, ChorusResult, FieldFormatError}; +use crate::api::{Limit, LimitType}; +use crate::errors::ChorusResult; +use crate::ratelimiter::ChorusRequest; +use crate::types::types::subconfigs::limits::rates::RateLimits; use crate::types::{GeneralConfiguration, User, UserSettings}; use crate::UrlBundle; #[derive(Debug, Clone)] /** The [`Instance`] what you will be using to perform all sorts of actions on the Spacebar server. +If `limits_information` is `None`, then the instance will not be rate limited. */ pub struct Instance { pub urls: UrlBundle, pub instance_info: GeneralConfiguration, - pub limits: Limits, + pub limits_information: Option, pub client: Client, } +#[derive(Debug, Clone)] +pub struct LimitsInformation { + pub ratelimits: HashMap, + pub configuration: RateLimits, +} + impl Instance { /// Creates a new [`Instance`]. /// # Arguments @@ -28,24 +38,43 @@ impl Instance { /// * `requester` - The [`LimitedRequester`] that will be used to make requests to the Spacebar server. /// # Errors /// * [`InstanceError`] - If the instance cannot be created. - pub async fn new(urls: UrlBundle) -> ChorusResult { + pub async fn new(urls: UrlBundle, limited: bool) -> ChorusResult { + let limits_information; + if limited { + let limits_configuration = + Some(ChorusRequest::get_limits_config(&urls.api).await?.rate); + let limits = Some(ChorusRequest::limits_config_to_hashmap( + limits_configuration.as_ref().unwrap(), + )); + limits_information = Some(LimitsInformation { + ratelimits: limits.unwrap(), + configuration: limits_configuration.unwrap(), + }); + } else { + limits_information = None; + } let mut instance = Instance { urls: urls.clone(), // Will be overwritten in the next step instance_info: GeneralConfiguration::default(), - limits: Limits::check_limits(urls.api).await, + limits_information, client: Client::new(), }; instance.instance_info = match instance.general_configuration_schema().await { Ok(schema) => schema, Err(e) => { - return Err(ChorusLibError::CantGetInfoError { - error: e.to_string(), - }); + log::warn!("Could not get instance configuration schema: {}", e); + GeneralConfiguration::default() } }; Ok(instance) } + pub(crate) fn clone_limits_if_some(&self) -> Option> { + if self.limits_information.is_some() { + return Some(self.limits_information.as_ref().unwrap().ratelimits.clone()); + } + None + } } #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -59,30 +88,11 @@ impl fmt::Display for Token { } } -#[derive(Debug, PartialEq, Eq)] -pub struct Username { - pub username: String, -} - -impl Username { - /// Creates a new [`Username`]. - /// # Arguments - /// * `username` - The username that will be used to create the [`Username`]. - /// # Errors - /// * [`UsernameFormatError`] - If the username is not between 2 and 32 characters. - pub fn new(username: String) -> Result { - if username.len() < 2 || username.len() > 32 { - return Err(FieldFormatError::UsernameError); - } - Ok(Username { username }) - } -} - #[derive(Debug)] pub struct UserMeta { pub belongs_to: Rc>, pub token: String, - pub limits: Limits, + pub limits: Option>, pub settings: UserSettings, pub object: User, } @@ -99,7 +109,7 @@ impl UserMeta { pub fn new( belongs_to: Rc>, token: String, - limits: Limits, + limits: Option>, settings: UserSettings, object: User, ) -> UserMeta { @@ -111,4 +121,24 @@ impl UserMeta { object, } } + + /// Creates a new 'shell' of a user. The user does not exist as an object, and exists so that you have + /// a UserMeta object to make Rate Limited requests with. This is useful in scenarios like + /// registering or logging in to the Instance, where you do not yet have a User object, but still + /// need to make a RateLimited request. + pub(crate) fn shell(instance: Rc>, token: String) -> UserMeta { + let settings = UserSettings::default(); + let object = User::default(); + UserMeta { + belongs_to: instance.clone(), + token, + limits: instance + .borrow() + .limits_information + .as_ref() + .map(|info| info.ratelimits.clone()), + settings, + object, + } + } } diff --git a/src/lib.rs b/src/lib.rs index bfda613..77425b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ pub mod gateway; #[cfg(feature = "client")] pub mod instance; #[cfg(feature = "client")] -pub mod limit; +pub mod ratelimiter; pub mod types; #[cfg(feature = "client")] pub mod voice; diff --git a/src/limit.rs b/src/limit.rs deleted file mode 100644 index 4415e0c..0000000 --- a/src/limit.rs +++ /dev/null @@ -1,304 +0,0 @@ -use reqwest::{RequestBuilder, Response}; - -use crate::{ - api::limits::{Limit, LimitType, Limits, LimitsMutRef}, - errors::{ChorusLibError, ChorusResult}, - instance::Instance, -}; - -#[derive(Debug)] -pub struct LimitedRequester; - -impl LimitedRequester { - /// Checks if a request can be sent without hitting API rate limits and sends it, if true. - /// Will automatically update the rate limits of the LimitedRequester the request has been - /// sent with. - /// - /// # Arguments - /// - /// * `request`: A `RequestBuilder` that contains a request ready to be sent. Unfinished or - /// invalid requests will result in the method panicing. - /// * `limit_type`: Because this library does not yet implement a way to check for which rate - /// limit will be used when the request gets send, you will have to specify this manually using - /// a `LimitType` enum. - /// - /// # Returns - /// - /// * `Response`: The `Response` gotten from sending the request to the server. This will be - /// returned if the Request was built and send successfully. Is wrapped in an `Option`. - /// * `None`: `None` will be returned if the rate limit has been hit, and the request could - /// therefore not have been sent. - /// - /// # Errors - /// - /// This method will error if: - /// - /// * The request does not return a success status code (200-299) - /// * The supplied `RequestBuilder` contains invalid or incomplete information - /// * There has been an error with processing (unwrapping) the `Response` - /// * The call to `update_limits` yielded errors. Read the methods' Errors section for more - /// information. - pub async fn send_request( - request: RequestBuilder, - limit_type: LimitType, - instance: &mut Instance, - user_rate_limits: &mut Limits, - ) -> ChorusResult { - if LimitedRequester::can_send_request(limit_type, &instance.limits, user_rate_limits) { - let built_request = match request.build() { - Ok(request) => request, - Err(e) => { - return Err(ChorusLibError::RequestErrorError { - url: "".to_string(), - error: e.to_string(), - }); - } - }; - let result = instance.client.execute(built_request).await; - let response = match result { - Ok(is_response) => is_response, - Err(e) => { - return Err(ChorusLibError::ReceivedErrorCodeError { - error_code: e.to_string(), - }); - } - }; - LimitedRequester::update_limits( - &response, - limit_type, - &mut instance.limits, - user_rate_limits, - ); - if !response.status().is_success() { - match response.status().as_u16() { - 401 => Err(ChorusLibError::TokenExpired), - 403 => Err(ChorusLibError::TokenExpired), - _ => Err(ChorusLibError::ReceivedErrorCodeError { - error_code: response.status().as_str().to_string(), - }), - } - } else { - Ok(response) - } - } else { - Err(ChorusLibError::RateLimited { - bucket: limit_type.to_string(), - }) - } - } - - fn update_limit_entry(entry: &mut Limit, reset: u64, remaining: u64, limit: u64) { - if reset != entry.reset { - entry.reset = reset; - entry.remaining = limit; - entry.limit = limit; - } else { - entry.remaining = remaining; - entry.limit = limit; - } - } - - fn can_send_request( - limit_type: LimitType, - instance_rate_limits: &Limits, - user_rate_limits: &Limits, - ) -> bool { - // Check if all of the limits in this vec have at least one remaining request - - let rate_limits = Limits::combine(instance_rate_limits, user_rate_limits); - - let constant_limits: Vec<&LimitType> = [ - &LimitType::Error, - &LimitType::Global, - &LimitType::Ip, - &limit_type, - ] - .to_vec(); - for limit in constant_limits.iter() { - match rate_limits.to_hash_map().get(limit) { - Some(limit) => { - if limit.remaining == 0 { - return false; - } - // AbsoluteRegister and AuthRegister can cancel each other out. - if limit.bucket == LimitType::AbsoluteRegister - && rate_limits - .to_hash_map() - .get(&LimitType::AuthRegister) - .unwrap() - .remaining - == 0 - { - return false; - } - if limit.bucket == LimitType::AuthRegister - && rate_limits - .to_hash_map() - .get(&LimitType::AbsoluteRegister) - .unwrap() - .remaining - == 0 - { - return false; - } - } - None => return false, - } - } - true - } - - fn update_limits( - response: &Response, - limit_type: LimitType, - instance_rate_limits: &mut Limits, - user_rate_limits: &mut Limits, - ) { - let mut rate_limits = LimitsMutRef::combine_mut_ref(instance_rate_limits, user_rate_limits); - - let remaining = match response.headers().get("X-RateLimit-Remaining") { - Some(remaining) => remaining.to_str().unwrap().parse::().unwrap(), - None => rate_limits.get_limit_mut_ref(&limit_type).remaining - 1, - }; - let limit = match response.headers().get("X-RateLimit-Limit") { - Some(limit) => limit.to_str().unwrap().parse::().unwrap(), - None => rate_limits.get_limit_mut_ref(&limit_type).limit, - }; - let reset = match response.headers().get("X-RateLimit-Reset") { - Some(reset) => reset.to_str().unwrap().parse::().unwrap(), - None => rate_limits.get_limit_mut_ref(&limit_type).reset, - }; - - let status = response.status(); - let status_str = status.as_str(); - - if status_str.starts_with('4') { - rate_limits - .get_limit_mut_ref(&LimitType::Error) - .add_remaining(-1); - } - - rate_limits - .get_limit_mut_ref(&LimitType::Global) - .add_remaining(-1); - - rate_limits - .get_limit_mut_ref(&LimitType::Ip) - .add_remaining(-1); - - match limit_type { - LimitType::Error => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::Error); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - } - LimitType::Global => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::Global); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - } - LimitType::Ip => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::Ip); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - } - LimitType::AuthLogin => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthLogin); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - } - LimitType::AbsoluteRegister => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteRegister); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - // AbsoluteRegister and AuthRegister both need to be updated, if a Register event - // happens. - rate_limits - .get_limit_mut_ref(&LimitType::AuthRegister) - .remaining -= 1; - } - LimitType::AuthRegister => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthRegister); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - // AbsoluteRegister and AuthRegister both need to be updated, if a Register event - // happens. - rate_limits - .get_limit_mut_ref(&LimitType::AbsoluteRegister) - .remaining -= 1; - } - LimitType::AbsoluteMessage => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteMessage); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - } - LimitType::Channel => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::Channel); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - } - LimitType::Guild => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::Guild); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - } - LimitType::Webhook => { - let entry = rate_limits.get_limit_mut_ref(&LimitType::Webhook); - LimitedRequester::update_limit_entry(entry, reset, remaining, limit); - } - } - } -} - -#[cfg(test)] -mod rate_limit { - use serde_json::from_str; - - use crate::{api::limits::Config, UrlBundle}; - - use super::*; - - #[tokio::test] - async fn run_into_limit() { - let urls = UrlBundle::new( - String::from("http://localhost:3001/api/"), - String::from("wss://localhost:3001/"), - String::from("http://localhost:3001/cdn"), - ); - let mut request: Option> = None; - let mut instance = Instance::new(urls.clone()).await.unwrap(); - let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await; - - for _ in 0..=50 { - let request_path = urls.api.clone() + "/some/random/nonexisting/path"; - let request_builder = instance.client.get(request_path); - request = Some( - LimitedRequester::send_request( - request_builder, - LimitType::Channel, - &mut instance, - &mut user_rate_limits, - ) - .await, - ); - } - assert!(matches!(request, Some(Err(_)))); - } - - #[tokio::test] - async fn test_send_request() { - let urls = UrlBundle::new( - String::from("http://localhost:3001/api/"), - String::from("wss://localhost:3001/"), - String::from("http://localhost:3001/cdn"), - ); - let mut instance = Instance::new(urls.clone()).await.unwrap(); - let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await; - let _requester = LimitedRequester; - let request_path = urls.api.clone() + "/policies/instance/limits"; - let request_builder = instance.client.get(request_path); - let request = LimitedRequester::send_request( - request_builder, - LimitType::Channel, - &mut instance, - &mut user_rate_limits, - ) - .await; - let result = match request { - Ok(result) => result, - Err(_) => panic!("Request failed"), - }; - let _config: Config = from_str(result.text().await.unwrap().as_str()).unwrap(); - } -} diff --git a/src/ratelimiter.rs b/src/ratelimiter.rs new file mode 100644 index 0000000..629c3d7 --- /dev/null +++ b/src/ratelimiter.rs @@ -0,0 +1,462 @@ +use std::collections::HashMap; + +use log; +use reqwest::{Client, RequestBuilder, Response}; +use serde::Deserialize; +use serde_json::from_str; + +use crate::{ + api::{Limit, LimitType}, + errors::{ChorusError, ChorusResult}, + instance::UserMeta, + types::{types::subconfigs::limits::rates::RateLimits, LimitsConfiguration}, +}; + +/// Chorus' request struct. This struct is used to send rate-limited requests to the Spacebar server. +/// See for more information. +pub struct ChorusRequest { + pub request: RequestBuilder, + pub limit_type: LimitType, +} + +impl ChorusRequest { + /// Sends a [`ChorusRequest`]. Checks if the user is rate limited, and if not, sends the request. + /// If the user is not rate limited and the instance has rate limits enabled, it will update the + /// rate limits. + #[allow(clippy::await_holding_refcell_ref)] + pub(crate) async fn send_request(self, user: &mut UserMeta) -> ChorusResult { + if !ChorusRequest::can_send_request(user, &self.limit_type) { + log::info!("Rate limit hit. Bucket: {:?}", self.limit_type); + return Err(ChorusError::RateLimited { + bucket: format!("{:?}", self.limit_type), + }); + } + let belongs_to = user.belongs_to.borrow(); + let result = match belongs_to + .client + .execute(self.request.build().unwrap()) + .await + { + Ok(result) => result, + Err(error) => { + log::warn!("Request failed: {:?}", error); + return Err(ChorusError::RequestFailed { + url: error.url().unwrap().to_string(), + error, + }); + } + }; + drop(belongs_to); + if !result.status().is_success() { + if result.status().as_u16() == 429 { + log::warn!("Rate limit hit unexpectedly. Bucket: {:?}. Setting the instances' remaining global limit to 0 to have cooldown.", self.limit_type); + user.belongs_to + .borrow_mut() + .limits_information + .as_mut() + .unwrap() + .ratelimits + .get_mut(&LimitType::Global) + .unwrap() + .remaining = 0; + return Err(ChorusError::RateLimited { + bucket: format!("{:?}", self.limit_type), + }); + } + log::warn!("Request failed: {:?}", result); + return Err(ChorusRequest::interpret_error(result).await); + } + ChorusRequest::update_rate_limits(user, &self.limit_type, !result.status().is_success()); + Ok(result) + } + + fn can_send_request(user: &mut UserMeta, limit_type: &LimitType) -> bool { + log::trace!("Checking if user or instance is rate-limited..."); + let mut belongs_to = user.belongs_to.borrow_mut(); + if belongs_to.limits_information.is_none() { + log::trace!("Instance indicates no rate limits are configured. Continuing."); + return true; + } + let instance_dictated_limits = [ + &LimitType::AuthLogin, + &LimitType::AuthRegister, + &LimitType::Global, + &LimitType::Ip, + ]; + let limits = match instance_dictated_limits.contains(&limit_type) { + true => { + log::trace!( + "Limit type {:?} is dictated by the instance. Continuing.", + limit_type + ); + belongs_to + .limits_information + .as_mut() + .unwrap() + .ratelimits + .clone() + } + false => { + log::trace!( + "Limit type {:?} is dictated by the user. Continuing.", + limit_type + ); + ChorusRequest::ensure_limit_in_map( + &belongs_to + .limits_information + .as_ref() + .unwrap() + .configuration, + user.limits.as_mut().unwrap(), + limit_type, + ); + user.limits.as_mut().unwrap().clone() + } + }; + let global = belongs_to + .limits_information + .as_ref() + .unwrap() + .ratelimits + .get(&LimitType::Global) + .unwrap(); + let ip = belongs_to + .limits_information + .as_ref() + .unwrap() + .ratelimits + .get(&LimitType::Ip) + .unwrap(); + let limit_type_limit = limits.get(limit_type).unwrap(); + global.remaining > 0 && ip.remaining > 0 && limit_type_limit.remaining > 0 + } + + fn ensure_limit_in_map( + rate_limits_config: &RateLimits, + map: &mut HashMap, + limit_type: &LimitType, + ) { + log::trace!("Ensuring limit type {:?} is in the map.", limit_type); + let time: u64 = chrono::Utc::now().timestamp() as u64; + match limit_type { + LimitType::Channel(snowflake) => { + if map.get(&LimitType::Channel(*snowflake)).is_some() { + log::trace!( + "Limit type {:?} is already in the map. Returning.", + limit_type + ); + return; + } + log::trace!("Limit type {:?} is not in the map. Adding it.", limit_type); + let channel_limit = &rate_limits_config.routes.channel; + map.insert( + LimitType::Channel(*snowflake), + Limit { + bucket: LimitType::Channel(*snowflake), + limit: channel_limit.count, + remaining: channel_limit.count, + reset: channel_limit.window + time, + window: channel_limit.window, + }, + ); + } + LimitType::Guild(snowflake) => { + if map.get(&LimitType::Guild(*snowflake)).is_some() { + return; + } + let guild_limit = &rate_limits_config.routes.guild; + map.insert( + LimitType::Guild(*snowflake), + Limit { + bucket: LimitType::Guild(*snowflake), + limit: guild_limit.count, + remaining: guild_limit.count, + reset: guild_limit.window + time, + window: guild_limit.window, + }, + ); + } + LimitType::Webhook(snowflake) => { + if map.get(&LimitType::Webhook(*snowflake)).is_some() { + return; + } + let webhook_limit = &rate_limits_config.routes.webhook; + map.insert( + LimitType::Webhook(*snowflake), + Limit { + bucket: LimitType::Webhook(*snowflake), + limit: webhook_limit.count, + remaining: webhook_limit.count, + reset: webhook_limit.window + time, + window: webhook_limit.window, + }, + ); + } + other_limit => { + if map.get(other_limit).is_some() { + return; + } + let limits_map = ChorusRequest::limits_config_to_hashmap(rate_limits_config); + map.insert( + *other_limit, + Limit { + bucket: *other_limit, + limit: limits_map.get(other_limit).as_ref().unwrap().limit, + remaining: limits_map.get(other_limit).as_ref().unwrap().remaining, + reset: limits_map.get(other_limit).as_ref().unwrap().reset, + window: limits_map.get(other_limit).as_ref().unwrap().window, + }, + ); + } + } + } + + async fn interpret_error(response: reqwest::Response) -> ChorusError { + match response.status().as_u16() { + 401..=403 | 407 => ChorusError::NoPermission, + 404 => ChorusError::NotFound { + error: response.text().await.unwrap(), + }, + 405 | 408 | 409 => ChorusError::ReceivedErrorCode { error_code: response.status().as_u16(), error: response.text().await.unwrap() }, + 411..=421 | 426 | 428 | 431 => ChorusError::InvalidArguments { + error: response.text().await.unwrap(), + }, + 429 => panic!("Illegal state: Rate limit exception should have been caught before this function call."), + 451 => ChorusError::NoResponse, + 500..=599 => ChorusError::ReceivedErrorCode { error_code: response.status().as_u16(), error: response.text().await.unwrap() }, + _ => ChorusError::ReceivedErrorCode { error_code: response.status().as_u16(), error: response.text().await.unwrap()}, + } + } + + /// Updates the rate limits of the user. The following steps are performed: + /// 1. If the current unix timestamp is greater than the reset timestamp, the reset timestamp is + /// set to the current unix timestamp + the rate limit window. The remaining rate limit is + /// reset to the rate limit limit. + /// 2. The remaining rate limit is decreased by 1. + fn update_rate_limits(user: &mut UserMeta, limit_type: &LimitType, response_was_err: bool) { + let instance_dictated_limits = [ + &LimitType::AuthLogin, + &LimitType::AuthRegister, + &LimitType::Global, + &LimitType::Ip, + ]; + // modify this to store something to look up the value with later, instead of storing a reference to the actual data itself. + let mut relevant_limits = Vec::new(); + if instance_dictated_limits.contains(&limit_type) { + relevant_limits.push((LimitOrigin::Instance, *limit_type)); + } else { + relevant_limits.push((LimitOrigin::User, *limit_type)); + } + relevant_limits.push((LimitOrigin::Instance, LimitType::Global)); + relevant_limits.push((LimitOrigin::Instance, LimitType::Ip)); + if response_was_err { + relevant_limits.push((LimitOrigin::User, LimitType::Error)); + } + let time: u64 = chrono::Utc::now().timestamp() as u64; + for relevant_limit in relevant_limits.iter() { + let mut belongs_to = user.belongs_to.borrow_mut(); + let limit = match relevant_limit.0 { + LimitOrigin::Instance => { + log::trace!( + "Updating instance rate limit. Bucket: {:?}", + relevant_limit.1 + ); + belongs_to + .limits_information + .as_mut() + .unwrap() + .ratelimits + .get_mut(&relevant_limit.1) + .unwrap() + } + LimitOrigin::User => { + log::trace!("Updating user rate limit. Bucket: {:?}", relevant_limit.1); + user.limits + .as_mut() + .unwrap() + .get_mut(&relevant_limit.1) + .unwrap() + } + }; + if time > limit.reset { + // Spacebar does not yet return rate limit information in its response headers. We + // therefore have to guess the next rate limit window. This is not ideal. Oh well! + log::trace!("Rate limit replenished. Bucket: {:?}", limit.bucket); + limit.reset += limit.window; + limit.remaining = limit.limit; + } + limit.remaining -= 1; + } + } + + pub(crate) async fn get_limits_config(url_api: &str) -> ChorusResult { + let request = Client::new() + .get(format!("{}/policies/instance/limits/", url_api)) + .send() + .await; + let request = match request { + Ok(request) => request, + Err(e) => { + return Err(ChorusError::RequestFailed { + url: url_api.to_string(), + error: e, + }) + } + }; + let limits_configuration = match request.status().as_u16() { + 200 => from_str::(&request.text().await.unwrap()).unwrap(), + 429 => { + return Err(ChorusError::RateLimited { + bucket: format!("{:?}", LimitType::Ip), + }) + } + 404 => return Err(ChorusError::NotFound { error: "Route \"/policies/instance/limits/\" not found. Are you perhaps trying to request the Limits configuration from an unsupported server?".to_string() }), + 400..=u16::MAX => { + return Err(ChorusError::ReceivedErrorCode { error_code: request.status().as_u16(), error: request.text().await.unwrap() }) + } + _ => { + return Err(ChorusError::InvalidResponse { + error: request.text().await.unwrap(), + }) + } + }; + + Ok(limits_configuration) + } + + pub(crate) fn limits_config_to_hashmap( + limits_configuration: &RateLimits, + ) -> HashMap { + let config = limits_configuration.clone(); + let routes = config.routes; + let mut map: HashMap = HashMap::new(); + let time: u64 = chrono::Utc::now().timestamp() as u64; + map.insert( + LimitType::AuthLogin, + Limit { + bucket: LimitType::AuthLogin, + limit: routes.auth.login.count, + remaining: routes.auth.login.count, + reset: routes.auth.login.window + time, + window: routes.auth.login.window, + }, + ); + map.insert( + LimitType::AuthRegister, + Limit { + bucket: LimitType::AuthRegister, + limit: routes.auth.register.count, + remaining: routes.auth.register.count, + reset: routes.auth.register.window + time, + window: routes.auth.register.window, + }, + ); + map.insert( + LimitType::ChannelBaseline, + Limit { + bucket: LimitType::ChannelBaseline, + limit: routes.channel.count, + remaining: routes.channel.count, + reset: routes.channel.window + time, + window: routes.channel.window, + }, + ); + map.insert( + LimitType::Error, + Limit { + bucket: LimitType::Error, + limit: config.error.count, + remaining: config.error.count, + reset: config.error.window + time, + window: config.error.window, + }, + ); + map.insert( + LimitType::Global, + Limit { + bucket: LimitType::Global, + limit: config.global.count, + remaining: config.global.count, + reset: config.global.window + time, + window: config.global.window, + }, + ); + map.insert( + LimitType::Ip, + Limit { + bucket: LimitType::Ip, + limit: config.ip.count, + remaining: config.ip.count, + reset: config.ip.window + time, + window: config.ip.window, + }, + ); + map.insert( + LimitType::GuildBaseline, + Limit { + bucket: LimitType::GuildBaseline, + limit: routes.guild.count, + remaining: routes.guild.count, + reset: routes.guild.window + time, + window: routes.guild.window, + }, + ); + map.insert( + LimitType::WebhookBaseline, + Limit { + bucket: LimitType::WebhookBaseline, + limit: routes.webhook.count, + remaining: routes.webhook.count, + reset: routes.webhook.window + time, + window: routes.webhook.window, + }, + ); + map + } + + /// Sends a [`ChorusRequest`] and returns a [`ChorusResult`] that contains nothing if the request + /// was successful, or a [`ChorusError`] if the request failed. + pub(crate) async fn handle_request_as_result(self, user: &mut UserMeta) -> ChorusResult<()> { + match self.send_request(user).await { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + + /// Sends a [`ChorusRequest`] and returns a [`ChorusResult`] that contains a [`T`] if the request + /// was successful, or a [`ChorusError`] if the request failed. + pub(crate) async fn deserialize_response Deserialize<'a>>( + self, + user: &mut UserMeta, + ) -> ChorusResult { + let response = self.send_request(user).await?; + let response_text = match response.text().await { + Ok(string) => string, + Err(e) => { + return Err(ChorusError::InvalidResponse { + error: format!( + "Error while trying to process the HTTP response into a String: {}", + e + ), + }); + } + }; + let object = match from_str::(&response_text) { + Ok(object) => object, + Err(e) => { + return Err(ChorusError::InvalidResponse { + error: format!( + "Error while trying to deserialize the JSON response into T: {}", + e + ), + }) + } + }; + Ok(object) + } +} + +enum LimitOrigin { + Instance, + User, +} diff --git a/src/types/config/types/general_configuration.rs b/src/types/config/types/general_configuration.rs index 07444b0..13b3aa8 100644 --- a/src/types/config/types/general_configuration.rs +++ b/src/types/config/types/general_configuration.rs @@ -18,10 +18,8 @@ pub struct GeneralConfiguration { impl Default for GeneralConfiguration { fn default() -> Self { Self { - instance_name: String::from("Spacebar Instance"), - instance_description: Some(String::from( - "This is a Spacebar instance made in the pre-release days", - )), + instance_name: String::from("Spacebar-compatible Instance"), + instance_description: Some(String::from("This is a spacebar-compatible instance.")), front_page: None, tos_page: None, correspondence_email: None, diff --git a/src/types/config/types/subconfigs/limits/rates.rs b/src/types/config/types/subconfigs/limits/rates.rs index 9d0cab1..ce1ea60 100644 --- a/src/types/config/types/subconfigs/limits/rates.rs +++ b/src/types/config/types/subconfigs/limits/rates.rs @@ -1,7 +1,12 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; -use crate::types::config::types::subconfigs::limits::ratelimits::{ - route::RouteRateLimit, RateLimitOptions, +use crate::{ + api::LimitType, + types::config::types::subconfigs::limits::ratelimits::{ + route::RouteRateLimit, RateLimitOptions, + }, }; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -39,3 +44,18 @@ impl Default for RateLimits { } } } + +impl RateLimits { + pub fn to_hash_map(&self) -> HashMap { + let mut map = HashMap::new(); + map.insert(LimitType::AuthLogin, self.routes.auth.login.clone()); + map.insert(LimitType::AuthRegister, self.routes.auth.register.clone()); + map.insert(LimitType::ChannelBaseline, self.routes.channel.clone()); + map.insert(LimitType::Error, self.error.clone()); + map.insert(LimitType::Global, self.global.clone()); + map.insert(LimitType::Ip, self.ip.clone()); + map.insert(LimitType::WebhookBaseline, self.routes.webhook.clone()); + map.insert(LimitType::GuildBaseline, self.routes.guild.clone()); + map + } +} diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index d0059fc..b240fa2 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -1,9 +1,8 @@ +use crate::types::utils::Snowflake; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serde_aux::prelude::deserialize_option_number_from_string; -use crate::types::utils::Snowflake; - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] pub struct UserData { @@ -16,7 +15,6 @@ impl User { PublicUser::from(self) } } - #[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, Eq)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct User { @@ -89,6 +87,7 @@ impl From for PublicUser { } } +#[allow(dead_code)] // FIXME: Remove this when we actually use this const CUSTOM_USER_FLAG_OFFSET: u64 = 1 << 32; bitflags::bitflags! { diff --git a/src/types/utils/rights.rs b/src/types/utils/rights.rs index 0198af6..fecf268 100644 --- a/src/types/utils/rights.rs +++ b/src/types/utils/rights.rs @@ -73,6 +73,7 @@ impl Rights { } } +#[allow(dead_code)] // FIXME: Remove this when we use this fn all_rights() -> Rights { Rights::OPERATOR | Rights::MANAGE_APPLICATIONS diff --git a/src/types/utils/snowflake.rs b/src/types/utils/snowflake.rs index 8502275..6176ea5 100644 --- a/src/types/utils/snowflake.rs +++ b/src/types/utils/snowflake.rs @@ -12,7 +12,7 @@ const EPOCH: i64 = 1420070400000; /// Unique identifier including a timestamp. /// See https://discord.com/developers/docs/reference#snowflakes -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "sqlx", derive(Type))] #[cfg_attr(feature = "sqlx", sqlx(transparent))] pub struct Snowflake(u64); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 9a62585..0b48f89 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -24,7 +24,7 @@ pub async fn setup() -> TestBundle { "ws://localhost:3001".to_string(), "http://localhost:3001".to_string(), ); - let mut instance = Instance::new(urls.clone()).await.unwrap(); + let mut instance = Instance::new(urls.clone(), true).await.unwrap(); // Requires the existance of the below user. let reg = RegisterSchema { username: "integrationtestuser".into(), diff --git a/tests/relationships.rs b/tests/relationships.rs index e23f0f3..5578efd 100644 --- a/tests/relationships.rs +++ b/tests/relationships.rs @@ -19,7 +19,7 @@ async fn test_get_mutual_relationships() { username: user.object.username.clone(), discriminator: Some(user.object.discriminator.clone()), }; - other_user.send_friend_request(friend_request_schema).await; + let _ = other_user.send_friend_request(friend_request_schema).await; let relationships = user .get_mutual_relationships(other_user.object.id) .await @@ -45,7 +45,10 @@ async fn test_get_relationships() { username: user.object.username.clone(), discriminator: Some(user.object.discriminator.clone()), }; - other_user.send_friend_request(friend_request_schema).await; + other_user + .send_friend_request(friend_request_schema) + .await + .unwrap(); let relationships = user.get_relationships().await.unwrap(); assert_eq!(relationships.get(0).unwrap().id, other_user.object.id); common::teardown(bundle).await @@ -64,7 +67,7 @@ async fn test_modify_relationship_friends() { let belongs_to = &mut bundle.instance; let user = &mut bundle.user; let mut other_user = belongs_to.register_account(®ister_schema).await.unwrap(); - other_user + let _ = other_user .modify_user_relationship(user.object.id, types::RelationshipType::Friends) .await; let relationships = user.get_relationships().await.unwrap(); @@ -79,7 +82,8 @@ async fn test_modify_relationship_friends() { relationships.get(0).unwrap().relationship_type, RelationshipType::Outgoing ); - user.modify_user_relationship(other_user.object.id, RelationshipType::Friends) + let _ = user + .modify_user_relationship(other_user.object.id, RelationshipType::Friends) .await; assert_eq!( other_user @@ -91,7 +95,7 @@ async fn test_modify_relationship_friends() { .relationship_type, RelationshipType::Friends ); - user.remove_relationship(other_user.object.id).await; + let _ = user.remove_relationship(other_user.object.id).await; assert_eq!( other_user.get_relationships().await.unwrap(), Vec::::new() @@ -112,7 +116,7 @@ async fn test_modify_relationship_block() { let belongs_to = &mut bundle.instance; let user = &mut bundle.user; let mut other_user = belongs_to.register_account(®ister_schema).await.unwrap(); - other_user + let _ = other_user .modify_user_relationship(user.object.id, types::RelationshipType::Blocked) .await; let relationships = user.get_relationships().await.unwrap(); @@ -123,7 +127,7 @@ async fn test_modify_relationship_block() { relationships.get(0).unwrap().relationship_type, RelationshipType::Blocked ); - other_user.remove_relationship(user.object.id).await; + let _ = other_user.remove_relationship(user.object.id).await; assert_eq!( other_user.get_relationships().await.unwrap(), Vec::::new()