From 17eab8169e024d46b4173be9bcbdb29aba61cca0 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Mon, 24 Apr 2023 19:49:26 +0200 Subject: [PATCH] Change HashMap<> to Limits --- src/api/auth/login.rs | 2 +- src/api/auth/register.rs | 2 +- src/instance.rs | 5 +- src/limit.rs | 108 ++++++++++++++++++--------------------- 4 files changed, 57 insertions(+), 60 deletions(-) diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 63f5a71..2cb61bd 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -18,7 +18,7 @@ pub mod login { let endpoint_url = self.urls.get_api().to_string() + "/auth/login"; let request_builder = client.post(endpoint_url).body(json_schema.to_string()); let response = requester - .send_request(request_builder, LimitType::AuthRegister) + .send_request(request_builder, LimitType::AuthRegister, &mut self.limits) .await; if !response.is_ok() { return Err(InstanceServerError::NoResponse); diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index d25f7c5..ef1062d 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -29,7 +29,7 @@ pub mod register { let endpoint_url = self.urls.get_api().to_string() + "/auth/register"; let request_builder = client.post(endpoint_url).body(json_schema.to_string()); let response = limited_requester - .send_request(request_builder, LimitType::AuthRegister) + .send_request(request_builder, LimitType::AuthRegister, &mut self.limits) .await; if !response.is_ok() { return Err(InstanceServerError::NoResponse); diff --git a/src/instance.rs b/src/instance.rs index a655735..702d25b 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -1,3 +1,4 @@ +use crate::api::limits::{Limit, LimitType, Limits}; use crate::api::schemas::schemas::InstancePoliciesSchema; use crate::errors::{FieldFormatError, InstanceServerError}; use crate::limit::LimitedRequester; @@ -14,6 +15,7 @@ pub struct Instance { pub urls: URLBundle, pub instance_info: InstancePoliciesSchema, pub requester: LimitedRequester, + pub limits: Limits, //pub gateway: Gateway, pub users: HashMap, } @@ -31,7 +33,7 @@ impl Instance { ) -> Result { let users: HashMap = HashMap::new(); let mut instance = Instance { - urls, + urls: urls.clone(), instance_info: InstancePoliciesSchema::new( // This is okay, because the instance_info will be overwritten by the instance_policies_schema() function. "".to_string(), @@ -43,6 +45,7 @@ impl Instance { None, None, ), + limits: Limits::check_limits(urls.api).await, requester, users, }; diff --git a/src/limit.rs b/src/limit.rs index f5ac4b6..b69ad86 100644 --- a/src/limit.rs +++ b/src/limit.rs @@ -4,7 +4,7 @@ use crate::{ }; use reqwest::{Client, RequestBuilder, Response}; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; // Note: There seem to be some overlapping request limiters. We need to make sure that sending a // request checks for all the request limiters that apply, and blocks if any of the limiters are 0 @@ -20,7 +20,6 @@ pub struct TypedRequest { pub struct LimitedRequester { http: Client, requests: VecDeque, - limits_rate: HashMap, } impl LimitedRequester { @@ -33,7 +32,6 @@ impl LimitedRequester { LimitedRequester { http: Client::new(), requests: VecDeque::new(), - limits_rate: Limits::check_limits(api_url).await, } } @@ -70,8 +68,9 @@ impl LimitedRequester { &mut self, request: RequestBuilder, limit_type: LimitType, + rate_limits: &mut Limits, ) -> Result { - if self.can_send_request(limit_type) { + if self.can_send_request(limit_type, rate_limits) { let built_request = request .build() .unwrap_or_else(|e| panic!("Error while building the Request for sending: {}", e)); @@ -80,7 +79,7 @@ impl LimitedRequester { Ok(is_response) => is_response, Err(e) => panic!("An error occured while processing the response: {}", e), }; - self.update_limits(&response, limit_type); + self.update_limits(&response, limit_type, rate_limits); return Ok(response); } else { self.requests.push_back(TypedRequest { @@ -102,8 +101,7 @@ impl LimitedRequester { } } - fn can_send_request(&mut self, limit_type: LimitType) -> bool { - let limits = &self.limits_rate.clone(); + fn can_send_request(&mut self, limit_type: LimitType, rate_limits: &Limits) -> bool { // Check if all of the limits in this vec have at least one remaining request let constant_limits: Vec<&LimitType> = [ &LimitType::Error, @@ -113,19 +111,29 @@ impl LimitedRequester { ] .to_vec(); for limit in constant_limits.iter() { - match limits.get(&limit) { + match rate_limits.to_hash_map().get(&limit) { Some(limit) => { if limit.remaining == 0 { return false; } // AbsoluteRegister and AuthRegister can cancel each other out. if limit.bucket == LimitType::AbsoluteRegister - && limits.get(&LimitType::AuthRegister).unwrap().remaining == 0 + && rate_limits + .to_hash_map() + .get(&LimitType::AuthRegister) + .unwrap() + .remaining + == 0 { return false; } if limit.bucket == LimitType::AuthRegister - && limits.get(&LimitType::AbsoluteRegister).unwrap().remaining == 0 + && rate_limits + .to_hash_map() + .get(&LimitType::AbsoluteRegister) + .unwrap() + .remaining + == 0 { return false; } @@ -136,97 +144,91 @@ impl LimitedRequester { return true; } - fn update_limits(&mut self, response: &Response, limit_type: LimitType) { + fn update_limits( + &mut self, + response: &Response, + limit_type: LimitType, + rate_limits: &mut Limits, + ) { let remaining = match response.headers().get("X-RateLimit-Remaining") { Some(remaining) => remaining.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().remaining - 1, + None => rate_limits.get_limit_mut_ref(&limit_type).remaining - 1, }; let limit = match response.headers().get("X-RateLimit-Limit") { Some(limit) => limit.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().limit, + None => rate_limits.get_limit_mut_ref(&limit_type).limit, }; let reset = match response.headers().get("X-RateLimit-Reset") { Some(reset) => reset.to_str().unwrap().parse::().unwrap(), - None => self.limits_rate.get(&limit_type).unwrap().reset, + None => rate_limits.get_limit_mut_ref(&limit_type).reset, }; let status = response.status(); let status_str = status.as_str(); if status_str.chars().next().unwrap() == '4' { - self.limits_rate - .get_mut(&LimitType::Error) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::Error) .add_remaining(-1); } - self.limits_rate - .get_mut(&LimitType::Global) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::Global) .add_remaining(-1); - self.limits_rate - .get_mut(&LimitType::Ip) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::Ip) .add_remaining(-1); - let mut_limits_rate = &mut self.limits_rate; - match limit_type { LimitType::Error => { - let entry = mut_limits_rate.get_mut(&LimitType::Error).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Error); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Global => { - let entry = mut_limits_rate.get_mut(&LimitType::Global).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Global); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Ip => { - let entry = mut_limits_rate.get_mut(&LimitType::Ip).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Ip); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::AuthLogin => { - let entry = mut_limits_rate.get_mut(&LimitType::AuthLogin).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthLogin); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::AbsoluteRegister => { - let entry = mut_limits_rate - .get_mut(&LimitType::AbsoluteRegister) - .unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteRegister); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); // AbsoluteRegister and AuthRegister both need to be updated, if a Register event // happens. - mut_limits_rate - .get_mut(&LimitType::AuthRegister) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::AuthRegister) .remaining -= 1; } LimitType::AuthRegister => { - let entry = mut_limits_rate.get_mut(&LimitType::AuthRegister).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthRegister); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); // AbsoluteRegister and AuthRegister both need to be updated, if a Register event // happens. - mut_limits_rate - .get_mut(&LimitType::AbsoluteRegister) - .unwrap() + rate_limits + .get_limit_mut_ref(&LimitType::AbsoluteRegister) .remaining -= 1; } LimitType::AbsoluteMessage => { - let entry = mut_limits_rate - .get_mut(&LimitType::AbsoluteMessage) - .unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteMessage); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Channel => { - let entry = mut_limits_rate.get_mut(&LimitType::Channel).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Channel); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Guild => { - let entry = mut_limits_rate.get_mut(&LimitType::Guild).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Guild); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } LimitType::Webhook => { - let entry = mut_limits_rate.get_mut(&LimitType::Webhook).unwrap(); + let entry = rate_limits.get_limit_mut_ref(&LimitType::Webhook); LimitedRequester::update_limit_entry(entry, reset, remaining, limit); } } @@ -248,15 +250,6 @@ mod rate_limit { String::from("http://localhost:3001/cdn"), ); let requester = LimitedRequester::new(urls.api).await; - assert_eq!( - requester.limits_rate.get(&LimitType::Ip).unwrap(), - &Limit { - bucket: LimitType::Ip, - limit: 500, - remaining: 500, - reset: 5 - } - ); } #[tokio::test] @@ -268,14 +261,14 @@ mod rate_limit { ); let mut requester = LimitedRequester::new(urls.api.clone()).await; let mut request: Option> = None; + let mut limits = Limits::check_limits(urls.api.clone()).await; for _ in 0..=50 { let request_path = urls.api.clone() + "/some/random/nonexisting/path"; - let request_builder = requester.http.get(request_path); request = Some( requester - .send_request(request_builder, LimitType::Channel) + .send_request(request_builder, LimitType::Channel, &mut limits) .await, ); } @@ -296,11 +289,12 @@ mod rate_limit { String::from("wss://localhost:3001/"), String::from("http://localhost:3001/cdn"), ); + let mut limits = Limits::check_limits(urls.api.clone()).await; let mut requester = LimitedRequester::new(urls.api.clone()).await; let request_path = urls.api.clone() + "/policies/instance/limits"; let request_builder = requester.http.get(request_path); let request = requester - .send_request(request_builder, LimitType::Channel) + .send_request(request_builder, LimitType::Channel, &mut limits) .await; let result = match request { Ok(result) => result,