Change HashMap<> to Limits

This commit is contained in:
bitfl0wer 2023-04-24 19:49:26 +02:00
parent 21c7bdcf8b
commit 30742380b9
4 changed files with 57 additions and 60 deletions

View File

@ -18,7 +18,7 @@ pub mod login {
let endpoint_url = self.urls.get_api().to_string() + "/auth/login"; let endpoint_url = self.urls.get_api().to_string() + "/auth/login";
let request_builder = client.post(endpoint_url).body(json_schema.to_string()); let request_builder = client.post(endpoint_url).body(json_schema.to_string());
let response = requester let response = requester
.send_request(request_builder, LimitType::AuthRegister) .send_request(request_builder, LimitType::AuthRegister, &mut self.limits)
.await; .await;
if !response.is_ok() { if !response.is_ok() {
return Err(InstanceServerError::NoResponse); return Err(InstanceServerError::NoResponse);

View File

@ -29,7 +29,7 @@ pub mod register {
let endpoint_url = self.urls.get_api().to_string() + "/auth/register"; let endpoint_url = self.urls.get_api().to_string() + "/auth/register";
let request_builder = client.post(endpoint_url).body(json_schema.to_string()); let request_builder = client.post(endpoint_url).body(json_schema.to_string());
let response = limited_requester let response = limited_requester
.send_request(request_builder, LimitType::AuthRegister) .send_request(request_builder, LimitType::AuthRegister, &mut self.limits)
.await; .await;
if !response.is_ok() { if !response.is_ok() {
return Err(InstanceServerError::NoResponse); return Err(InstanceServerError::NoResponse);

View File

@ -1,3 +1,4 @@
use crate::api::limits::{Limit, LimitType, Limits};
use crate::api::schemas::schemas::InstancePoliciesSchema; use crate::api::schemas::schemas::InstancePoliciesSchema;
use crate::errors::{FieldFormatError, InstanceServerError}; use crate::errors::{FieldFormatError, InstanceServerError};
use crate::limit::LimitedRequester; use crate::limit::LimitedRequester;
@ -14,6 +15,7 @@ pub struct Instance {
pub urls: URLBundle, pub urls: URLBundle,
pub instance_info: InstancePoliciesSchema, pub instance_info: InstancePoliciesSchema,
pub requester: LimitedRequester, pub requester: LimitedRequester,
pub limits: Limits,
//pub gateway: Gateway, //pub gateway: Gateway,
pub users: HashMap<Token, Username>, pub users: HashMap<Token, Username>,
} }
@ -31,7 +33,7 @@ impl Instance {
) -> Result<Instance, InstanceServerError> { ) -> Result<Instance, InstanceServerError> {
let users: HashMap<Token, Username> = HashMap::new(); let users: HashMap<Token, Username> = HashMap::new();
let mut instance = Instance { let mut instance = Instance {
urls, urls: urls.clone(),
instance_info: InstancePoliciesSchema::new( instance_info: InstancePoliciesSchema::new(
// This is okay, because the instance_info will be overwritten by the instance_policies_schema() function. // This is okay, because the instance_info will be overwritten by the instance_policies_schema() function.
"".to_string(), "".to_string(),
@ -43,6 +45,7 @@ impl Instance {
None, None,
None, None,
), ),
limits: Limits::check_limits(urls.api).await,
requester, requester,
users, users,
}; };

View File

@ -4,7 +4,7 @@ use crate::{
}; };
use reqwest::{Client, RequestBuilder, Response}; use reqwest::{Client, RequestBuilder, Response};
use std::collections::{HashMap, VecDeque}; use std::collections::VecDeque;
// Note: There seem to be some overlapping request limiters. We need to make sure that sending a // Note: There seem to be some overlapping request limiters. We need to make sure that sending a
// request checks for all the request limiters that apply, and blocks if any of the limiters are 0 // request checks for all the request limiters that apply, and blocks if any of the limiters are 0
@ -20,7 +20,6 @@ pub struct TypedRequest {
pub struct LimitedRequester { pub struct LimitedRequester {
http: Client, http: Client,
requests: VecDeque<TypedRequest>, requests: VecDeque<TypedRequest>,
limits_rate: HashMap<LimitType, Limit>,
} }
impl LimitedRequester { impl LimitedRequester {
@ -33,7 +32,6 @@ impl LimitedRequester {
LimitedRequester { LimitedRequester {
http: Client::new(), http: Client::new(),
requests: VecDeque::new(), requests: VecDeque::new(),
limits_rate: Limits::check_limits(api_url).await,
} }
} }
@ -70,8 +68,9 @@ impl LimitedRequester {
&mut self, &mut self,
request: RequestBuilder, request: RequestBuilder,
limit_type: LimitType, limit_type: LimitType,
rate_limits: &mut Limits,
) -> Result<Response, InstanceServerError> { ) -> Result<Response, InstanceServerError> {
if self.can_send_request(limit_type) { if self.can_send_request(limit_type, rate_limits) {
let built_request = request let built_request = request
.build() .build()
.unwrap_or_else(|e| panic!("Error while building the Request for sending: {}", e)); .unwrap_or_else(|e| panic!("Error while building the Request for sending: {}", e));
@ -80,7 +79,7 @@ impl LimitedRequester {
Ok(is_response) => is_response, Ok(is_response) => is_response,
Err(e) => panic!("An error occured while processing the response: {}", e), Err(e) => panic!("An error occured while processing the response: {}", e),
}; };
self.update_limits(&response, limit_type); self.update_limits(&response, limit_type, rate_limits);
return Ok(response); return Ok(response);
} else { } else {
self.requests.push_back(TypedRequest { self.requests.push_back(TypedRequest {
@ -102,8 +101,7 @@ impl LimitedRequester {
} }
} }
fn can_send_request(&mut self, limit_type: LimitType) -> bool { fn can_send_request(&mut self, limit_type: LimitType, rate_limits: &Limits) -> bool {
let limits = &self.limits_rate.clone();
// Check if all of the limits in this vec have at least one remaining request // Check if all of the limits in this vec have at least one remaining request
let constant_limits: Vec<&LimitType> = [ let constant_limits: Vec<&LimitType> = [
&LimitType::Error, &LimitType::Error,
@ -113,19 +111,29 @@ impl LimitedRequester {
] ]
.to_vec(); .to_vec();
for limit in constant_limits.iter() { for limit in constant_limits.iter() {
match limits.get(&limit) { match rate_limits.to_hash_map().get(&limit) {
Some(limit) => { Some(limit) => {
if limit.remaining == 0 { if limit.remaining == 0 {
return false; return false;
} }
// AbsoluteRegister and AuthRegister can cancel each other out. // AbsoluteRegister and AuthRegister can cancel each other out.
if limit.bucket == LimitType::AbsoluteRegister if limit.bucket == LimitType::AbsoluteRegister
&& limits.get(&LimitType::AuthRegister).unwrap().remaining == 0 && rate_limits
.to_hash_map()
.get(&LimitType::AuthRegister)
.unwrap()
.remaining
== 0
{ {
return false; return false;
} }
if limit.bucket == LimitType::AuthRegister if limit.bucket == LimitType::AuthRegister
&& limits.get(&LimitType::AbsoluteRegister).unwrap().remaining == 0 && rate_limits
.to_hash_map()
.get(&LimitType::AbsoluteRegister)
.unwrap()
.remaining
== 0
{ {
return false; return false;
} }
@ -136,97 +144,91 @@ impl LimitedRequester {
return true; return true;
} }
fn update_limits(&mut self, response: &Response, limit_type: LimitType) { fn update_limits(
&mut self,
response: &Response,
limit_type: LimitType,
rate_limits: &mut Limits,
) {
let remaining = match response.headers().get("X-RateLimit-Remaining") { let remaining = match response.headers().get("X-RateLimit-Remaining") {
Some(remaining) => remaining.to_str().unwrap().parse::<u64>().unwrap(), Some(remaining) => remaining.to_str().unwrap().parse::<u64>().unwrap(),
None => self.limits_rate.get(&limit_type).unwrap().remaining - 1, None => rate_limits.get_limit_mut_ref(&limit_type).remaining - 1,
}; };
let limit = match response.headers().get("X-RateLimit-Limit") { let limit = match response.headers().get("X-RateLimit-Limit") {
Some(limit) => limit.to_str().unwrap().parse::<u64>().unwrap(), Some(limit) => limit.to_str().unwrap().parse::<u64>().unwrap(),
None => self.limits_rate.get(&limit_type).unwrap().limit, None => rate_limits.get_limit_mut_ref(&limit_type).limit,
}; };
let reset = match response.headers().get("X-RateLimit-Reset") { let reset = match response.headers().get("X-RateLimit-Reset") {
Some(reset) => reset.to_str().unwrap().parse::<u64>().unwrap(), Some(reset) => reset.to_str().unwrap().parse::<u64>().unwrap(),
None => self.limits_rate.get(&limit_type).unwrap().reset, None => rate_limits.get_limit_mut_ref(&limit_type).reset,
}; };
let status = response.status(); let status = response.status();
let status_str = status.as_str(); let status_str = status.as_str();
if status_str.chars().next().unwrap() == '4' { if status_str.chars().next().unwrap() == '4' {
self.limits_rate rate_limits
.get_mut(&LimitType::Error) .get_limit_mut_ref(&LimitType::Error)
.unwrap()
.add_remaining(-1); .add_remaining(-1);
} }
self.limits_rate rate_limits
.get_mut(&LimitType::Global) .get_limit_mut_ref(&LimitType::Global)
.unwrap()
.add_remaining(-1); .add_remaining(-1);
self.limits_rate rate_limits
.get_mut(&LimitType::Ip) .get_limit_mut_ref(&LimitType::Ip)
.unwrap()
.add_remaining(-1); .add_remaining(-1);
let mut_limits_rate = &mut self.limits_rate;
match limit_type { match limit_type {
LimitType::Error => { LimitType::Error => {
let entry = mut_limits_rate.get_mut(&LimitType::Error).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Error);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Global => { LimitType::Global => {
let entry = mut_limits_rate.get_mut(&LimitType::Global).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Global);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Ip => { LimitType::Ip => {
let entry = mut_limits_rate.get_mut(&LimitType::Ip).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Ip);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::AuthLogin => { LimitType::AuthLogin => {
let entry = mut_limits_rate.get_mut(&LimitType::AuthLogin).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthLogin);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::AbsoluteRegister => { LimitType::AbsoluteRegister => {
let entry = mut_limits_rate let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteRegister);
.get_mut(&LimitType::AbsoluteRegister)
.unwrap();
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
// AbsoluteRegister and AuthRegister both need to be updated, if a Register event // AbsoluteRegister and AuthRegister both need to be updated, if a Register event
// happens. // happens.
mut_limits_rate rate_limits
.get_mut(&LimitType::AuthRegister) .get_limit_mut_ref(&LimitType::AuthRegister)
.unwrap()
.remaining -= 1; .remaining -= 1;
} }
LimitType::AuthRegister => { LimitType::AuthRegister => {
let entry = mut_limits_rate.get_mut(&LimitType::AuthRegister).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthRegister);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
// AbsoluteRegister and AuthRegister both need to be updated, if a Register event // AbsoluteRegister and AuthRegister both need to be updated, if a Register event
// happens. // happens.
mut_limits_rate rate_limits
.get_mut(&LimitType::AbsoluteRegister) .get_limit_mut_ref(&LimitType::AbsoluteRegister)
.unwrap()
.remaining -= 1; .remaining -= 1;
} }
LimitType::AbsoluteMessage => { LimitType::AbsoluteMessage => {
let entry = mut_limits_rate let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteMessage);
.get_mut(&LimitType::AbsoluteMessage)
.unwrap();
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Channel => { LimitType::Channel => {
let entry = mut_limits_rate.get_mut(&LimitType::Channel).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Channel);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Guild => { LimitType::Guild => {
let entry = mut_limits_rate.get_mut(&LimitType::Guild).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Guild);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Webhook => { LimitType::Webhook => {
let entry = mut_limits_rate.get_mut(&LimitType::Webhook).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Webhook);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
} }
@ -248,15 +250,6 @@ mod rate_limit {
String::from("http://localhost:3001/cdn"), String::from("http://localhost:3001/cdn"),
); );
let requester = LimitedRequester::new(urls.api).await; let requester = LimitedRequester::new(urls.api).await;
assert_eq!(
requester.limits_rate.get(&LimitType::Ip).unwrap(),
&Limit {
bucket: LimitType::Ip,
limit: 500,
remaining: 500,
reset: 5
}
);
} }
#[tokio::test] #[tokio::test]
@ -268,14 +261,14 @@ mod rate_limit {
); );
let mut requester = LimitedRequester::new(urls.api.clone()).await; let mut requester = LimitedRequester::new(urls.api.clone()).await;
let mut request: Option<Result<Response, InstanceServerError>> = None; let mut request: Option<Result<Response, InstanceServerError>> = None;
let mut limits = Limits::check_limits(urls.api.clone()).await;
for _ in 0..=50 { for _ in 0..=50 {
let request_path = urls.api.clone() + "/some/random/nonexisting/path"; let request_path = urls.api.clone() + "/some/random/nonexisting/path";
let request_builder = requester.http.get(request_path); let request_builder = requester.http.get(request_path);
request = Some( request = Some(
requester requester
.send_request(request_builder, LimitType::Channel) .send_request(request_builder, LimitType::Channel, &mut limits)
.await, .await,
); );
} }
@ -296,11 +289,12 @@ mod rate_limit {
String::from("wss://localhost:3001/"), String::from("wss://localhost:3001/"),
String::from("http://localhost:3001/cdn"), String::from("http://localhost:3001/cdn"),
); );
let mut limits = Limits::check_limits(urls.api.clone()).await;
let mut requester = LimitedRequester::new(urls.api.clone()).await; let mut requester = LimitedRequester::new(urls.api.clone()).await;
let request_path = urls.api.clone() + "/policies/instance/limits"; let request_path = urls.api.clone() + "/policies/instance/limits";
let request_builder = requester.http.get(request_path); let request_builder = requester.http.get(request_path);
let request = requester let request = requester
.send_request(request_builder, LimitType::Channel) .send_request(request_builder, LimitType::Channel, &mut limits)
.await; .await;
let result = match request { let result = match request {
Ok(result) => result, Ok(result) => result,