From 21c7bdcf8b1df635bda7915d3c0e2956cab86165 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Mon, 24 Apr 2023 19:38:42 +0200 Subject: [PATCH 1/4] Implement Limits::new() and more logic --- src/api/policies/instance/limits.rs | 225 ++++++++++++---------------- 1 file changed, 97 insertions(+), 128 deletions(-) diff --git a/src/api/policies/instance/limits.rs b/src/api/policies/instance/limits.rs index a46a29f..299164b 100644 --- a/src/api/policies/instance/limits.rs +++ b/src/api/policies/instance/limits.rs @@ -166,6 +166,7 @@ pub mod limits { } } + #[derive(Debug, Clone)] pub struct Limits { pub limit_absolute_messages: Limit, pub limit_absolute_register: Limit, @@ -180,19 +181,49 @@ pub mod limits { } impl Limits { - pub fn iter(&self) -> std::vec::IntoIter { - let mut limits: Vec = Vec::new(); - limits.push(self.limit_absolute_messages.clone()); - limits.push(self.limit_absolute_register.clone()); - limits.push(self.limit_auth_login.clone()); - limits.push(self.limit_auth_register.clone()); - limits.push(self.limit_ip.clone()); - limits.push(self.limit_global.clone()); - limits.push(self.limit_error.clone()); - limits.push(self.limit_guild.clone()); - limits.push(self.limit_webhook.clone()); - limits.push(self.limit_channel.clone()); - limits.into_iter() + pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit { + match limit_type { + &LimitType::AbsoluteMessage => &self.limit_absolute_messages, + &LimitType::AbsoluteRegister => &self.limit_absolute_register, + &LimitType::AuthLogin => &self.limit_auth_login, + &LimitType::AuthRegister => &self.limit_auth_register, + &LimitType::Channel => &self.limit_channel, + &LimitType::Error => &self.limit_error, + &LimitType::Global => &self.limit_global, + &LimitType::Guild => &self.limit_guild, + &LimitType::Ip => &self.limit_ip, + &LimitType::Webhook => &self.limit_webhook, + } + } + + pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit { + match limit_type { + &LimitType::AbsoluteMessage => &mut self.limit_absolute_messages, + &LimitType::AbsoluteRegister => &mut self.limit_absolute_register, + &LimitType::AuthLogin => &mut self.limit_auth_login, + &LimitType::AuthRegister => &mut self.limit_auth_register, + &LimitType::Channel => &mut self.limit_channel, + &LimitType::Error => &mut self.limit_error, + &LimitType::Global => &mut self.limit_global, + &LimitType::Guild => &mut self.limit_guild, + &LimitType::Ip => &mut self.limit_ip, + &LimitType::Webhook => &mut self.limit_webhook, + } + } + + pub fn to_hash_map(&self) -> HashMap { + let mut map: HashMap = HashMap::new(); + map.insert(LimitType::AbsoluteMessage, self.limit_absolute_messages); + map.insert(LimitType::AbsoluteRegister, self.limit_absolute_register); + map.insert(LimitType::AuthLogin, self.limit_auth_login); + map.insert(LimitType::AuthRegister, self.limit_auth_register); + map.insert(LimitType::Ip, self.limit_ip); + map.insert(LimitType::Global, self.limit_global); + map.insert(LimitType::Error, self.limit_error); + map.insert(LimitType::Guild, self.limit_guild); + map.insert(LimitType::Webhook, self.limit_webhook); + map.insert(LimitType::Channel, self.limit_channel); + map } /// check_limits uses the API to get the current request limits of the instance. @@ -201,7 +232,7 @@ pub mod limits { /// # Errors /// This function will panic if the request fails or if the response body cannot be parsed. /// TODO: Change this to return a Result and handle the errors properly. - pub async fn check_limits(api_url: String) -> HashMap { + pub async fn check_limits(api_url: String) -> Limits { let client = Client::new(); let url_parsed = crate::URLBundle::parse_url(api_url) + "/policies/instance/limits"; let result = client @@ -219,213 +250,151 @@ pub mod limits { }); let config: Config = from_str(&result).unwrap(); // If config.rate.enabled is false, then add return a Limits struct with all limits set to u64::MAX - let mut limits: HashMap = HashMap::new(); + let mut limits: Limits; if config.rate.enabled == false { - limits.insert( - LimitType::AbsoluteMessage, - Limit { + limits = Limits { + limit_absolute_messages: Limit { bucket: LimitType::AbsoluteMessage, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::AbsoluteRegister, - Limit { + limit_absolute_register: Limit { bucket: LimitType::AbsoluteRegister, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::AuthLogin, - Limit { + limit_auth_login: Limit { bucket: LimitType::AuthLogin, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::AuthRegister, - Limit { + limit_auth_register: Limit { bucket: LimitType::AuthRegister, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Ip, - Limit { + limit_ip: Limit { bucket: LimitType::Ip, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Global, - Limit { + limit_global: Limit { bucket: LimitType::Global, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Error, - Limit { + limit_error: Limit { bucket: LimitType::Error, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Guild, - Limit { + limit_guild: Limit { bucket: LimitType::Guild, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Webhook, - Limit { + limit_webhook: Limit { bucket: LimitType::Webhook, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); - limits.insert( - LimitType::Channel, - Limit { + limit_channel: Limit { bucket: LimitType::Channel, limit: u64::MAX, remaining: u64::MAX, reset: u64::MAX, }, - ); + }; } else { - limits.insert( - LimitType::AbsoluteMessage, - Limit { + limits = Limits { + limit_absolute_messages: Limit { bucket: LimitType::AbsoluteMessage, limit: config.absoluteRate.sendMessage.limit, remaining: config.absoluteRate.sendMessage.limit, reset: config.absoluteRate.sendMessage.window, }, - ); - limits.insert( - LimitType::AbsoluteRegister, - Limit { + limit_absolute_register: Limit { bucket: LimitType::AbsoluteRegister, limit: config.absoluteRate.register.limit, remaining: config.absoluteRate.register.limit, reset: config.absoluteRate.register.window, }, - ); - limits.insert( - LimitType::AuthLogin, - Limit { + limit_auth_login: Limit { bucket: LimitType::AuthLogin, limit: config.rate.routes.auth.login.count, remaining: config.rate.routes.auth.login.count, reset: config.rate.routes.auth.login.window, }, - ); - limits.insert( - LimitType::AuthRegister, - Limit { + limit_auth_register: Limit { bucket: LimitType::AuthRegister, limit: config.rate.routes.auth.register.count, remaining: config.rate.routes.auth.register.count, reset: config.rate.routes.auth.register.window, }, - ); - limits.insert( - LimitType::Guild, - Limit { - bucket: LimitType::Guild, - limit: config.rate.routes.guild.count, - remaining: config.rate.routes.guild.count, - reset: config.rate.routes.guild.window, - }, - ); - limits.insert( - LimitType::Webhook, - Limit { - bucket: LimitType::Webhook, - limit: config.rate.routes.webhook.count, - remaining: config.rate.routes.webhook.count, - reset: config.rate.routes.webhook.window, - }, - ); - limits.insert( - LimitType::Channel, - Limit { - bucket: LimitType::Channel, - limit: config.rate.routes.channel.count, - remaining: config.rate.routes.channel.count, - reset: config.rate.routes.channel.window, - }, - ); - limits.insert( - LimitType::Ip, - Limit { + limit_ip: Limit { bucket: LimitType::Ip, limit: config.rate.ip.count, remaining: config.rate.ip.count, reset: config.rate.ip.window, }, - ); - limits.insert( - LimitType::Global, - Limit { + limit_global: Limit { bucket: LimitType::Global, limit: config.rate.global.count, remaining: config.rate.global.count, reset: config.rate.global.window, }, - ); - limits.insert( - LimitType::Error, - Limit { + limit_error: Limit { bucket: LimitType::Error, limit: config.rate.error.count, remaining: config.rate.error.count, reset: config.rate.error.window, }, - ); + limit_guild: Limit { + bucket: LimitType::Guild, + limit: config.rate.routes.guild.count, + remaining: config.rate.routes.guild.count, + reset: config.rate.routes.guild.window, + }, + limit_webhook: Limit { + bucket: LimitType::Webhook, + limit: config.rate.routes.webhook.count, + remaining: config.rate.routes.webhook.count, + reset: config.rate.routes.webhook.window, + }, + limit_channel: Limit { + bucket: LimitType::Channel, + limit: config.rate.routes.channel.count, + remaining: config.rate.routes.channel.count, + reset: config.rate.routes.channel.window, + }, + }; } if !config.absoluteRate.register.enabled { - limits.insert( - LimitType::AbsoluteRegister, - Limit { - bucket: LimitType::AbsoluteRegister, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - ); + limits.limit_absolute_register = Limit { + bucket: LimitType::AbsoluteRegister, + limit: u64::MAX, + remaining: u64::MAX, + reset: u64::MAX, + }; } if !config.absoluteRate.sendMessage.enabled { - limits.insert( - LimitType::AbsoluteMessage, - Limit { - bucket: LimitType::AbsoluteMessage, - limit: u64::MAX, - remaining: u64::MAX, - reset: u64::MAX, - }, - ); + limits.limit_absolute_messages = Limit { + bucket: LimitType::AbsoluteMessage, + limit: u64::MAX, + remaining: u64::MAX, + reset: u64::MAX, + }; } return limits; From 30742380b99f132afe2e8e0b0009e3dbadd8372a Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Mon, 24 Apr 2023 19:49:26 +0200 Subject: [PATCH 2/4] Change HashMap<> to Limits --- src/api/auth/login.rs | 2 +- src/api/auth/register.rs | 2 +- src/instance.rs | 5 +- src/limit.rs | 108 ++++++++++++++++++--------------------- 4 files changed, 57 insertions(+), 60 deletions(-) diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 63f5a71..2cb61bd 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -18,7 +18,7 @@ pub mod login { let endpoint_url = self.urls.get_api().to_string() + "/auth/login"; let request_builder = client.post(endpoint_url).body(json_schema.to_string()); let response = requester - .send_request(request_builder, LimitType::AuthRegister) + .send_request(request_builder, LimitType::AuthRegister, &mut self.limits) .await; if !response.is_ok() { return Err(InstanceServerError::NoResponse); diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index d25f7c5..ef1062d 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -29,7 +29,7 @@ pub mod register { let endpoint_url = self.urls.get_api().to_string() + "/auth/register"; let request_builder = client.post(endpoint_url).body(json_schema.to_string()); let response = limited_requester - .send_request(request_builder, LimitType::AuthRegister) + .send_request(request_builder, LimitType::AuthRegister, &mut self.limits) .await; if !response.is_ok() { return Err(InstanceServerError::NoResponse); diff --git a/src/instance.rs b/src/instance.rs index a655735..702d25b 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -1,3 +1,4 @@ +use crate::api::limits::{Limit, LimitType, Limits}; use crate::api::schemas::schemas::InstancePoliciesSchema; use crate::errors::{FieldFormatError, InstanceServerError}; use crate::limit::LimitedRequester; @@ -14,6 +15,7 @@ pub struct Instance { pub urls: URLBundle, pub instance_info: InstancePoliciesSchema, pub requester: LimitedRequester, + pub limits: Limits, //pub gateway: Gateway, pub users: HashMap, } @@ -31,7 +33,7 @@ impl Instance { ) -> Result { let users: HashMap = HashMap::new(); let mut instance = Instance { - urls, + urls: urls.clone(), instance_info: InstancePoliciesSchema::new( // This is okay, because the instance_info will be overwritten by the instance_policies_schema() function. "".to_string(), @@ -43,6 +45,7 @@ impl Instance { None, None, ), + limits: Limits::check_limits(urls.api).await, requester, users, }; diff --git a/src/limit.rs b/src/limit.rs index f5ac4b6..b69ad86 100644 --- a/src/limit.rs +++ b/src/limit.rs @@ -4,7 +4,7 @@ use crate::{ }; use reqwest::{Client, RequestBuilder, Response}; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; // Note: There seem to be some overlapping request limiters. We need to make sure that sending a // request checks for all the request limiters that apply, and blocks if any of the limiters are 0 @@ -20,7 +20,6 @@ pub struct TypedRequest { pub struct LimitedRequester { http: Client, requests: VecDeque, - limits_rate: HashMap, } impl LimitedRequester { @@ -33,7 +32,6 @@ impl LimitedRequester { LimitedRequester { http: Client::new(), requests: VecDeque::new(), - limits_rate: Limits::check_limits(api_url).await, } } @@ -70,8 +68,9 @@ impl LimitedRequester { &mut self, request: RequestBuilder, limit_type: LimitType, + rate_limits: &mut Limits, ) -> Result { - if self.can_send_request(limit_type) { + if self.can_send_request(limit_type, rate_limits) { let built_request = request .build() .unwrap_or_else(|e| panic!("Error while building the Request for sending: {}", e)); @@ -80,7 +79,7 @@ impl LimitedRequester { Ok(is_response) => is_response, Err(e) => panic!("An error occured while processing the response: {}", e), }; - self.update_limits(&response, limit_type); + self.update_limits(&response, limit_type, rate_limits); return Ok(response); } else { self.requests.push_back(TypedRequest { @@ -102,8 +101,7 @@ impl LimitedRequester { } } - fn can_send_request(&mut self, limit_type: LimitType) -> bool { - let limits = &self.limits_rate.clone(); + fn can_send_request(&mut self, limit_type: LimitType, rate_limits: &Limits) -> bool { // Check if all of the limits in this vec have at least one remaining request let constant_limits: Vec<&LimitType> = [ &LimitType::Error, @@ -113,19 +111,29 @@ impl LimitedRequester { ] .to_vec(); for limit in constant_limits.iter() { - match limits.get(&limit) { + match rate_limits.to_hash_map().get(&limit) { Some(limit) => { if limit.remaining == 0 { return false; } // AbsoluteRegister and AuthRegister can cancel each other out. if limit.bucket == LimitType::AbsoluteRegister - && limits.get(&LimitType::AuthRegister).unwrap().remaining == 0 + && rate_limits + .to_hash_map() + .get(&LimitType::AuthRegister) + .unwrap() + .remaining + == 0 { return false; } if limit.bucket == LimitType::AuthRegister - && limits.get(&LimitType::AbsoluteRegister).unwrap().remaining == 0 + && rate_limits + .to_hash_map() + .get(&LimitType::AbsoluteRegister) + .unwrap() + .remaining + == 0 { return false; } @@ -136,97 +144,91 @@ impl LimitedRequester { return true; } - fn update_limits(&mut self, response: &Response, limit_type: LimitType) { + fn update_limits( + &mut self, + response: &Response, + limit_type: LimitType, + rate_limits: &mut Limits, + ) { let remaining = match response.headers().get("X-RateLimit-Remaining") { Some(remaining) => remaining.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().remaining - 1, + None => rate_limits.get_limit_mut_ref(&limit_type).remaining - 1, }; let limit = match response.headers().get("X-RateLimit-Limit") { Some(limit) => limit.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().limit, + None => rate_limits.get_limit_mut_ref(&limit_type).limit, }; let reset = match response.headers().get("X-RateLimit-Reset") { Some(reset) => reset.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().reset, + None => rate_limits.get_limit_mut_ref(&limit_type).reset, }; let status = response.status(); let status_str = status.as_str(); if status_str.chars().next().unwrap() == '4' { - self.limits_rate - .get_mut(&LimitType::Error) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::Error) .add_remaining(-1); } - self.limits_rate - .get_mut(&LimitType::Global) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::Global) .add_remaining(-1); - self.limits_rate - .get_mut(&LimitType::Ip) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::Ip) .add_remaining(-1); - let mut_limits_rate = &mut self.limits_rate; - match limit_type { LimitType::Error => { - let entry = mut_limits_rate.get_mut(&LimitType::Error).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Error); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Global => { - let entry = mut_limits_rate.get_mut(&LimitType::Global).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Global); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Ip => { - let entry = mut_limits_rate.get_mut(&LimitType::Ip).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Ip); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::AuthLogin => { - let entry = mut_limits_rate.get_mut(&LimitType::AuthLogin).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthLogin); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::AbsoluteRegister => { - let entry = mut_limits_rate - .get_mut(&LimitType::AbsoluteRegister) - .unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteRegister); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); // AbsoluteRegister and AuthRegister both need to be updated, if a Register event // happens. - mut_limits_rate - .get_mut(&LimitType::AuthRegister) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::AuthRegister) .remaining -= 1; } LimitType::AuthRegister => { - let entry = mut_limits_rate.get_mut(&LimitType::AuthRegister).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthRegister); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); // AbsoluteRegister and AuthRegister both need to be updated, if a Register event // happens. - mut_limits_rate - .get_mut(&LimitType::AbsoluteRegister) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::AbsoluteRegister) .remaining -= 1; } LimitType::AbsoluteMessage => { - let entry = mut_limits_rate - .get_mut(&LimitType::AbsoluteMessage) - .unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteMessage); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Channel => { - let entry = mut_limits_rate.get_mut(&LimitType::Channel).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Channel); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Guild => { - let entry = mut_limits_rate.get_mut(&LimitType::Guild).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Guild); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Webhook => { - let entry = mut_limits_rate.get_mut(&LimitType::Webhook).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Webhook); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } } @@ -248,15 +250,6 @@ mod rate_limit { String::from("http://localhost:3001/cdn"), ); let requester = LimitedRequester::new(urls.api).await; - assert_eq!( - requester.limits_rate.get(&LimitType::Ip).unwrap(), - &Limit { - bucket: LimitType::Ip, - limit: 500, - remaining: 500, - reset: 5 - } - ); } #[tokio::test] @@ -268,14 +261,14 @@ mod rate_limit { ); let mut requester = LimitedRequester::new(urls.api.clone()).await; let mut request: Option> = None; + let mut limits = Limits::check_limits(urls.api.clone()).await; for _ in 0..=50 { let request_path = urls.api.clone() + "/some/random/nonexisting/path"; - let request_builder = requester.http.get(request_path); request = Some( requester - .send_request(request_builder, LimitType::Channel) + .send_request(request_builder, LimitType::Channel, &mut limits) .await, ); } @@ -296,11 +289,12 @@ mod rate_limit { String::from("wss://localhost:3001/"), String::from("http://localhost:3001/cdn"), ); + let mut limits = Limits::check_limits(urls.api.clone()).await; let mut requester = LimitedRequester::new(urls.api.clone()).await; let request_path = urls.api.clone() + "/policies/instance/limits"; let request_builder = requester.http.get(request_path); let request = requester - .send_request(request_builder, LimitType::Channel) + .send_request(request_builder, LimitType::Channel, &mut limits) .await; let result = match request { Ok(result) => result, From aba42a68690b197798ccf04d595ab76549da86f7 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Mon, 24 Apr 2023 19:51:35 +0200 Subject: [PATCH 3/4] Give each user their own rate limits --- src/api/schemas.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/api/schemas.rs b/src/api/schemas.rs index 65b4fe1..d28e60e 100644 --- a/src/api/schemas.rs +++ b/src/api/schemas.rs @@ -3,7 +3,7 @@ pub mod schemas { use serde::{Deserialize, Serialize}; use std::{collections::HashMap, fmt}; - use crate::{errors::FieldFormatError, URLBundle}; + use crate::{api::limits::Limits, errors::FieldFormatError, URLBundle}; /** A struct that represents a well-formed email address. @@ -405,6 +405,7 @@ pub mod schemas { logged_in: bool, belongs_to: URLBundle, token: String, + rate_limits: Limits, pub settings: UserSettings, pub object: UserObject, } From a5943197d4bee6c435b4cfe8691a927249c2e812 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Mon, 24 Apr 2023 20:58:45 +0200 Subject: [PATCH 4/4] separate User and Instance limits. --- src/api/auth/login.rs | 11 +++- src/api/auth/register.rs | 13 ++++- src/api/policies/instance/limits.rs | 78 +++++++++++++++++++++++++++++ src/api/schemas.rs | 2 + src/limit.rs | 49 ++++++++++++++---- 5 files changed, 140 insertions(+), 13 deletions(-) diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 2cb61bd..a11db3c 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -17,8 +17,17 @@ pub mod login { let client = Client::new(); let endpoint_url = self.urls.get_api().to_string() + "/auth/login"; let request_builder = client.post(endpoint_url).body(json_schema.to_string()); + // We do not have a user yet, and the UserRateLimits will not be affected by a login + // request (since login is an instance wide limit), which is why we are just cloning the + // instances' limits to pass them on as user_rate_limits later. + let mut cloned_limits = self.limits.clone(); let response = requester - .send_request(request_builder, LimitType::AuthRegister, &mut self.limits) + .send_request( + request_builder, + LimitType::AuthRegister, + &mut self.limits, + &mut cloned_limits, + ) .await; if !response.is_ok() { return Err(InstanceServerError::NoResponse); diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index ef1062d..2e0db06 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -28,8 +28,17 @@ pub mod register { let client = Client::new(); let endpoint_url = self.urls.get_api().to_string() + "/auth/register"; let request_builder = client.post(endpoint_url).body(json_schema.to_string()); + // We do not have a user yet, and the UserRateLimits will not be affected by a login + // request (since register is an instance wide limit), which is why we are just cloning + // the instances' limits to pass them on as user_rate_limits later. + let mut cloned_limits = self.limits.clone(); let response = limited_requester - .send_request(request_builder, LimitType::AuthRegister, &mut self.limits) + .send_request( + request_builder, + LimitType::AuthRegister, + &mut self.limits, + &mut cloned_limits, + ) .await; if !response.is_ok() { return Err(InstanceServerError::NoResponse); @@ -111,7 +120,7 @@ mod test { AuthUsername::new("Hiiii".to_string()).unwrap(), Some(AuthPassword::new("mysupersecurepass123!".to_string()).unwrap()), true, - Some(AuthEmail::new("flori@aaaa.xyz".to_string()).unwrap()), + Some(AuthEmail::new("random978234@aaaa.xyz".to_string()).unwrap()), None, None, Some("2000-01-01".to_string()), diff --git a/src/api/policies/instance/limits.rs b/src/api/policies/instance/limits.rs index 299164b..c3c76a3 100644 --- a/src/api/policies/instance/limits.rs +++ b/src/api/policies/instance/limits.rs @@ -166,6 +166,69 @@ pub mod limits { } } + pub struct LimitsMutRef<'a> { + pub limit_absolute_messages: &'a mut Limit, + pub limit_absolute_register: &'a mut Limit, + pub limit_auth_login: &'a mut Limit, + pub limit_auth_register: &'a mut Limit, + pub limit_ip: &'a mut Limit, + pub limit_global: &'a mut Limit, + pub limit_error: &'a mut Limit, + pub limit_guild: &'a mut Limit, + pub limit_webhook: &'a mut Limit, + pub limit_channel: &'a mut Limit, + } + + impl LimitsMutRef<'_> { + pub fn combine_mut_ref<'a>( + instance_rate_limits: &'a mut Limits, + user_rate_limits: &'a mut Limits, + ) -> LimitsMutRef<'a> { + return LimitsMutRef { + limit_absolute_messages: &mut instance_rate_limits.limit_absolute_messages, + limit_absolute_register: &mut instance_rate_limits.limit_absolute_register, + limit_auth_login: &mut instance_rate_limits.limit_auth_login, + limit_auth_register: &mut instance_rate_limits.limit_auth_register, + limit_channel: &mut user_rate_limits.limit_channel, + limit_error: &mut user_rate_limits.limit_error, + limit_global: &mut instance_rate_limits.limit_global, + limit_guild: &mut user_rate_limits.limit_guild, + limit_ip: &mut instance_rate_limits.limit_ip, + limit_webhook: &mut user_rate_limits.limit_webhook, + }; + } + + pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit { + match limit_type { + &LimitType::AbsoluteMessage => &self.limit_absolute_messages, + &LimitType::AbsoluteRegister => &self.limit_absolute_register, + &LimitType::AuthLogin => &self.limit_auth_login, + &LimitType::AuthRegister => &self.limit_auth_register, + &LimitType::Channel => &self.limit_channel, + &LimitType::Error => &self.limit_error, + &LimitType::Global => &self.limit_global, + &LimitType::Guild => &self.limit_guild, + &LimitType::Ip => &self.limit_ip, + &LimitType::Webhook => &self.limit_webhook, + } + } + + pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit { + match limit_type { + &LimitType::AbsoluteMessage => &mut self.limit_absolute_messages, + &LimitType::AbsoluteRegister => &mut self.limit_absolute_register, + &LimitType::AuthLogin => &mut self.limit_auth_login, + &LimitType::AuthRegister => &mut self.limit_auth_register, + &LimitType::Channel => &mut self.limit_channel, + &LimitType::Error => &mut self.limit_error, + &LimitType::Global => &mut self.limit_global, + &LimitType::Guild => &mut self.limit_guild, + &LimitType::Ip => &mut self.limit_ip, + &LimitType::Webhook => &mut self.limit_webhook, + } + } + } + #[derive(Debug, Clone)] pub struct Limits { pub limit_absolute_messages: Limit, @@ -181,6 +244,21 @@ pub mod limits { } impl Limits { + pub fn combine(instance_rate_limits: &Limits, user_rate_limits: &Limits) -> Limits { + return Limits { + limit_absolute_messages: instance_rate_limits.limit_absolute_messages, + limit_absolute_register: instance_rate_limits.limit_absolute_register, + limit_auth_login: instance_rate_limits.limit_auth_login, + limit_auth_register: instance_rate_limits.limit_auth_register, + limit_channel: user_rate_limits.limit_channel, + limit_error: user_rate_limits.limit_error, + limit_global: instance_rate_limits.limit_global, + limit_guild: user_rate_limits.limit_guild, + limit_ip: instance_rate_limits.limit_ip, + limit_webhook: user_rate_limits.limit_webhook, + }; + } + pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit { match limit_type { &LimitType::AbsoluteMessage => &self.limit_absolute_messages, diff --git a/src/api/schemas.rs b/src/api/schemas.rs index d28e60e..0feb280 100644 --- a/src/api/schemas.rs +++ b/src/api/schemas.rs @@ -439,6 +439,7 @@ pub mod schemas { logged_in: bool, belongs_to: URLBundle, token: String, + rate_limits: Limits, settings: UserSettings, object: UserObject, ) -> User { @@ -446,6 +447,7 @@ pub mod schemas { logged_in, belongs_to, token, + rate_limits, settings, object, } diff --git a/src/limit.rs b/src/limit.rs index b69ad86..f7f89d5 100644 --- a/src/limit.rs +++ b/src/limit.rs @@ -1,5 +1,5 @@ use crate::{ - api::limits::{Limit, LimitType, Limits}, + api::limits::{Limit, LimitType, Limits, LimitsMutRef}, errors::InstanceServerError, }; @@ -68,9 +68,10 @@ impl LimitedRequester { &mut self, request: RequestBuilder, limit_type: LimitType, - rate_limits: &mut Limits, + instance_rate_limits: &mut Limits, + user_rate_limits: &mut Limits, ) -> Result { - if self.can_send_request(limit_type, rate_limits) { + if self.can_send_request(limit_type, instance_rate_limits, user_rate_limits) { let built_request = request .build() .unwrap_or_else(|e| panic!("Error while building the Request for sending: {}", e)); @@ -79,7 +80,12 @@ impl LimitedRequester { Ok(is_response) => is_response, Err(e) => panic!("An error occured while processing the response: {}", e), }; - self.update_limits(&response, limit_type, rate_limits); + self.update_limits( + &response, + limit_type, + instance_rate_limits, + user_rate_limits, + ); return Ok(response); } else { self.requests.push_back(TypedRequest { @@ -101,8 +107,16 @@ impl LimitedRequester { } } - fn can_send_request(&mut self, limit_type: LimitType, rate_limits: &Limits) -> bool { + fn can_send_request( + &mut self, + limit_type: LimitType, + instance_rate_limits: &Limits, + user_rate_limits: &Limits, + ) -> bool { // Check if all of the limits in this vec have at least one remaining request + + let rate_limits = Limits::combine(instance_rate_limits, user_rate_limits); + let constant_limits: Vec<&LimitType> = [ &LimitType::Error, &LimitType::Global, @@ -148,8 +162,11 @@ impl LimitedRequester { &mut self, response: &Response, limit_type: LimitType, - rate_limits: &mut Limits, + instance_rate_limits: &mut Limits, + user_rate_limits: &mut Limits, ) { + let mut rate_limits = LimitsMutRef::combine_mut_ref(instance_rate_limits, user_rate_limits); + let remaining = match response.headers().get("X-RateLimit-Remaining") { Some(remaining) => remaining.to_str().unwrap().parse::().unwrap(), None => rate_limits.get_limit_mut_ref(&limit_type).remaining - 1, @@ -261,14 +278,20 @@ mod rate_limit { ); let mut requester = LimitedRequester::new(urls.api.clone()).await; let mut request: Option> = None; - let mut limits = Limits::check_limits(urls.api.clone()).await; + let mut instance_rate_limits = Limits::check_limits(urls.api.clone()).await; + let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await; for _ in 0..=50 { let request_path = urls.api.clone() + "/some/random/nonexisting/path"; let request_builder = requester.http.get(request_path); request = Some( requester - .send_request(request_builder, LimitType::Channel, &mut limits) + .send_request( + request_builder, + LimitType::Channel, + &mut instance_rate_limits, + &mut user_rate_limits, + ) .await, ); } @@ -289,12 +312,18 @@ mod rate_limit { String::from("wss://localhost:3001/"), String::from("http://localhost:3001/cdn"), ); - let mut limits = Limits::check_limits(urls.api.clone()).await; + let mut instance_rate_limits = Limits::check_limits(urls.api.clone()).await; + let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await; let mut requester = LimitedRequester::new(urls.api.clone()).await; let request_path = urls.api.clone() + "/policies/instance/limits"; let request_builder = requester.http.get(request_path); let request = requester - .send_request(request_builder, LimitType::Channel, &mut limits) + .send_request( + request_builder, + LimitType::Channel, + &mut instance_rate_limits, + &mut user_rate_limits, + ) .await; let result = match request { Ok(result) => result,