diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 63f5a71..f7c4fe6 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -3,7 +3,8 @@ pub mod login { use serde_json::{from_str, json}; use crate::api::limits::LimitType; - use crate::api::schemas::schemas::{ErrorResponse, LoginResult, LoginSchema}; + use crate::api::schemas::LoginSchema; + use crate::api::types::{ErrorResponse, LoginResult}; use crate::errors::InstanceServerError; use crate::instance::Instance; @@ -17,10 +18,19 @@ pub mod login { let client = Client::new(); let endpoint_url = self.urls.get_api().to_string() + "/auth/login"; let request_builder = client.post(endpoint_url).body(json_schema.to_string()); + // We do not have a user yet, and the UserRateLimits will not be affected by a login + // request (since login is an instance wide limit), which is why we are just cloning the + // instances' limits to pass them on as user_rate_limits later. + let mut cloned_limits = self.limits.clone(); let response = requester - .send_request(request_builder, LimitType::AuthRegister) + .send_request( + request_builder, + LimitType::AuthRegister, + &mut self.limits, + &mut cloned_limits, + ) .await; - if !response.is_ok() { + if response.is_err() { return Err(InstanceServerError::NoResponse); } @@ -41,7 +51,7 @@ pub mod login { let login_result: LoginResult = from_str(&response_text_string).unwrap(); - return Ok(login_result); + Ok(login_result) } } } diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index d25f7c5..b932b9f 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -3,10 +3,7 @@ pub mod register { use serde_json::json; use crate::{ - api::{ - limits::LimitType, - schemas::schemas::{ErrorResponse, RegisterSchema}, - }, + api::{limits::LimitType, schemas::RegisterSchema, types::ErrorResponse}, errors::InstanceServerError, instance::{Instance, Token}, }; @@ -28,10 +25,19 @@ pub mod register { let client = Client::new(); let endpoint_url = self.urls.get_api().to_string() + "/auth/register"; let request_builder = client.post(endpoint_url).body(json_schema.to_string()); + // We do not have a user yet, and the UserRateLimits will not be affected by a login + // request (since register is an instance wide limit), which is why we are just cloning + // the instances' limits to pass them on as user_rate_limits later. + let mut cloned_limits = self.limits.clone(); let response = limited_requester - .send_request(request_builder, LimitType::AuthRegister) + .send_request( + request_builder, + LimitType::AuthRegister, + &mut self.limits, + &mut cloned_limits, + ) .await; - if !response.is_ok() { + if response.is_err() { return Err(InstanceServerError::NoResponse); } @@ -49,16 +55,16 @@ pub mod register { } return Err(InstanceServerError::InvalidFormBodyError { error_type, error }); } - return Ok(Token { + Ok(Token { token: response_text_string, - }); + }) } } } #[cfg(test)] mod test { - use crate::api::schemas::schemas::{AuthEmail, AuthPassword, AuthUsername, RegisterSchema}; + use crate::api::schemas::{AuthEmail, AuthPassword, AuthUsername, RegisterSchema}; use crate::errors::InstanceServerError; use crate::instance::Instance; use crate::limit::LimitedRequester; @@ -70,7 +76,7 @@ mod test { "http://localhost:3001".to_string(), "http://localhost:3001".to_string(), ); - let limited_requester = LimitedRequester::new(urls.get_api().to_string()).await; + let limited_requester = LimitedRequester::new().await; let mut test_instance = Instance::new(urls.clone(), limited_requester) .await .unwrap(); @@ -103,7 +109,7 @@ mod test { "http://localhost:3001".to_string(), "http://localhost:3001".to_string(), ); - let limited_requester = LimitedRequester::new(urls.get_api().to_string()).await; + let limited_requester = LimitedRequester::new().await; let mut test_instance = Instance::new(urls.clone(), limited_requester) .await .unwrap(); @@ -111,7 +117,7 @@ mod test { AuthUsername::new("Hiiii".to_string()).unwrap(), Some(AuthPassword::new("mysupersecurepass123!".to_string()).unwrap()), true, - Some(AuthEmail::new("flori@aaaa.xyz".to_string()).unwrap()), + Some(AuthEmail::new("random978234@aaaa.xyz".to_string()).unwrap()), None, None, Some("2000-01-01".to_string()), diff --git a/src/api/mod.rs b/src/api/mod.rs index f619259..a9902d6 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,7 +1,9 @@ pub mod auth; pub mod policies; pub mod schemas; +pub mod types; pub use policies::instance::instance::*; pub use policies::instance::limits::*; pub use schemas::*; +pub use types::*; diff --git a/src/api/policies/instance/instance.rs b/src/api/policies/instance/instance.rs index 2eaac5b..f96325f 100644 --- a/src/api/policies/instance/instance.rs +++ b/src/api/policies/instance/instance.rs @@ -3,7 +3,7 @@ pub mod instance { use serde_json::from_str; use crate::errors::InstanceServerError; - use crate::{api::schemas::schemas::InstancePoliciesSchema, instance::Instance}; + use crate::{api::types::InstancePolicies, instance::Instance}; impl Instance { /** @@ -13,7 +13,7 @@ pub mod instance { */ pub async fn instance_policies_schema( &self, - ) -> Result { + ) -> Result { let client = Client::new(); let endpoint_url = self.urls.get_api().to_string() + "/policies/instance/"; let request = match client.get(&endpoint_url).send().await { @@ -26,14 +26,14 @@ pub mod instance { } }; - if request.status().as_str().chars().next().unwrap() != '2' { + if !request.status().as_str().starts_with('2') { return Err(InstanceServerError::ReceivedErrorCodeError { error_code: request.status().to_string(), }); } let body = request.text().await.unwrap(); - let instance_policies_schema: InstancePoliciesSchema = from_str(&body).unwrap(); + let instance_policies_schema: InstancePolicies = from_str(&body).unwrap(); Ok(instance_policies_schema) } } @@ -50,7 +50,7 @@ mod instance_policies_schema_test { "http://localhost:3001".to_string(), "http://localhost:3001".to_string(), ); - let limited_requester = LimitedRequester::new(urls.get_api().to_string()).await; + let limited_requester = LimitedRequester::new().await; let test_instance = Instance::new(urls.clone(), limited_requester) .await .unwrap(); diff --git a/src/api/policies/instance/limits.rs b/src/api/policies/instance/limits.rs index a46a29f..7c4b624 100644 --- a/src/api/policies/instance/limits.rs +++ b/src/api/policies/instance/limits.rs @@ -155,17 +155,81 @@ pub mod limits { impl Limit { pub fn add_remaining(&mut self, remaining: i64) { if remaining < 0 { - if ((self.remaining as i64 + remaining) as i64) <= 0 { + if (self.remaining as i64 + remaining) <= 0 { self.remaining = 0; return; } - self.remaining -= remaining.abs() as u64; + self.remaining -= remaining.unsigned_abs(); return; } - self.remaining += remaining.abs() as u64; + self.remaining += remaining.unsigned_abs(); } } + pub struct LimitsMutRef<'a> { + pub limit_absolute_messages: &'a mut Limit, + pub limit_absolute_register: &'a mut Limit, + pub limit_auth_login: &'a mut Limit, + pub limit_auth_register: &'a mut Limit, + pub limit_ip: &'a mut Limit, + pub limit_global: &'a mut Limit, + pub limit_error: &'a mut Limit, + pub limit_guild: &'a mut Limit, + pub limit_webhook: &'a mut Limit, + pub limit_channel: &'a mut Limit, + } + + impl LimitsMutRef<'_> { + pub fn combine_mut_ref<'a>( + instance_rate_limits: &'a mut Limits, + user_rate_limits: &'a mut Limits, + ) -> LimitsMutRef<'a> { + LimitsMutRef { + limit_absolute_messages: &mut instance_rate_limits.limit_absolute_messages, + limit_absolute_register: &mut instance_rate_limits.limit_absolute_register, + limit_auth_login: &mut instance_rate_limits.limit_auth_login, + limit_auth_register: &mut instance_rate_limits.limit_auth_register, + limit_channel: &mut user_rate_limits.limit_channel, + limit_error: &mut user_rate_limits.limit_error, + limit_global: &mut instance_rate_limits.limit_global, + limit_guild: &mut user_rate_limits.limit_guild, + limit_ip: &mut instance_rate_limits.limit_ip, + limit_webhook: &mut user_rate_limits.limit_webhook, + } + } + + pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit { + match limit_type { + &LimitType::AbsoluteMessage => self.limit_absolute_messages, + &LimitType::AbsoluteRegister => self.limit_absolute_register, + &LimitType::AuthLogin => self.limit_auth_login, + &LimitType::AuthRegister => self.limit_auth_register, + &LimitType::Channel => self.limit_channel, + &LimitType::Error => self.limit_error, + &LimitType::Global => self.limit_global, + &LimitType::Guild => self.limit_guild, + &LimitType::Ip => self.limit_ip, + &LimitType::Webhook => self.limit_webhook, + } + } + + pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit { + match limit_type { + &LimitType::AbsoluteMessage => self.limit_absolute_messages, + &LimitType::AbsoluteRegister => self.limit_absolute_register, + &LimitType::AuthLogin => self.limit_auth_login, + &LimitType::AuthRegister => self.limit_auth_register, + &LimitType::Channel => self.limit_channel, + &LimitType::Error => self.limit_error, + &LimitType::Global => self.limit_global, + &LimitType::Guild => self.limit_guild, + &LimitType::Ip => self.limit_ip, + &LimitType::Webhook => self.limit_webhook, + } + } + } + + #[derive(Debug, Clone)] pub struct Limits { pub limit_absolute_messages: Limit, pub limit_absolute_register: Limit, @@ -180,19 +244,64 @@ pub mod limits { } impl Limits { - pub fn iter(&self) -> std::vec::IntoIter { - let mut limits: Vec = Vec::new(); - limits.push(self.limit_absolute_messages.clone()); - limits.push(self.limit_absolute_register.clone()); - limits.push(self.limit_auth_login.clone()); - limits.push(self.limit_auth_register.clone()); - limits.push(self.limit_ip.clone()); - limits.push(self.limit_global.clone()); - limits.push(self.limit_error.clone()); - limits.push(self.limit_guild.clone()); - limits.push(self.limit_webhook.clone()); - limits.push(self.limit_channel.clone()); - limits.into_iter() + pub fn combine(instance_rate_limits: &Limits, user_rate_limits: &Limits) -> Limits { + Limits { + limit_absolute_messages: instance_rate_limits.limit_absolute_messages, + limit_absolute_register: instance_rate_limits.limit_absolute_register, + limit_auth_login: instance_rate_limits.limit_auth_login, + limit_auth_register: instance_rate_limits.limit_auth_register, + limit_channel: user_rate_limits.limit_channel, + limit_error: user_rate_limits.limit_error, + limit_global: instance_rate_limits.limit_global, + limit_guild: user_rate_limits.limit_guild, + limit_ip: instance_rate_limits.limit_ip, + limit_webhook: user_rate_limits.limit_webhook, + } + } + + pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit { + match limit_type { + &LimitType::AbsoluteMessage => &self.limit_absolute_messages, + &LimitType::AbsoluteRegister => &self.limit_absolute_register, + &LimitType::AuthLogin => &self.limit_auth_login, + &LimitType::AuthRegister => &self.limit_auth_register, + &LimitType::Channel => &self.limit_channel, + &LimitType::Error => &self.limit_error, + &LimitType::Global => &self.limit_global, + &LimitType::Guild => &self.limit_guild, + &LimitType::Ip => &self.limit_ip, + &LimitType::Webhook => &self.limit_webhook, + } + } + + pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit { + match limit_type { + &LimitType::AbsoluteMessage => &mut self.limit_absolute_messages, + &LimitType::AbsoluteRegister => &mut self.limit_absolute_register, + &LimitType::AuthLogin => &mut self.limit_auth_login, + &LimitType::AuthRegister => &mut self.limit_auth_register, + &LimitType::Channel => &mut self.limit_channel, + &LimitType::Error => &mut self.limit_error, + &LimitType::Global => &mut self.limit_global, + &LimitType::Guild => &mut self.limit_guild, + &LimitType::Ip => &mut self.limit_ip, + &LimitType::Webhook => &mut self.limit_webhook, + } + } + + pub fn to_hash_map(&self) -> HashMap { + let mut map: HashMap = HashMap::new(); + map.insert(LimitType::AbsoluteMessage, self.limit_absolute_messages); + map.insert(LimitType::AbsoluteRegister, self.limit_absolute_register); + map.insert(LimitType::AuthLogin, self.limit_auth_login); + map.insert(LimitType::AuthRegister, self.limit_auth_register); + map.insert(LimitType::Ip, self.limit_ip); + map.insert(LimitType::Global, self.limit_global); + map.insert(LimitType::Error, self.limit_error); + map.insert(LimitType::Guild, self.limit_guild); + map.insert(LimitType::Webhook, self.limit_webhook); + map.insert(LimitType::Channel, self.limit_channel); + map } /// check_limits uses the API to get the current request limits of the instance. @@ -201,7 +310,7 @@ pub mod limits { /// # Errors /// This function will panic if the request fails or if the response body cannot be parsed. /// TODO: Change this to return a Result and handle the errors properly. - pub async fn check_limits(api_url: String) -> HashMap { + pub async fn check_limits(api_url: String) -> Limits { let client = Client::new(); let url_parsed = crate::URLBundle::parse_url(api_url) + "/policies/instance/limits"; let result = client @@ -219,216 +328,154 @@ pub mod limits { }); let config: Config = from_str(&result).unwrap(); // If config.rate.enabled is false, then add return a Limits struct with all limits set to u64::MAX - let mut limits: HashMap = HashMap::new(); - if config.rate.enabled == false { - limits.insert( - LimitType::AbsoluteMessage, - Limit { + let mut limits: Limits; + if !config.rate.enabled { + limits = Limits { + limit_absolute_messages: Limit { bucket: LimitType::AbsoluteMessage, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::AbsoluteRegister, - Limit { + limit_absolute_register: Limit { bucket: LimitType::AbsoluteRegister, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::AuthLogin, - Limit { + limit_auth_login: Limit { bucket: LimitType::AuthLogin, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::AuthRegister, - Limit { + limit_auth_register: Limit { bucket: LimitType::AuthRegister, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Ip, - Limit { + limit_ip: Limit { bucket: LimitType::Ip, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Global, - Limit { + limit_global: Limit { bucket: LimitType::Global, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Error, - Limit { + limit_error: Limit { bucket: LimitType::Error, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Guild, - Limit { + limit_guild: Limit { bucket: LimitType::Guild, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Webhook, - Limit { + limit_webhook: Limit { bucket: LimitType::Webhook, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Channel, - Limit { + limit_channel: Limit { bucket: LimitType::Channel, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); + }; } else { - limits.insert( - LimitType::AbsoluteMessage, - Limit { + limits = Limits { + limit_absolute_messages: Limit { bucket: LimitType::AbsoluteMessage, limit: config.absoluteRate.sendMessage.limit, remaining: config.absoluteRate.sendMessage.limit, reset: config.absoluteRate.sendMessage.window, }, - ); - limits.insert( - LimitType::AbsoluteRegister, - Limit { + limit_absolute_register: Limit { bucket: LimitType::AbsoluteRegister, limit: config.absoluteRate.register.limit, remaining: config.absoluteRate.register.limit, reset: config.absoluteRate.register.window, }, - ); - limits.insert( - LimitType::AuthLogin, - Limit { + limit_auth_login: Limit { bucket: LimitType::AuthLogin, limit: config.rate.routes.auth.login.count, remaining: config.rate.routes.auth.login.count, reset: config.rate.routes.auth.login.window, }, - ); - limits.insert( - LimitType::AuthRegister, - Limit { + limit_auth_register: Limit { bucket: LimitType::AuthRegister, limit: config.rate.routes.auth.register.count, remaining: config.rate.routes.auth.register.count, reset: config.rate.routes.auth.register.window, }, - ); - limits.insert( - LimitType::Guild, - Limit { - bucket: LimitType::Guild, - limit: config.rate.routes.guild.count, - remaining: config.rate.routes.guild.count, - reset: config.rate.routes.guild.window, - }, - ); - limits.insert( - LimitType::Webhook, - Limit { - bucket: LimitType::Webhook, - limit: config.rate.routes.webhook.count, - remaining: config.rate.routes.webhook.count, - reset: config.rate.routes.webhook.window, - }, - ); - limits.insert( - LimitType::Channel, - Limit { - bucket: LimitType::Channel, - limit: config.rate.routes.channel.count, - remaining: config.rate.routes.channel.count, - reset: config.rate.routes.channel.window, - }, - ); - limits.insert( - LimitType::Ip, - Limit { + limit_ip: Limit { bucket: LimitType::Ip, limit: config.rate.ip.count, remaining: config.rate.ip.count, reset: config.rate.ip.window, }, - ); - limits.insert( - LimitType::Global, - Limit { + limit_global: Limit { bucket: LimitType::Global, limit: config.rate.global.count, remaining: config.rate.global.count, reset: config.rate.global.window, }, - ); - limits.insert( - LimitType::Error, - Limit { + limit_error: Limit { bucket: LimitType::Error, limit: config.rate.error.count, remaining: config.rate.error.count, reset: config.rate.error.window, }, - ); + limit_guild: Limit { + bucket: LimitType::Guild, + limit: config.rate.routes.guild.count, + remaining: config.rate.routes.guild.count, + reset: config.rate.routes.guild.window, + }, + limit_webhook: Limit { + bucket: LimitType::Webhook, + limit: config.rate.routes.webhook.count, + remaining: config.rate.routes.webhook.count, + reset: config.rate.routes.webhook.window, + }, + limit_channel: Limit { + bucket: LimitType::Channel, + limit: config.rate.routes.channel.count, + remaining: config.rate.routes.channel.count, + reset: config.rate.routes.channel.window, + }, + }; } if !config.absoluteRate.register.enabled { - limits.insert( - LimitType::AbsoluteRegister, - Limit { - bucket: LimitType::AbsoluteRegister, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - ); + limits.limit_absolute_register = Limit { + bucket: LimitType::AbsoluteRegister, + limit: u64::MAX, + remaining: u64::MAX, + reset: u64::MAX, + }; } if !config.absoluteRate.sendMessage.enabled { - limits.insert( - LimitType::AbsoluteMessage, - Limit { - bucket: LimitType::AbsoluteMessage, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - ); + limits.limit_absolute_messages = Limit { + bucket: LimitType::AbsoluteMessage, + limit: u64::MAX, + remaining: u64::MAX, + reset: u64::MAX, + }; } - return limits; + limits } } } @@ -446,8 +493,8 @@ mod instance_limits { reset: 0, }; limit.add_remaining(-2); - assert_eq!(0 as u64, limit.remaining); + assert_eq!(0_u64, limit.remaining); limit.add_remaining(-2123123); - assert_eq!(0 as u64, limit.remaining); + assert_eq!(0_u64, limit.remaining); } } diff --git a/src/api/schemas.rs b/src/api/schemas.rs index ad7d297..08425a3 100644 --- a/src/api/schemas.rs +++ b/src/api/schemas.rs @@ -1,392 +1,250 @@ -pub mod schemas { - use regex::Regex; - use serde::{Deserialize, Serialize}; - use std::{collections::HashMap, fmt}; +use regex::Regex; +use serde::{Deserialize, Serialize}; - use crate::errors::FieldFormatError; +use crate::errors::FieldFormatError; +/** +A struct that represents a well-formed email address. + */ +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct AuthEmail { + pub email: String, +} + +impl AuthEmail { /** - A struct that represents a well-formed email address. - */ - #[derive(Clone, PartialEq, Eq, Debug)] - pub struct AuthEmail { - pub email: String, - } - - impl AuthEmail { - /** - Returns a new [`Result`]. - ## Arguments - The email address you want to validate. - ## Errors - You will receive a [`FieldFormatError`], if: - - The email address is not in a valid format. - - */ - pub fn new(email: String) -> Result { - let regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap(); - if !regex.is_match(email.clone().as_str()) { - return Err(FieldFormatError::EmailError); - } - return Ok(AuthEmail { email }); - } - } - - /** - A struct that represents a well-formed username. + Returns a new [`Result`]. ## Arguments - Please use new() to create a new instance of this struct. + The email address you want to validate. + ## Errors + You will receive a [`FieldFormatError`], if: + - The email address is not in a valid format. + + */ + pub fn new(email: String) -> Result { + let regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap(); + if !regex.is_match(email.as_str()) { + return Err(FieldFormatError::EmailError); + } + Ok(AuthEmail { email }) + } +} + +/** +A struct that represents a well-formed username. +## Arguments +Please use new() to create a new instance of this struct. +## Errors +You will receive a [`FieldFormatError`], if: +- The username is not between 2 and 32 characters. + */ +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct AuthUsername { + pub username: String, +} + +impl AuthUsername { + /** + Returns a new [`Result`]. + ## Arguments + The username you want to validate. ## Errors You will receive a [`FieldFormatError`], if: - The username is not between 2 and 32 characters. */ - #[derive(Clone, PartialEq, Eq, Debug)] - pub struct AuthUsername { - pub username: String, - } - - impl AuthUsername { - /** - Returns a new [`Result`]. - ## Arguments - The username you want to validate. - ## Errors - You will receive a [`FieldFormatError`], if: - - The username is not between 2 and 32 characters. - */ - pub fn new(username: String) -> Result { - if username.len() < 2 || username.len() > 32 { - return Err(FieldFormatError::UsernameError); - } else { - return Ok(AuthUsername { username }); - } + pub fn new(username: String) -> Result { + if username.len() < 2 || username.len() > 32 { + Err(FieldFormatError::UsernameError) + } else { + Ok(AuthUsername { username }) } } +} +/** +A struct that represents a well-formed password. +## Arguments +Please use new() to create a new instance of this struct. +## Errors +You will receive a [`FieldFormatError`], if: +- The password is not between 1 and 72 characters. + */ +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct AuthPassword { + pub password: String, +} + +impl AuthPassword { /** - A struct that represents a well-formed password. + Returns a new [`Result`]. ## Arguments - Please use new() to create a new instance of this struct. + The password you want to validate. ## Errors You will receive a [`FieldFormatError`], if: - The password is not between 1 and 72 characters. */ - #[derive(Clone, PartialEq, Eq, Debug)] - pub struct AuthPassword { - pub password: String, - } - - impl AuthPassword { - /** - Returns a new [`Result`]. - ## Arguments - The password you want to validate. - ## Errors - You will receive a [`FieldFormatError`], if: - - The password is not between 1 and 72 characters. - */ - pub fn new(password: String) -> Result { - if password.len() < 1 || password.len() > 72 { - return Err(FieldFormatError::PasswordError); - } else { - return Ok(AuthPassword { password }); - } + pub fn new(password: String) -> Result { + if password.is_empty() || password.len() > 72 { + Err(FieldFormatError::PasswordError) + } else { + Ok(AuthPassword { password }) } } +} +/** +A struct that represents a well-formed register request. +## Arguments +Please use new() to create a new instance of this struct. +## Errors +You will receive a [`FieldFormatError`], if: +- The username is not between 2 and 32 characters. +- The password is not between 1 and 72 characters. + */ + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub struct RegisterSchema { + username: String, + password: Option, + consent: bool, + email: Option, + fingerprint: Option, + invite: Option, + date_of_birth: Option, + gift_code_sku_id: Option, + captcha_key: Option, + promotional_email_opt_in: Option, +} + +impl RegisterSchema { /** - A struct that represents a well-formed register request. + Returns a new [`Result`]. ## Arguments - Please use new() to create a new instance of this struct. + All but "String::username" and "bool::consent" are optional. + ## Errors You will receive a [`FieldFormatError`], if: - - The username is not between 2 and 32 characters. - - The password is not between 1 and 72 characters. - */ + - The username is less than 2 or more than 32 characters in length + - You supply a `password` which is less than 1 or more than 72 characters in length. - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] - #[serde(rename_all = "snake_case")] - pub struct RegisterSchema { - username: String, - password: Option, + These constraints have been defined [in the Spacebar-API](https://docs.spacebar.chat/routes/) + */ + pub fn new( + username: AuthUsername, + password: Option, consent: bool, - email: Option, + email: Option, fingerprint: Option, invite: Option, date_of_birth: Option, gift_code_sku_id: Option, captcha_key: Option, promotional_email_opt_in: Option, - } + ) -> Result { + let username = username.username; - impl RegisterSchema { - /** - Returns a new [`Result`]. - ## Arguments - All but "String::username" and "bool::consent" are optional. - - ## Errors - You will receive a [`FieldFormatError`], if: - - The username is less than 2 or more than 32 characters in length - - You supply a `password` which is less than 1 or more than 72 characters in length. - - These constraints have been defined [in the Spacebar-API](https://docs.spacebar.chat/routes/) - */ - pub fn new( - username: AuthUsername, - password: Option, - consent: bool, - email: Option, - fingerprint: Option, - invite: Option, - date_of_birth: Option, - gift_code_sku_id: Option, - captcha_key: Option, - promotional_email_opt_in: Option, - ) -> Result { - let username = username.username; - - let email_addr; - if email.is_some() { - email_addr = Some(email.unwrap().email); - } else { - email_addr = None; - } - - let has_password; - if password.is_some() { - has_password = Some(password.unwrap().password); - } else { - has_password = None; - } - - if !consent { - return Err(FieldFormatError::ConsentError); - } - - return Ok(RegisterSchema { - username, - password: has_password, - consent, - email: email_addr, - fingerprint, - invite, - date_of_birth, - gift_code_sku_id, - captcha_key, - promotional_email_opt_in, - }); + let email_addr; + if email.is_some() { + email_addr = Some(email.unwrap().email); + } else { + email_addr = None; } - } + let has_password; + if password.is_some() { + has_password = Some(password.unwrap().password); + } else { + has_password = None; + } + + if !consent { + return Err(FieldFormatError::ConsentError); + } + + Ok(RegisterSchema { + username, + password: has_password, + consent, + email: email_addr, + fingerprint, + invite, + date_of_birth, + gift_code_sku_id, + captcha_key, + promotional_email_opt_in, + }) + } +} + +/** +A struct that represents a well-formed login request. +## Arguments +Please use new() to create a new instance of this struct. +## Errors +You will receive a [`FieldFormatError`], if: +- The username is not between 2 and 32 characters. +- The password is not between 1 and 72 characters. + */ +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub struct LoginSchema { + login: String, + password: String, + undelete: Option, + captcha_key: Option, + login_source: Option, + gift_code_sku_id: Option, +} + +impl LoginSchema { /** - A struct that represents a well-formed login request. + Returns a new [`Result`]. ## Arguments - Please use new() to create a new instance of this struct. + login: The username you want to login with. + password: The password you want to login with. + undelete: Honestly no idea what this is for. + captcha_key: The captcha key you want to login with. + login_source: The login source. + gift_code_sku_id: The gift code sku id. ## Errors You will receive a [`FieldFormatError`], if: - - The username is not between 2 and 32 characters. - - The password is not between 1 and 72 characters. - */ - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] - #[serde(rename_all = "snake_case")] - pub struct LoginSchema { - login: String, + - The username is less than 2 or more than 32 characters in length + */ + pub fn new( + login: AuthUsername, password: String, undelete: Option, captcha_key: Option, login_source: Option, gift_code_sku_id: Option, + ) -> Result { + let login = login.username; + Ok(LoginSchema { + login, + password, + undelete, + captcha_key, + login_source, + gift_code_sku_id, + }) } +} - impl LoginSchema { - /** - Returns a new [`Result`]. - ## Arguments - login: The username you want to login with. - password: The password you want to login with. - undelete: Honestly no idea what this is for. - captcha_key: The captcha key you want to login with. - login_source: The login source. - gift_code_sku_id: The gift code sku id. - ## Errors - You will receive a [`FieldFormatError`], if: - - The username is less than 2 or more than 32 characters in length - */ - pub fn new( - login: AuthUsername, - password: String, - undelete: Option, - captcha_key: Option, - login_source: Option, - gift_code_sku_id: Option, - ) -> Result { - let login = login.username; - return Ok(LoginSchema { - login, - password, - undelete, - captcha_key, - login_source, - gift_code_sku_id, - }); - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct LoginResult { - token: String, - settings: UserSettings, - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct UserSettings { - afk_timeout: i32, - allow_accessibility_detection: bool, - animate_emoji: bool, - animate_stickers: i32, - contact_sync_enabled: bool, - convert_emoticons: bool, - custom_status: Option, - default_guilds_restricted: bool, - detect_platform_accounts: bool, - developer_mode: bool, - disable_games_tab: bool, - enable_tts_command: bool, - explicit_content_filter: i32, - friend_source_flags: FriendSourceFlags, - friend_discovery_flags: Option, - gateway_connected: bool, - gif_auto_play: bool, - guild_folders: Vec, - guild_positions: Vec, - inline_attachment_media: bool, - inline_embed_media: bool, - locale: String, - message_display_compact: bool, - native_phone_integration_enabled: bool, - render_embeds: bool, - render_reactions: bool, - restricted_guilds: Vec, - show_current_game: bool, - status: String, - stream_notifications_enabled: bool, - theme: String, - timezone_offset: i32, - view_nsfw_guilds: bool, - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct FriendSourceFlags { - all: Option, - mutual_friends: Option, - mutual_guilds: Option, - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct GuildFolder { - id: String, - guild_ids: Vec, - name: String, - } - - #[derive(Debug, Serialize, Deserialize)] - #[serde(rename_all = "snake_case")] - pub struct TotpSchema { - code: String, - ticket: String, - gift_code_sku_id: Option, - login_source: Option, - } - - /** - Represents the result you get from GET: /api/instance/policies/. - */ - #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] - #[serde(rename_all = "camelCase")] - pub struct InstancePoliciesSchema { - instance_name: String, - instance_description: Option, - front_page: Option, - tos_page: Option, - correspondence_email: Option, - correspondence_user_id: Option, - image: Option, - instance_id: Option, - } - - impl InstancePoliciesSchema { - pub fn new( - instance_name: String, - instance_description: Option, - front_page: Option, - tos_page: Option, - correspondence_email: Option, - correspondence_user_id: Option, - image: Option, - instance_id: Option, - ) -> Self { - InstancePoliciesSchema { - instance_name, - instance_description, - front_page, - tos_page, - correspondence_email, - correspondence_user_id, - image, - instance_id, - } - } - } - - impl fmt::Display for InstancePoliciesSchema { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "InstancePoliciesSchema {{ instance_name: {}, instance_description: {}, front_page: {}, tos_page: {}, correspondence_email: {}, correspondence_user_id: {}, image: {}, instance_id: {} }}", - self.instance_name, - self.instance_description.clone().unwrap_or("None".to_string()), - self.front_page.clone().unwrap_or("None".to_string()), - self.tos_page.clone().unwrap_or("None".to_string()), - self.correspondence_email.clone().unwrap_or("None".to_string()), - self.correspondence_user_id.clone().unwrap_or("None".to_string()), - self.image.clone().unwrap_or("None".to_string()), - self.instance_id.clone().unwrap_or("None".to_string()), - ) - } - } - - #[derive(Serialize, Deserialize, Debug)] - pub struct ErrorResponse { - pub code: i32, - pub message: String, - pub errors: IntermittentError, - } - - #[derive(Serialize, Deserialize, Debug)] - pub struct IntermittentError { - #[serde(flatten)] - pub errors: std::collections::HashMap, - } - - #[derive(Serialize, Deserialize, Debug, Default)] - pub struct ErrorField { - #[serde(default)] - pub _errors: Vec, - } - - #[derive(Serialize, Deserialize, Debug)] - pub struct Error { - pub message: String, - pub code: String, - } +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct TotpSchema { + code: String, + ticket: String, + gift_code_sku_id: Option, + login_source: Option, } // I know that some of these tests are... really really basic and unneccessary, but sometimes, I // just feel like writing tests, so there you go :) -@bitfl0wer #[cfg(test)] mod schemas_tests { - use super::schemas::*; + use super::*; use crate::errors::FieldFormatError; #[test] @@ -401,7 +259,7 @@ mod schemas_tests { fn password_too_long() { let mut long_pw = String::new(); for _ in 0..73 { - long_pw = long_pw + "a"; + long_pw += "a"; } assert_eq!( AuthPassword::new(long_pw), @@ -421,7 +279,7 @@ mod schemas_tests { fn username_too_long() { let mut long_un = String::new(); for _ in 0..33 { - long_un = long_un + "a"; + long_un += "a"; } assert_eq!( AuthUsername::new(long_un), diff --git a/src/api/types.rs b/src/api/types.rs new file mode 100644 index 0000000..d4c8309 --- /dev/null +++ b/src/api/types.rs @@ -0,0 +1,213 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +use crate::{api::limits::Limits, URLBundle}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginResult { + token: String, + settings: UserSettings, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UserSettings { + afk_timeout: i32, + allow_accessibility_detection: bool, + animate_emoji: bool, + animate_stickers: i32, + contact_sync_enabled: bool, + convert_emoticons: bool, + custom_status: Option, + default_guilds_restricted: bool, + detect_platform_accounts: bool, + developer_mode: bool, + disable_games_tab: bool, + enable_tts_command: bool, + explicit_content_filter: i32, + friend_source_flags: FriendSourceFlags, + friend_discovery_flags: Option, + gateway_connected: bool, + gif_auto_play: bool, + guild_folders: Vec, + guild_positions: Vec, + inline_attachment_media: bool, + inline_embed_media: bool, + locale: String, + message_display_compact: bool, + native_phone_integration_enabled: bool, + render_embeds: bool, + render_reactions: bool, + restricted_guilds: Vec, + show_current_game: bool, + status: String, + stream_notifications_enabled: bool, + theme: String, + timezone_offset: i32, + view_nsfw_guilds: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FriendSourceFlags { + all: Option, + mutual_friends: Option, + mutual_guilds: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GuildFolder { + id: String, + guild_ids: Vec, + name: String, +} + +/** +Represents the result you get from GET: /api/instance/policies/. +*/ +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "camelCase")] +pub struct InstancePolicies { + instance_name: String, + instance_description: Option, + front_page: Option, + tos_page: Option, + correspondence_email: Option, + correspondence_user_id: Option, + image: Option, + instance_id: Option, +} + +impl InstancePolicies { + pub fn new( + instance_name: String, + instance_description: Option, + front_page: Option, + tos_page: Option, + correspondence_email: Option, + correspondence_user_id: Option, + image: Option, + instance_id: Option, + ) -> Self { + InstancePolicies { + instance_name, + instance_description, + front_page, + tos_page, + correspondence_email, + correspondence_user_id, + image, + instance_id, + } + } +} + +impl fmt::Display for InstancePolicies { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "InstancePoliciesSchema {{ instance_name: {}, instance_description: {}, front_page: {}, tos_page: {}, correspondence_email: {}, correspondence_user_id: {}, image: {}, instance_id: {} }}", + self.instance_name, + self.instance_description.clone().unwrap_or("None".to_string()), + self.front_page.clone().unwrap_or("None".to_string()), + self.tos_page.clone().unwrap_or("None".to_string()), + self.correspondence_email.clone().unwrap_or("None".to_string()), + self.correspondence_user_id.clone().unwrap_or("None".to_string()), + self.image.clone().unwrap_or("None".to_string()), + self.instance_id.clone().unwrap_or("None".to_string()), + ) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ErrorResponse { + pub code: i32, + pub message: String, + pub errors: IntermittentError, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct IntermittentError { + #[serde(flatten)] + pub errors: std::collections::HashMap, +} + +#[derive(Serialize, Deserialize, Debug, Default)] +pub struct ErrorField { + #[serde(default)] + pub _errors: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Error { + pub message: String, + pub code: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct UserObject { + id: String, + username: String, + discriminator: String, + avatar: Option, + bot: Option, + system: Option, + mfa_enabled: Option, + banner: Option, + accent_color: Option, + locale: String, + verified: Option, + email: Option, + flags: i8, + premium_type: Option, + public_flags: Option, +} + +#[derive(Debug)] +pub struct User { + logged_in: bool, + belongs_to: URLBundle, + token: String, + rate_limits: Limits, + pub settings: UserSettings, + pub object: UserObject, +} + +impl User { + pub fn is_logged_in(&self) -> bool { + self.logged_in + } + + pub fn belongs_to(&self) -> URLBundle { + self.belongs_to.clone() + } + + pub fn token(&self) -> String { + self.token.clone() + } + + pub fn set_logged_in(&mut self, bool: bool) { + self.logged_in = bool; + } + + pub fn set_token(&mut self, token: String) { + self.token = token; + } + + pub fn new( + logged_in: bool, + belongs_to: URLBundle, + token: String, + rate_limits: Limits, + settings: UserSettings, + object: UserObject, + ) -> User { + User { + logged_in, + belongs_to, + token, + rate_limits, + settings, + object, + } + } +} diff --git a/src/instance.rs b/src/instance.rs index a655735..3e52e88 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -1,4 +1,5 @@ -use crate::api::schemas::schemas::InstancePoliciesSchema; +use crate::api::limits::Limits; +use crate::api::types::{InstancePolicies, User}; use crate::errors::{FieldFormatError, InstanceServerError}; use crate::limit::LimitedRequester; use crate::URLBundle; @@ -12,10 +13,11 @@ The [`Instance`] what you will be using to perform all sorts of actions on the S */ pub struct Instance { pub urls: URLBundle, - pub instance_info: InstancePoliciesSchema, + pub instance_info: InstancePolicies, pub requester: LimitedRequester, + pub limits: Limits, //pub gateway: Gateway, - pub users: HashMap, + pub users: HashMap, } impl Instance { @@ -29,10 +31,10 @@ impl Instance { urls: URLBundle, requester: LimitedRequester, ) -> Result { - let users: HashMap = HashMap::new(); + let users: HashMap = HashMap::new(); let mut instance = Instance { - urls, - instance_info: InstancePoliciesSchema::new( + urls: urls.clone(), + instance_info: InstancePolicies::new( // This is okay, because the instance_info will be overwritten by the instance_policies_schema() function. "".to_string(), None, @@ -43,6 +45,7 @@ impl Instance { None, None, ), + limits: Limits::check_limits(urls.api).await, requester, users, }; @@ -84,6 +87,6 @@ impl Username { if username.len() < 2 || username.len() > 32 { return Err(FieldFormatError::UsernameError); } - return Ok(Username { username }); + Ok(Username { username }) } } diff --git a/src/lib.rs b/src/lib.rs index ab97b74..00170d0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,10 +49,10 @@ impl URLBundle { }; // if the last character of the string is a slash, remove it. let mut url_string = url.to_string(); - if url_string.chars().last().unwrap() == '/' { + if url_string.ends_with('/') { url_string.pop(); } - return url_string; + url_string } pub fn get_api(&self) -> &str { diff --git a/src/limit.rs b/src/limit.rs index f5ac4b6..84e6d17 100644 --- a/src/limit.rs +++ b/src/limit.rs @@ -1,10 +1,10 @@ use crate::{ - api::limits::{Limit, LimitType, Limits}, + api::limits::{Limit, LimitType, Limits, LimitsMutRef}, errors::InstanceServerError, }; use reqwest::{Client, RequestBuilder, Response}; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; // Note: There seem to be some overlapping request limiters. We need to make sure that sending a // request checks for all the request limiters that apply, and blocks if any of the limiters are 0 @@ -20,7 +20,6 @@ pub struct TypedRequest { pub struct LimitedRequester { http: Client, requests: VecDeque, - limits_rate: HashMap, } impl LimitedRequester { @@ -29,11 +28,10 @@ impl LimitedRequester { /// be send within the `Limit` of an external API Ratelimiter, and looks at the returned request /// headers to see if it can find Ratelimit info to update itself. #[allow(dead_code)] - pub async fn new(api_url: String) -> Self { + pub async fn new() -> Self { LimitedRequester { http: Client::new(), requests: VecDeque::new(), - limits_rate: Limits::check_limits(api_url).await, } } @@ -70,8 +68,10 @@ impl LimitedRequester { &mut self, request: RequestBuilder, limit_type: LimitType, + instance_rate_limits: &mut Limits, + user_rate_limits: &mut Limits, ) -> Result { - if self.can_send_request(limit_type) { + if self.can_send_request(limit_type, instance_rate_limits, user_rate_limits) { let built_request = request .build() .unwrap_or_else(|e| panic!("Error while building the Request for sending: {}", e)); @@ -80,14 +80,19 @@ impl LimitedRequester { Ok(is_response) => is_response, Err(e) => panic!("An error occured while processing the response: {}", e), }; - self.update_limits(&response, limit_type); - return Ok(response); + self.update_limits( + &response, + limit_type, + instance_rate_limits, + user_rate_limits, + ); + Ok(response) } else { self.requests.push_back(TypedRequest { - request: request, - limit_type: limit_type, + request, + limit_type, }); - return Err(InstanceServerError::RateLimited); + Err(InstanceServerError::RateLimited) } } @@ -102,9 +107,16 @@ impl LimitedRequester { } } - fn can_send_request(&mut self, limit_type: LimitType) -> bool { - let limits = &self.limits_rate.clone(); + fn can_send_request( + &mut self, + limit_type: LimitType, + instance_rate_limits: &Limits, + user_rate_limits: &Limits, + ) -> bool { // Check if all of the limits in this vec have at least one remaining request + + let rate_limits = Limits::combine(instance_rate_limits, user_rate_limits); + let constant_limits: Vec<&LimitType> = [ &LimitType::Error, &LimitType::Global, @@ -113,19 +125,29 @@ impl LimitedRequester { ] .to_vec(); for limit in constant_limits.iter() { - match limits.get(&limit) { + match rate_limits.to_hash_map().get(limit) { Some(limit) => { if limit.remaining == 0 { return false; } // AbsoluteRegister and AuthRegister can cancel each other out. if limit.bucket == LimitType::AbsoluteRegister - && limits.get(&LimitType::AuthRegister).unwrap().remaining == 0 + && rate_limits + .to_hash_map() + .get(&LimitType::AuthRegister) + .unwrap() + .remaining + == 0 { return false; } if limit.bucket == LimitType::AuthRegister - && limits.get(&LimitType::AbsoluteRegister).unwrap().remaining == 0 + && rate_limits + .to_hash_map() + .get(&LimitType::AbsoluteRegister) + .unwrap() + .remaining + == 0 { return false; } @@ -133,100 +155,97 @@ impl LimitedRequester { None => return false, } } - return true; + true } - fn update_limits(&mut self, response: &Response, limit_type: LimitType) { + fn update_limits( + &mut self, + response: &Response, + limit_type: LimitType, + instance_rate_limits: &mut Limits, + user_rate_limits: &mut Limits, + ) { + let mut rate_limits = LimitsMutRef::combine_mut_ref(instance_rate_limits, user_rate_limits); + let remaining = match response.headers().get("X-RateLimit-Remaining") { Some(remaining) => remaining.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().remaining - 1, + None => rate_limits.get_limit_mut_ref(&limit_type).remaining - 1, }; let limit = match response.headers().get("X-RateLimit-Limit") { Some(limit) => limit.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().limit, + None => rate_limits.get_limit_mut_ref(&limit_type).limit, }; let reset = match response.headers().get("X-RateLimit-Reset") { Some(reset) => reset.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().reset, + None => rate_limits.get_limit_mut_ref(&limit_type).reset, }; let status = response.status(); let status_str = status.as_str(); - if status_str.chars().next().unwrap() == '4' { - self.limits_rate - .get_mut(&LimitType::Error) - .unwrap() + if status_str.starts_with('4') { + rate_limits + .get_limit_mut_ref(&LimitType::Error) .add_remaining(-1); } - self.limits_rate - .get_mut(&LimitType::Global) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::Global) .add_remaining(-1); - self.limits_rate - .get_mut(&LimitType::Ip) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::Ip) .add_remaining(-1); - let mut_limits_rate = &mut self.limits_rate; - match limit_type { LimitType::Error => { - let entry = mut_limits_rate.get_mut(&LimitType::Error).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Error); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Global => { - let entry = mut_limits_rate.get_mut(&LimitType::Global).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Global); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Ip => { - let entry = mut_limits_rate.get_mut(&LimitType::Ip).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Ip); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::AuthLogin => { - let entry = mut_limits_rate.get_mut(&LimitType::AuthLogin).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthLogin); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::AbsoluteRegister => { - let entry = mut_limits_rate - .get_mut(&LimitType::AbsoluteRegister) - .unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteRegister); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); // AbsoluteRegister and AuthRegister both need to be updated, if a Register event // happens. - mut_limits_rate - .get_mut(&LimitType::AuthRegister) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::AuthRegister) .remaining -= 1; } LimitType::AuthRegister => { - let entry = mut_limits_rate.get_mut(&LimitType::AuthRegister).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthRegister); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); // AbsoluteRegister and AuthRegister both need to be updated, if a Register event // happens. - mut_limits_rate - .get_mut(&LimitType::AbsoluteRegister) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::AbsoluteRegister) .remaining -= 1; } LimitType::AbsoluteMessage => { - let entry = mut_limits_rate - .get_mut(&LimitType::AbsoluteMessage) - .unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteMessage); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Channel => { - let entry = mut_limits_rate.get_mut(&LimitType::Channel).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Channel); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Guild => { - let entry = mut_limits_rate.get_mut(&LimitType::Guild).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Guild); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Webhook => { - let entry = mut_limits_rate.get_mut(&LimitType::Webhook).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Webhook); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } } @@ -247,16 +266,7 @@ mod rate_limit { String::from("wss://localhost:3001/"), String::from("http://localhost:3001/cdn"), ); - let requester = LimitedRequester::new(urls.api).await; - assert_eq!( - requester.limits_rate.get(&LimitType::Ip).unwrap(), - &Limit { - bucket: LimitType::Ip, - limit: 500, - remaining: 500, - reset: 5 - } - ); + let _requester = LimitedRequester::new().await; } #[tokio::test] @@ -266,16 +276,22 @@ mod rate_limit { String::from("wss://localhost:3001/"), String::from("http://localhost:3001/cdn"), ); - let mut requester = LimitedRequester::new(urls.api.clone()).await; + let mut requester = LimitedRequester::new().await; let mut request: Option> = None; + let mut instance_rate_limits = Limits::check_limits(urls.api.clone()).await; + let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await; for _ in 0..=50 { let request_path = urls.api.clone() + "/some/random/nonexisting/path"; - let request_builder = requester.http.get(request_path); request = Some( requester - .send_request(request_builder, LimitType::Channel) + .send_request( + request_builder, + LimitType::Channel, + &mut instance_rate_limits, + &mut user_rate_limits, + ) .await, ); } @@ -296,11 +312,18 @@ mod rate_limit { String::from("wss://localhost:3001/"), String::from("http://localhost:3001/cdn"), ); - let mut requester = LimitedRequester::new(urls.api.clone()).await; + let mut instance_rate_limits = Limits::check_limits(urls.api.clone()).await; + let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await; + let mut requester = LimitedRequester::new().await; let request_path = urls.api.clone() + "/policies/instance/limits"; let request_builder = requester.http.get(request_path); let request = requester - .send_request(request_builder, LimitType::Channel) + .send_request( + request_builder, + LimitType::Channel, + &mut instance_rate_limits, + &mut user_rate_limits, + ) .await; let result = match request { Ok(result) => result,