Merge remote-tracking branch 'origin/main' into feature/gateway-observer

This commit is contained in:
bitfl0wer 2023-04-25 18:00:44 +02:00
commit 27206e4a0f
10 changed files with 734 additions and 572 deletions

View File

@ -3,7 +3,8 @@ pub mod login {
use serde_json::{from_str, json}; use serde_json::{from_str, json};
use crate::api::limits::LimitType; use crate::api::limits::LimitType;
use crate::api::schemas::schemas::{ErrorResponse, LoginResult, LoginSchema}; use crate::api::schemas::LoginSchema;
use crate::api::types::{ErrorResponse, LoginResult};
use crate::errors::InstanceServerError; use crate::errors::InstanceServerError;
use crate::instance::Instance; use crate::instance::Instance;
@ -17,10 +18,19 @@ pub mod login {
let client = Client::new(); let client = Client::new();
let endpoint_url = self.urls.get_api().to_string() + "/auth/login"; let endpoint_url = self.urls.get_api().to_string() + "/auth/login";
let request_builder = client.post(endpoint_url).body(json_schema.to_string()); let request_builder = client.post(endpoint_url).body(json_schema.to_string());
// We do not have a user yet, and the UserRateLimits will not be affected by a login
// request (since login is an instance wide limit), which is why we are just cloning the
// instances' limits to pass them on as user_rate_limits later.
let mut cloned_limits = self.limits.clone();
let response = requester let response = requester
.send_request(request_builder, LimitType::AuthRegister) .send_request(
request_builder,
LimitType::AuthRegister,
&mut self.limits,
&mut cloned_limits,
)
.await; .await;
if !response.is_ok() { if response.is_err() {
return Err(InstanceServerError::NoResponse); return Err(InstanceServerError::NoResponse);
} }
@ -41,7 +51,7 @@ pub mod login {
let login_result: LoginResult = from_str(&response_text_string).unwrap(); let login_result: LoginResult = from_str(&response_text_string).unwrap();
return Ok(login_result); Ok(login_result)
} }
} }
} }

View File

@ -3,10 +3,7 @@ pub mod register {
use serde_json::json; use serde_json::json;
use crate::{ use crate::{
api::{ api::{limits::LimitType, schemas::RegisterSchema, types::ErrorResponse},
limits::LimitType,
schemas::schemas::{ErrorResponse, RegisterSchema},
},
errors::InstanceServerError, errors::InstanceServerError,
instance::{Instance, Token}, instance::{Instance, Token},
}; };
@ -28,10 +25,19 @@ pub mod register {
let client = Client::new(); let client = Client::new();
let endpoint_url = self.urls.get_api().to_string() + "/auth/register"; let endpoint_url = self.urls.get_api().to_string() + "/auth/register";
let request_builder = client.post(endpoint_url).body(json_schema.to_string()); let request_builder = client.post(endpoint_url).body(json_schema.to_string());
// We do not have a user yet, and the UserRateLimits will not be affected by a login
// request (since register is an instance wide limit), which is why we are just cloning
// the instances' limits to pass them on as user_rate_limits later.
let mut cloned_limits = self.limits.clone();
let response = limited_requester let response = limited_requester
.send_request(request_builder, LimitType::AuthRegister) .send_request(
request_builder,
LimitType::AuthRegister,
&mut self.limits,
&mut cloned_limits,
)
.await; .await;
if !response.is_ok() { if response.is_err() {
return Err(InstanceServerError::NoResponse); return Err(InstanceServerError::NoResponse);
} }
@ -49,16 +55,16 @@ pub mod register {
} }
return Err(InstanceServerError::InvalidFormBodyError { error_type, error }); return Err(InstanceServerError::InvalidFormBodyError { error_type, error });
} }
return Ok(Token { Ok(Token {
token: response_text_string, token: response_text_string,
}); })
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::api::schemas::schemas::{AuthEmail, AuthPassword, AuthUsername, RegisterSchema}; use crate::api::schemas::{AuthEmail, AuthPassword, AuthUsername, RegisterSchema};
use crate::errors::InstanceServerError; use crate::errors::InstanceServerError;
use crate::instance::Instance; use crate::instance::Instance;
use crate::limit::LimitedRequester; use crate::limit::LimitedRequester;
@ -70,7 +76,7 @@ mod test {
"http://localhost:3001".to_string(), "http://localhost:3001".to_string(),
"http://localhost:3001".to_string(), "http://localhost:3001".to_string(),
); );
let limited_requester = LimitedRequester::new(urls.get_api().to_string()).await; let limited_requester = LimitedRequester::new().await;
let mut test_instance = Instance::new(urls.clone(), limited_requester) let mut test_instance = Instance::new(urls.clone(), limited_requester)
.await .await
.unwrap(); .unwrap();
@ -103,7 +109,7 @@ mod test {
"http://localhost:3001".to_string(), "http://localhost:3001".to_string(),
"http://localhost:3001".to_string(), "http://localhost:3001".to_string(),
); );
let limited_requester = LimitedRequester::new(urls.get_api().to_string()).await; let limited_requester = LimitedRequester::new().await;
let mut test_instance = Instance::new(urls.clone(), limited_requester) let mut test_instance = Instance::new(urls.clone(), limited_requester)
.await .await
.unwrap(); .unwrap();
@ -111,7 +117,7 @@ mod test {
AuthUsername::new("Hiiii".to_string()).unwrap(), AuthUsername::new("Hiiii".to_string()).unwrap(),
Some(AuthPassword::new("mysupersecurepass123!".to_string()).unwrap()), Some(AuthPassword::new("mysupersecurepass123!".to_string()).unwrap()),
true, true,
Some(AuthEmail::new("flori@aaaa.xyz".to_string()).unwrap()), Some(AuthEmail::new("random978234@aaaa.xyz".to_string()).unwrap()),
None, None,
None, None,
Some("2000-01-01".to_string()), Some("2000-01-01".to_string()),

View File

@ -1,7 +1,9 @@
pub mod auth; pub mod auth;
pub mod policies; pub mod policies;
pub mod schemas; pub mod schemas;
pub mod types;
pub use policies::instance::instance::*; pub use policies::instance::instance::*;
pub use policies::instance::limits::*; pub use policies::instance::limits::*;
pub use schemas::*; pub use schemas::*;
pub use types::*;

View File

@ -3,7 +3,7 @@ pub mod instance {
use serde_json::from_str; use serde_json::from_str;
use crate::errors::InstanceServerError; use crate::errors::InstanceServerError;
use crate::{api::schemas::schemas::InstancePoliciesSchema, instance::Instance}; use crate::{api::types::InstancePolicies, instance::Instance};
impl Instance { impl Instance {
/** /**
@ -13,7 +13,7 @@ pub mod instance {
*/ */
pub async fn instance_policies_schema( pub async fn instance_policies_schema(
&self, &self,
) -> Result<InstancePoliciesSchema, InstanceServerError> { ) -> Result<InstancePolicies, InstanceServerError> {
let client = Client::new(); let client = Client::new();
let endpoint_url = self.urls.get_api().to_string() + "/policies/instance/"; let endpoint_url = self.urls.get_api().to_string() + "/policies/instance/";
let request = match client.get(&endpoint_url).send().await { let request = match client.get(&endpoint_url).send().await {
@ -26,14 +26,14 @@ pub mod instance {
} }
}; };
if request.status().as_str().chars().next().unwrap() != '2' { if !request.status().as_str().starts_with('2') {
return Err(InstanceServerError::ReceivedErrorCodeError { return Err(InstanceServerError::ReceivedErrorCodeError {
error_code: request.status().to_string(), error_code: request.status().to_string(),
}); });
} }
let body = request.text().await.unwrap(); let body = request.text().await.unwrap();
let instance_policies_schema: InstancePoliciesSchema = from_str(&body).unwrap(); let instance_policies_schema: InstancePolicies = from_str(&body).unwrap();
Ok(instance_policies_schema) Ok(instance_policies_schema)
} }
} }
@ -50,7 +50,7 @@ mod instance_policies_schema_test {
"http://localhost:3001".to_string(), "http://localhost:3001".to_string(),
"http://localhost:3001".to_string(), "http://localhost:3001".to_string(),
); );
let limited_requester = LimitedRequester::new(urls.get_api().to_string()).await; let limited_requester = LimitedRequester::new().await;
let test_instance = Instance::new(urls.clone(), limited_requester) let test_instance = Instance::new(urls.clone(), limited_requester)
.await .await
.unwrap(); .unwrap();

View File

@ -155,17 +155,81 @@ pub mod limits {
impl Limit { impl Limit {
pub fn add_remaining(&mut self, remaining: i64) { pub fn add_remaining(&mut self, remaining: i64) {
if remaining < 0 { if remaining < 0 {
if ((self.remaining as i64 + remaining) as i64) <= 0 { if (self.remaining as i64 + remaining) <= 0 {
self.remaining = 0; self.remaining = 0;
return; return;
} }
self.remaining -= remaining.abs() as u64; self.remaining -= remaining.unsigned_abs();
return; return;
} }
self.remaining += remaining.abs() as u64; self.remaining += remaining.unsigned_abs();
} }
} }
pub struct LimitsMutRef<'a> {
pub limit_absolute_messages: &'a mut Limit,
pub limit_absolute_register: &'a mut Limit,
pub limit_auth_login: &'a mut Limit,
pub limit_auth_register: &'a mut Limit,
pub limit_ip: &'a mut Limit,
pub limit_global: &'a mut Limit,
pub limit_error: &'a mut Limit,
pub limit_guild: &'a mut Limit,
pub limit_webhook: &'a mut Limit,
pub limit_channel: &'a mut Limit,
}
impl LimitsMutRef<'_> {
pub fn combine_mut_ref<'a>(
instance_rate_limits: &'a mut Limits,
user_rate_limits: &'a mut Limits,
) -> LimitsMutRef<'a> {
LimitsMutRef {
limit_absolute_messages: &mut instance_rate_limits.limit_absolute_messages,
limit_absolute_register: &mut instance_rate_limits.limit_absolute_register,
limit_auth_login: &mut instance_rate_limits.limit_auth_login,
limit_auth_register: &mut instance_rate_limits.limit_auth_register,
limit_channel: &mut user_rate_limits.limit_channel,
limit_error: &mut user_rate_limits.limit_error,
limit_global: &mut instance_rate_limits.limit_global,
limit_guild: &mut user_rate_limits.limit_guild,
limit_ip: &mut instance_rate_limits.limit_ip,
limit_webhook: &mut user_rate_limits.limit_webhook,
}
}
pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit {
match limit_type {
&LimitType::AbsoluteMessage => self.limit_absolute_messages,
&LimitType::AbsoluteRegister => self.limit_absolute_register,
&LimitType::AuthLogin => self.limit_auth_login,
&LimitType::AuthRegister => self.limit_auth_register,
&LimitType::Channel => self.limit_channel,
&LimitType::Error => self.limit_error,
&LimitType::Global => self.limit_global,
&LimitType::Guild => self.limit_guild,
&LimitType::Ip => self.limit_ip,
&LimitType::Webhook => self.limit_webhook,
}
}
pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit {
match limit_type {
&LimitType::AbsoluteMessage => self.limit_absolute_messages,
&LimitType::AbsoluteRegister => self.limit_absolute_register,
&LimitType::AuthLogin => self.limit_auth_login,
&LimitType::AuthRegister => self.limit_auth_register,
&LimitType::Channel => self.limit_channel,
&LimitType::Error => self.limit_error,
&LimitType::Global => self.limit_global,
&LimitType::Guild => self.limit_guild,
&LimitType::Ip => self.limit_ip,
&LimitType::Webhook => self.limit_webhook,
}
}
}
#[derive(Debug, Clone)]
pub struct Limits { pub struct Limits {
pub limit_absolute_messages: Limit, pub limit_absolute_messages: Limit,
pub limit_absolute_register: Limit, pub limit_absolute_register: Limit,
@ -180,19 +244,64 @@ pub mod limits {
} }
impl Limits { impl Limits {
pub fn iter(&self) -> std::vec::IntoIter<Limit> { pub fn combine(instance_rate_limits: &Limits, user_rate_limits: &Limits) -> Limits {
let mut limits: Vec<Limit> = Vec::new(); Limits {
limits.push(self.limit_absolute_messages.clone()); limit_absolute_messages: instance_rate_limits.limit_absolute_messages,
limits.push(self.limit_absolute_register.clone()); limit_absolute_register: instance_rate_limits.limit_absolute_register,
limits.push(self.limit_auth_login.clone()); limit_auth_login: instance_rate_limits.limit_auth_login,
limits.push(self.limit_auth_register.clone()); limit_auth_register: instance_rate_limits.limit_auth_register,
limits.push(self.limit_ip.clone()); limit_channel: user_rate_limits.limit_channel,
limits.push(self.limit_global.clone()); limit_error: user_rate_limits.limit_error,
limits.push(self.limit_error.clone()); limit_global: instance_rate_limits.limit_global,
limits.push(self.limit_guild.clone()); limit_guild: user_rate_limits.limit_guild,
limits.push(self.limit_webhook.clone()); limit_ip: instance_rate_limits.limit_ip,
limits.push(self.limit_channel.clone()); limit_webhook: user_rate_limits.limit_webhook,
limits.into_iter() }
}
pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit {
match limit_type {
&LimitType::AbsoluteMessage => &self.limit_absolute_messages,
&LimitType::AbsoluteRegister => &self.limit_absolute_register,
&LimitType::AuthLogin => &self.limit_auth_login,
&LimitType::AuthRegister => &self.limit_auth_register,
&LimitType::Channel => &self.limit_channel,
&LimitType::Error => &self.limit_error,
&LimitType::Global => &self.limit_global,
&LimitType::Guild => &self.limit_guild,
&LimitType::Ip => &self.limit_ip,
&LimitType::Webhook => &self.limit_webhook,
}
}
pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit {
match limit_type {
&LimitType::AbsoluteMessage => &mut self.limit_absolute_messages,
&LimitType::AbsoluteRegister => &mut self.limit_absolute_register,
&LimitType::AuthLogin => &mut self.limit_auth_login,
&LimitType::AuthRegister => &mut self.limit_auth_register,
&LimitType::Channel => &mut self.limit_channel,
&LimitType::Error => &mut self.limit_error,
&LimitType::Global => &mut self.limit_global,
&LimitType::Guild => &mut self.limit_guild,
&LimitType::Ip => &mut self.limit_ip,
&LimitType::Webhook => &mut self.limit_webhook,
}
}
pub fn to_hash_map(&self) -> HashMap<LimitType, Limit> {
let mut map: HashMap<LimitType, Limit> = HashMap::new();
map.insert(LimitType::AbsoluteMessage, self.limit_absolute_messages);
map.insert(LimitType::AbsoluteRegister, self.limit_absolute_register);
map.insert(LimitType::AuthLogin, self.limit_auth_login);
map.insert(LimitType::AuthRegister, self.limit_auth_register);
map.insert(LimitType::Ip, self.limit_ip);
map.insert(LimitType::Global, self.limit_global);
map.insert(LimitType::Error, self.limit_error);
map.insert(LimitType::Guild, self.limit_guild);
map.insert(LimitType::Webhook, self.limit_webhook);
map.insert(LimitType::Channel, self.limit_channel);
map
} }
/// check_limits uses the API to get the current request limits of the instance. /// check_limits uses the API to get the current request limits of the instance.
@ -201,7 +310,7 @@ pub mod limits {
/// # Errors /// # Errors
/// This function will panic if the request fails or if the response body cannot be parsed. /// This function will panic if the request fails or if the response body cannot be parsed.
/// TODO: Change this to return a Result and handle the errors properly. /// TODO: Change this to return a Result and handle the errors properly.
pub async fn check_limits(api_url: String) -> HashMap<LimitType, Limit> { pub async fn check_limits(api_url: String) -> Limits {
let client = Client::new(); let client = Client::new();
let url_parsed = crate::URLBundle::parse_url(api_url) + "/policies/instance/limits"; let url_parsed = crate::URLBundle::parse_url(api_url) + "/policies/instance/limits";
let result = client let result = client
@ -219,216 +328,154 @@ pub mod limits {
}); });
let config: Config = from_str(&result).unwrap(); let config: Config = from_str(&result).unwrap();
// If config.rate.enabled is false, then add return a Limits struct with all limits set to u64::MAX // If config.rate.enabled is false, then add return a Limits struct with all limits set to u64::MAX
let mut limits: HashMap<LimitType, Limit> = HashMap::new(); let mut limits: Limits;
if config.rate.enabled == false { if !config.rate.enabled {
limits.insert( limits = Limits {
LimitType::AbsoluteMessage, limit_absolute_messages: Limit {
Limit {
bucket: LimitType::AbsoluteMessage, bucket: LimitType::AbsoluteMessage,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_absolute_register: Limit {
limits.insert(
LimitType::AbsoluteRegister,
Limit {
bucket: LimitType::AbsoluteRegister, bucket: LimitType::AbsoluteRegister,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_auth_login: Limit {
limits.insert(
LimitType::AuthLogin,
Limit {
bucket: LimitType::AuthLogin, bucket: LimitType::AuthLogin,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_auth_register: Limit {
limits.insert(
LimitType::AuthRegister,
Limit {
bucket: LimitType::AuthRegister, bucket: LimitType::AuthRegister,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_ip: Limit {
limits.insert(
LimitType::Ip,
Limit {
bucket: LimitType::Ip, bucket: LimitType::Ip,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_global: Limit {
limits.insert(
LimitType::Global,
Limit {
bucket: LimitType::Global, bucket: LimitType::Global,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_error: Limit {
limits.insert(
LimitType::Error,
Limit {
bucket: LimitType::Error, bucket: LimitType::Error,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_guild: Limit {
limits.insert(
LimitType::Guild,
Limit {
bucket: LimitType::Guild, bucket: LimitType::Guild,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_webhook: Limit {
limits.insert(
LimitType::Webhook,
Limit {
bucket: LimitType::Webhook, bucket: LimitType::Webhook,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); limit_channel: Limit {
limits.insert(
LimitType::Channel,
Limit {
bucket: LimitType::Channel, bucket: LimitType::Channel,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, },
); };
} else { } else {
limits.insert( limits = Limits {
LimitType::AbsoluteMessage, limit_absolute_messages: Limit {
Limit {
bucket: LimitType::AbsoluteMessage, bucket: LimitType::AbsoluteMessage,
limit: config.absoluteRate.sendMessage.limit, limit: config.absoluteRate.sendMessage.limit,
remaining: config.absoluteRate.sendMessage.limit, remaining: config.absoluteRate.sendMessage.limit,
reset: config.absoluteRate.sendMessage.window, reset: config.absoluteRate.sendMessage.window,
}, },
); limit_absolute_register: Limit {
limits.insert(
LimitType::AbsoluteRegister,
Limit {
bucket: LimitType::AbsoluteRegister, bucket: LimitType::AbsoluteRegister,
limit: config.absoluteRate.register.limit, limit: config.absoluteRate.register.limit,
remaining: config.absoluteRate.register.limit, remaining: config.absoluteRate.register.limit,
reset: config.absoluteRate.register.window, reset: config.absoluteRate.register.window,
}, },
); limit_auth_login: Limit {
limits.insert(
LimitType::AuthLogin,
Limit {
bucket: LimitType::AuthLogin, bucket: LimitType::AuthLogin,
limit: config.rate.routes.auth.login.count, limit: config.rate.routes.auth.login.count,
remaining: config.rate.routes.auth.login.count, remaining: config.rate.routes.auth.login.count,
reset: config.rate.routes.auth.login.window, reset: config.rate.routes.auth.login.window,
}, },
); limit_auth_register: Limit {
limits.insert(
LimitType::AuthRegister,
Limit {
bucket: LimitType::AuthRegister, bucket: LimitType::AuthRegister,
limit: config.rate.routes.auth.register.count, limit: config.rate.routes.auth.register.count,
remaining: config.rate.routes.auth.register.count, remaining: config.rate.routes.auth.register.count,
reset: config.rate.routes.auth.register.window, reset: config.rate.routes.auth.register.window,
}, },
); limit_ip: Limit {
limits.insert(
LimitType::Guild,
Limit {
bucket: LimitType::Guild,
limit: config.rate.routes.guild.count,
remaining: config.rate.routes.guild.count,
reset: config.rate.routes.guild.window,
},
);
limits.insert(
LimitType::Webhook,
Limit {
bucket: LimitType::Webhook,
limit: config.rate.routes.webhook.count,
remaining: config.rate.routes.webhook.count,
reset: config.rate.routes.webhook.window,
},
);
limits.insert(
LimitType::Channel,
Limit {
bucket: LimitType::Channel,
limit: config.rate.routes.channel.count,
remaining: config.rate.routes.channel.count,
reset: config.rate.routes.channel.window,
},
);
limits.insert(
LimitType::Ip,
Limit {
bucket: LimitType::Ip, bucket: LimitType::Ip,
limit: config.rate.ip.count, limit: config.rate.ip.count,
remaining: config.rate.ip.count, remaining: config.rate.ip.count,
reset: config.rate.ip.window, reset: config.rate.ip.window,
}, },
); limit_global: Limit {
limits.insert(
LimitType::Global,
Limit {
bucket: LimitType::Global, bucket: LimitType::Global,
limit: config.rate.global.count, limit: config.rate.global.count,
remaining: config.rate.global.count, remaining: config.rate.global.count,
reset: config.rate.global.window, reset: config.rate.global.window,
}, },
); limit_error: Limit {
limits.insert(
LimitType::Error,
Limit {
bucket: LimitType::Error, bucket: LimitType::Error,
limit: config.rate.error.count, limit: config.rate.error.count,
remaining: config.rate.error.count, remaining: config.rate.error.count,
reset: config.rate.error.window, reset: config.rate.error.window,
}, },
); limit_guild: Limit {
bucket: LimitType::Guild,
limit: config.rate.routes.guild.count,
remaining: config.rate.routes.guild.count,
reset: config.rate.routes.guild.window,
},
limit_webhook: Limit {
bucket: LimitType::Webhook,
limit: config.rate.routes.webhook.count,
remaining: config.rate.routes.webhook.count,
reset: config.rate.routes.webhook.window,
},
limit_channel: Limit {
bucket: LimitType::Channel,
limit: config.rate.routes.channel.count,
remaining: config.rate.routes.channel.count,
reset: config.rate.routes.channel.window,
},
};
} }
if !config.absoluteRate.register.enabled { if !config.absoluteRate.register.enabled {
limits.insert( limits.limit_absolute_register = Limit {
LimitType::AbsoluteRegister,
Limit {
bucket: LimitType::AbsoluteRegister, bucket: LimitType::AbsoluteRegister,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, };
);
} }
if !config.absoluteRate.sendMessage.enabled { if !config.absoluteRate.sendMessage.enabled {
limits.insert( limits.limit_absolute_messages = Limit {
LimitType::AbsoluteMessage,
Limit {
bucket: LimitType::AbsoluteMessage, bucket: LimitType::AbsoluteMessage,
limit: u64::MAX, limit: u64::MAX,
remaining: u64::MAX, remaining: u64::MAX,
reset: u64::MAX, reset: u64::MAX,
}, };
);
} }
return limits; limits
} }
} }
} }
@ -446,8 +493,8 @@ mod instance_limits {
reset: 0, reset: 0,
}; };
limit.add_remaining(-2); limit.add_remaining(-2);
assert_eq!(0 as u64, limit.remaining); assert_eq!(0_u64, limit.remaining);
limit.add_remaining(-2123123); limit.add_remaining(-2123123);
assert_eq!(0 as u64, limit.remaining); assert_eq!(0_u64, limit.remaining);
} }
} }

View File

@ -1,19 +1,17 @@
pub mod schemas { use regex::Regex;
use regex::Regex; use serde::{Deserialize, Serialize};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fmt};
use crate::errors::FieldFormatError; use crate::errors::FieldFormatError;
/** /**
A struct that represents a well-formed email address. A struct that represents a well-formed email address.
*/ */
#[derive(Clone, PartialEq, Eq, Debug)] #[derive(Clone, PartialEq, Eq, Debug)]
pub struct AuthEmail { pub struct AuthEmail {
pub email: String, pub email: String,
} }
impl AuthEmail { impl AuthEmail {
/** /**
Returns a new [`Result<AuthEmail, FieldFormatError>`]. Returns a new [`Result<AuthEmail, FieldFormatError>`].
## Arguments ## Arguments
@ -25,27 +23,27 @@ pub mod schemas {
*/ */
pub fn new(email: String) -> Result<AuthEmail, FieldFormatError> { pub fn new(email: String) -> Result<AuthEmail, FieldFormatError> {
let regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap(); let regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
if !regex.is_match(email.clone().as_str()) { if !regex.is_match(email.as_str()) {
return Err(FieldFormatError::EmailError); return Err(FieldFormatError::EmailError);
} }
return Ok(AuthEmail { email }); Ok(AuthEmail { email })
}
} }
}
/** /**
A struct that represents a well-formed username. A struct that represents a well-formed username.
## Arguments ## Arguments
Please use new() to create a new instance of this struct. Please use new() to create a new instance of this struct.
## Errors ## Errors
You will receive a [`FieldFormatError`], if: You will receive a [`FieldFormatError`], if:
- The username is not between 2 and 32 characters. - The username is not between 2 and 32 characters.
*/ */
#[derive(Clone, PartialEq, Eq, Debug)] #[derive(Clone, PartialEq, Eq, Debug)]
pub struct AuthUsername { pub struct AuthUsername {
pub username: String, pub username: String,
} }
impl AuthUsername { impl AuthUsername {
/** /**
Returns a new [`Result<AuthUsername, FieldFormatError>`]. Returns a new [`Result<AuthUsername, FieldFormatError>`].
## Arguments ## Arguments
@ -56,27 +54,27 @@ pub mod schemas {
*/ */
pub fn new(username: String) -> Result<AuthUsername, FieldFormatError> { pub fn new(username: String) -> Result<AuthUsername, FieldFormatError> {
if username.len() < 2 || username.len() > 32 { if username.len() < 2 || username.len() > 32 {
return Err(FieldFormatError::UsernameError); Err(FieldFormatError::UsernameError)
} else { } else {
return Ok(AuthUsername { username }); Ok(AuthUsername { username })
}
} }
} }
}
/** /**
A struct that represents a well-formed password. A struct that represents a well-formed password.
## Arguments ## Arguments
Please use new() to create a new instance of this struct. Please use new() to create a new instance of this struct.
## Errors ## Errors
You will receive a [`FieldFormatError`], if: You will receive a [`FieldFormatError`], if:
- The password is not between 1 and 72 characters. - The password is not between 1 and 72 characters.
*/ */
#[derive(Clone, PartialEq, Eq, Debug)] #[derive(Clone, PartialEq, Eq, Debug)]
pub struct AuthPassword { pub struct AuthPassword {
pub password: String, pub password: String,
} }
impl AuthPassword { impl AuthPassword {
/** /**
Returns a new [`Result<AuthPassword, FieldFormatError>`]. Returns a new [`Result<AuthPassword, FieldFormatError>`].
## Arguments ## Arguments
@ -86,27 +84,27 @@ pub mod schemas {
- The password is not between 1 and 72 characters. - The password is not between 1 and 72 characters.
*/ */
pub fn new(password: String) -> Result<AuthPassword, FieldFormatError> { pub fn new(password: String) -> Result<AuthPassword, FieldFormatError> {
if password.len() < 1 || password.len() > 72 { if password.is_empty() || password.len() > 72 {
return Err(FieldFormatError::PasswordError); Err(FieldFormatError::PasswordError)
} else { } else {
return Ok(AuthPassword { password }); Ok(AuthPassword { password })
}
} }
} }
}
/** /**
A struct that represents a well-formed register request. A struct that represents a well-formed register request.
## Arguments ## Arguments
Please use new() to create a new instance of this struct. Please use new() to create a new instance of this struct.
## Errors ## Errors
You will receive a [`FieldFormatError`], if: You will receive a [`FieldFormatError`], if:
- The username is not between 2 and 32 characters. - The username is not between 2 and 32 characters.
- The password is not between 1 and 72 characters. - The password is not between 1 and 72 characters.
*/ */
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct RegisterSchema { pub struct RegisterSchema {
username: String, username: String,
password: Option<String>, password: Option<String>,
consent: bool, consent: bool,
@ -117,9 +115,9 @@ pub mod schemas {
gift_code_sku_id: Option<String>, gift_code_sku_id: Option<String>,
captcha_key: Option<String>, captcha_key: Option<String>,
promotional_email_opt_in: Option<bool>, promotional_email_opt_in: Option<bool>,
} }
impl RegisterSchema { impl RegisterSchema {
/** /**
Returns a new [`Result<RegisterSchema, FieldFormatError>`]. Returns a new [`Result<RegisterSchema, FieldFormatError>`].
## Arguments ## Arguments
@ -164,7 +162,7 @@ pub mod schemas {
return Err(FieldFormatError::ConsentError); return Err(FieldFormatError::ConsentError);
} }
return Ok(RegisterSchema { Ok(RegisterSchema {
username, username,
password: has_password, password: has_password,
consent, consent,
@ -175,31 +173,31 @@ pub mod schemas {
gift_code_sku_id, gift_code_sku_id,
captcha_key, captcha_key,
promotional_email_opt_in, promotional_email_opt_in,
}); })
}
} }
}
/** /**
A struct that represents a well-formed login request. A struct that represents a well-formed login request.
## Arguments ## Arguments
Please use new() to create a new instance of this struct. Please use new() to create a new instance of this struct.
## Errors ## Errors
You will receive a [`FieldFormatError`], if: You will receive a [`FieldFormatError`], if:
- The username is not between 2 and 32 characters. - The username is not between 2 and 32 characters.
- The password is not between 1 and 72 characters. - The password is not between 1 and 72 characters.
*/ */
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct LoginSchema { pub struct LoginSchema {
login: String, login: String,
password: String, password: String,
undelete: Option<bool>, undelete: Option<bool>,
captcha_key: Option<String>, captcha_key: Option<String>,
login_source: Option<String>, login_source: Option<String>,
gift_code_sku_id: Option<String>, gift_code_sku_id: Option<String>,
} }
impl LoginSchema { impl LoginSchema {
/** /**
Returns a new [`Result<LoginSchema, FieldFormatError>`]. Returns a new [`Result<LoginSchema, FieldFormatError>`].
## Arguments ## Arguments
@ -222,171 +220,31 @@ pub mod schemas {
gift_code_sku_id: Option<String>, gift_code_sku_id: Option<String>,
) -> Result<LoginSchema, FieldFormatError> { ) -> Result<LoginSchema, FieldFormatError> {
let login = login.username; let login = login.username;
return Ok(LoginSchema { Ok(LoginSchema {
login, login,
password, password,
undelete, undelete,
captcha_key, captcha_key,
login_source, login_source,
gift_code_sku_id, gift_code_sku_id,
}); })
}
} }
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct LoginResult { #[serde(rename_all = "snake_case")]
token: String, pub struct TotpSchema {
settings: UserSettings,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserSettings {
afk_timeout: i32,
allow_accessibility_detection: bool,
animate_emoji: bool,
animate_stickers: i32,
contact_sync_enabled: bool,
convert_emoticons: bool,
custom_status: Option<String>,
default_guilds_restricted: bool,
detect_platform_accounts: bool,
developer_mode: bool,
disable_games_tab: bool,
enable_tts_command: bool,
explicit_content_filter: i32,
friend_source_flags: FriendSourceFlags,
friend_discovery_flags: Option<i32>,
gateway_connected: bool,
gif_auto_play: bool,
guild_folders: Vec<GuildFolder>,
guild_positions: Vec<i64>,
inline_attachment_media: bool,
inline_embed_media: bool,
locale: String,
message_display_compact: bool,
native_phone_integration_enabled: bool,
render_embeds: bool,
render_reactions: bool,
restricted_guilds: Vec<i64>,
show_current_game: bool,
status: String,
stream_notifications_enabled: bool,
theme: String,
timezone_offset: i32,
view_nsfw_guilds: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FriendSourceFlags {
all: Option<bool>,
mutual_friends: Option<bool>,
mutual_guilds: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GuildFolder {
id: String,
guild_ids: Vec<i64>,
name: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TotpSchema {
code: String, code: String,
ticket: String, ticket: String,
gift_code_sku_id: Option<String>, gift_code_sku_id: Option<String>,
login_source: Option<String>, login_source: Option<String>,
}
/**
Represents the result you get from GET: /api/instance/policies/.
*/
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct InstancePoliciesSchema {
instance_name: String,
instance_description: Option<String>,
front_page: Option<String>,
tos_page: Option<String>,
correspondence_email: Option<String>,
correspondence_user_id: Option<String>,
image: Option<String>,
instance_id: Option<String>,
}
impl InstancePoliciesSchema {
pub fn new(
instance_name: String,
instance_description: Option<String>,
front_page: Option<String>,
tos_page: Option<String>,
correspondence_email: Option<String>,
correspondence_user_id: Option<String>,
image: Option<String>,
instance_id: Option<String>,
) -> Self {
InstancePoliciesSchema {
instance_name,
instance_description,
front_page,
tos_page,
correspondence_email,
correspondence_user_id,
image,
instance_id,
}
}
}
impl fmt::Display for InstancePoliciesSchema {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"InstancePoliciesSchema {{ instance_name: {}, instance_description: {}, front_page: {}, tos_page: {}, correspondence_email: {}, correspondence_user_id: {}, image: {}, instance_id: {} }}",
self.instance_name,
self.instance_description.clone().unwrap_or("None".to_string()),
self.front_page.clone().unwrap_or("None".to_string()),
self.tos_page.clone().unwrap_or("None".to_string()),
self.correspondence_email.clone().unwrap_or("None".to_string()),
self.correspondence_user_id.clone().unwrap_or("None".to_string()),
self.image.clone().unwrap_or("None".to_string()),
self.instance_id.clone().unwrap_or("None".to_string()),
)
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ErrorResponse {
pub code: i32,
pub message: String,
pub errors: IntermittentError,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct IntermittentError {
#[serde(flatten)]
pub errors: std::collections::HashMap<String, ErrorField>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct ErrorField {
#[serde(default)]
pub _errors: Vec<Error>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Error {
pub message: String,
pub code: String,
}
} }
// I know that some of these tests are... really really basic and unneccessary, but sometimes, I // I know that some of these tests are... really really basic and unneccessary, but sometimes, I
// just feel like writing tests, so there you go :) -@bitfl0wer // just feel like writing tests, so there you go :) -@bitfl0wer
#[cfg(test)] #[cfg(test)]
mod schemas_tests { mod schemas_tests {
use super::schemas::*; use super::*;
use crate::errors::FieldFormatError; use crate::errors::FieldFormatError;
#[test] #[test]
@ -401,7 +259,7 @@ mod schemas_tests {
fn password_too_long() { fn password_too_long() {
let mut long_pw = String::new(); let mut long_pw = String::new();
for _ in 0..73 { for _ in 0..73 {
long_pw = long_pw + "a"; long_pw += "a";
} }
assert_eq!( assert_eq!(
AuthPassword::new(long_pw), AuthPassword::new(long_pw),
@ -421,7 +279,7 @@ mod schemas_tests {
fn username_too_long() { fn username_too_long() {
let mut long_un = String::new(); let mut long_un = String::new();
for _ in 0..33 { for _ in 0..33 {
long_un = long_un + "a"; long_un += "a";
} }
assert_eq!( assert_eq!(
AuthUsername::new(long_un), AuthUsername::new(long_un),

213
src/api/types.rs Normal file
View File

@ -0,0 +1,213 @@
use std::fmt;
use serde::{Deserialize, Serialize};
use crate::{api::limits::Limits, URLBundle};
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResult {
token: String,
settings: UserSettings,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserSettings {
afk_timeout: i32,
allow_accessibility_detection: bool,
animate_emoji: bool,
animate_stickers: i32,
contact_sync_enabled: bool,
convert_emoticons: bool,
custom_status: Option<String>,
default_guilds_restricted: bool,
detect_platform_accounts: bool,
developer_mode: bool,
disable_games_tab: bool,
enable_tts_command: bool,
explicit_content_filter: i32,
friend_source_flags: FriendSourceFlags,
friend_discovery_flags: Option<i32>,
gateway_connected: bool,
gif_auto_play: bool,
guild_folders: Vec<GuildFolder>,
guild_positions: Vec<i64>,
inline_attachment_media: bool,
inline_embed_media: bool,
locale: String,
message_display_compact: bool,
native_phone_integration_enabled: bool,
render_embeds: bool,
render_reactions: bool,
restricted_guilds: Vec<i64>,
show_current_game: bool,
status: String,
stream_notifications_enabled: bool,
theme: String,
timezone_offset: i32,
view_nsfw_guilds: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FriendSourceFlags {
all: Option<bool>,
mutual_friends: Option<bool>,
mutual_guilds: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GuildFolder {
id: String,
guild_ids: Vec<i64>,
name: String,
}
/**
Represents the result you get from GET: /api/instance/policies/.
*/
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct InstancePolicies {
instance_name: String,
instance_description: Option<String>,
front_page: Option<String>,
tos_page: Option<String>,
correspondence_email: Option<String>,
correspondence_user_id: Option<String>,
image: Option<String>,
instance_id: Option<String>,
}
impl InstancePolicies {
pub fn new(
instance_name: String,
instance_description: Option<String>,
front_page: Option<String>,
tos_page: Option<String>,
correspondence_email: Option<String>,
correspondence_user_id: Option<String>,
image: Option<String>,
instance_id: Option<String>,
) -> Self {
InstancePolicies {
instance_name,
instance_description,
front_page,
tos_page,
correspondence_email,
correspondence_user_id,
image,
instance_id,
}
}
}
impl fmt::Display for InstancePolicies {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"InstancePoliciesSchema {{ instance_name: {}, instance_description: {}, front_page: {}, tos_page: {}, correspondence_email: {}, correspondence_user_id: {}, image: {}, instance_id: {} }}",
self.instance_name,
self.instance_description.clone().unwrap_or("None".to_string()),
self.front_page.clone().unwrap_or("None".to_string()),
self.tos_page.clone().unwrap_or("None".to_string()),
self.correspondence_email.clone().unwrap_or("None".to_string()),
self.correspondence_user_id.clone().unwrap_or("None".to_string()),
self.image.clone().unwrap_or("None".to_string()),
self.instance_id.clone().unwrap_or("None".to_string()),
)
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ErrorResponse {
pub code: i32,
pub message: String,
pub errors: IntermittentError,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct IntermittentError {
#[serde(flatten)]
pub errors: std::collections::HashMap<String, ErrorField>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct ErrorField {
#[serde(default)]
pub _errors: Vec<Error>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Error {
pub message: String,
pub code: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct UserObject {
id: String,
username: String,
discriminator: String,
avatar: Option<String>,
bot: Option<bool>,
system: Option<bool>,
mfa_enabled: Option<bool>,
banner: Option<bool>,
accent_color: Option<String>,
locale: String,
verified: Option<bool>,
email: Option<String>,
flags: i8,
premium_type: Option<i8>,
public_flags: Option<i8>,
}
#[derive(Debug)]
pub struct User {
logged_in: bool,
belongs_to: URLBundle,
token: String,
rate_limits: Limits,
pub settings: UserSettings,
pub object: UserObject,
}
impl User {
pub fn is_logged_in(&self) -> bool {
self.logged_in
}
pub fn belongs_to(&self) -> URLBundle {
self.belongs_to.clone()
}
pub fn token(&self) -> String {
self.token.clone()
}
pub fn set_logged_in(&mut self, bool: bool) {
self.logged_in = bool;
}
pub fn set_token(&mut self, token: String) {
self.token = token;
}
pub fn new(
logged_in: bool,
belongs_to: URLBundle,
token: String,
rate_limits: Limits,
settings: UserSettings,
object: UserObject,
) -> User {
User {
logged_in,
belongs_to,
token,
rate_limits,
settings,
object,
}
}
}

View File

@ -1,4 +1,5 @@
use crate::api::schemas::schemas::InstancePoliciesSchema; use crate::api::limits::Limits;
use crate::api::types::{InstancePolicies, User};
use crate::errors::{FieldFormatError, InstanceServerError}; use crate::errors::{FieldFormatError, InstanceServerError};
use crate::limit::LimitedRequester; use crate::limit::LimitedRequester;
use crate::URLBundle; use crate::URLBundle;
@ -12,10 +13,11 @@ The [`Instance`] what you will be using to perform all sorts of actions on the S
*/ */
pub struct Instance { pub struct Instance {
pub urls: URLBundle, pub urls: URLBundle,
pub instance_info: InstancePoliciesSchema, pub instance_info: InstancePolicies,
pub requester: LimitedRequester, pub requester: LimitedRequester,
pub limits: Limits,
//pub gateway: Gateway, //pub gateway: Gateway,
pub users: HashMap<Token, Username>, pub users: HashMap<Token, User>,
} }
impl Instance { impl Instance {
@ -29,10 +31,10 @@ impl Instance {
urls: URLBundle, urls: URLBundle,
requester: LimitedRequester, requester: LimitedRequester,
) -> Result<Instance, InstanceServerError> { ) -> Result<Instance, InstanceServerError> {
let users: HashMap<Token, Username> = HashMap::new(); let users: HashMap<Token, User> = HashMap::new();
let mut instance = Instance { let mut instance = Instance {
urls, urls: urls.clone(),
instance_info: InstancePoliciesSchema::new( instance_info: InstancePolicies::new(
// This is okay, because the instance_info will be overwritten by the instance_policies_schema() function. // This is okay, because the instance_info will be overwritten by the instance_policies_schema() function.
"".to_string(), "".to_string(),
None, None,
@ -43,6 +45,7 @@ impl Instance {
None, None,
None, None,
), ),
limits: Limits::check_limits(urls.api).await,
requester, requester,
users, users,
}; };
@ -84,6 +87,6 @@ impl Username {
if username.len() < 2 || username.len() > 32 { if username.len() < 2 || username.len() > 32 {
return Err(FieldFormatError::UsernameError); return Err(FieldFormatError::UsernameError);
} }
return Ok(Username { username }); Ok(Username { username })
} }
} }

View File

@ -49,10 +49,10 @@ impl URLBundle {
}; };
// if the last character of the string is a slash, remove it. // if the last character of the string is a slash, remove it.
let mut url_string = url.to_string(); let mut url_string = url.to_string();
if url_string.chars().last().unwrap() == '/' { if url_string.ends_with('/') {
url_string.pop(); url_string.pop();
} }
return url_string; url_string
} }
pub fn get_api(&self) -> &str { pub fn get_api(&self) -> &str {

View File

@ -1,10 +1,10 @@
use crate::{ use crate::{
api::limits::{Limit, LimitType, Limits}, api::limits::{Limit, LimitType, Limits, LimitsMutRef},
errors::InstanceServerError, errors::InstanceServerError,
}; };
use reqwest::{Client, RequestBuilder, Response}; use reqwest::{Client, RequestBuilder, Response};
use std::collections::{HashMap, VecDeque}; use std::collections::VecDeque;
// Note: There seem to be some overlapping request limiters. We need to make sure that sending a // Note: There seem to be some overlapping request limiters. We need to make sure that sending a
// request checks for all the request limiters that apply, and blocks if any of the limiters are 0 // request checks for all the request limiters that apply, and blocks if any of the limiters are 0
@ -20,7 +20,6 @@ pub struct TypedRequest {
pub struct LimitedRequester { pub struct LimitedRequester {
http: Client, http: Client,
requests: VecDeque<TypedRequest>, requests: VecDeque<TypedRequest>,
limits_rate: HashMap<LimitType, Limit>,
} }
impl LimitedRequester { impl LimitedRequester {
@ -29,11 +28,10 @@ impl LimitedRequester {
/// be send within the `Limit` of an external API Ratelimiter, and looks at the returned request /// be send within the `Limit` of an external API Ratelimiter, and looks at the returned request
/// headers to see if it can find Ratelimit info to update itself. /// headers to see if it can find Ratelimit info to update itself.
#[allow(dead_code)] #[allow(dead_code)]
pub async fn new(api_url: String) -> Self { pub async fn new() -> Self {
LimitedRequester { LimitedRequester {
http: Client::new(), http: Client::new(),
requests: VecDeque::new(), requests: VecDeque::new(),
limits_rate: Limits::check_limits(api_url).await,
} }
} }
@ -70,8 +68,10 @@ impl LimitedRequester {
&mut self, &mut self,
request: RequestBuilder, request: RequestBuilder,
limit_type: LimitType, limit_type: LimitType,
instance_rate_limits: &mut Limits,
user_rate_limits: &mut Limits,
) -> Result<Response, InstanceServerError> { ) -> Result<Response, InstanceServerError> {
if self.can_send_request(limit_type) { if self.can_send_request(limit_type, instance_rate_limits, user_rate_limits) {
let built_request = request let built_request = request
.build() .build()
.unwrap_or_else(|e| panic!("Error while building the Request for sending: {}", e)); .unwrap_or_else(|e| panic!("Error while building the Request for sending: {}", e));
@ -80,14 +80,19 @@ impl LimitedRequester {
Ok(is_response) => is_response, Ok(is_response) => is_response,
Err(e) => panic!("An error occured while processing the response: {}", e), Err(e) => panic!("An error occured while processing the response: {}", e),
}; };
self.update_limits(&response, limit_type); self.update_limits(
return Ok(response); &response,
limit_type,
instance_rate_limits,
user_rate_limits,
);
Ok(response)
} else { } else {
self.requests.push_back(TypedRequest { self.requests.push_back(TypedRequest {
request: request, request,
limit_type: limit_type, limit_type,
}); });
return Err(InstanceServerError::RateLimited); Err(InstanceServerError::RateLimited)
} }
} }
@ -102,9 +107,16 @@ impl LimitedRequester {
} }
} }
fn can_send_request(&mut self, limit_type: LimitType) -> bool { fn can_send_request(
let limits = &self.limits_rate.clone(); &mut self,
limit_type: LimitType,
instance_rate_limits: &Limits,
user_rate_limits: &Limits,
) -> bool {
// Check if all of the limits in this vec have at least one remaining request // Check if all of the limits in this vec have at least one remaining request
let rate_limits = Limits::combine(instance_rate_limits, user_rate_limits);
let constant_limits: Vec<&LimitType> = [ let constant_limits: Vec<&LimitType> = [
&LimitType::Error, &LimitType::Error,
&LimitType::Global, &LimitType::Global,
@ -113,19 +125,29 @@ impl LimitedRequester {
] ]
.to_vec(); .to_vec();
for limit in constant_limits.iter() { for limit in constant_limits.iter() {
match limits.get(&limit) { match rate_limits.to_hash_map().get(limit) {
Some(limit) => { Some(limit) => {
if limit.remaining == 0 { if limit.remaining == 0 {
return false; return false;
} }
// AbsoluteRegister and AuthRegister can cancel each other out. // AbsoluteRegister and AuthRegister can cancel each other out.
if limit.bucket == LimitType::AbsoluteRegister if limit.bucket == LimitType::AbsoluteRegister
&& limits.get(&LimitType::AuthRegister).unwrap().remaining == 0 && rate_limits
.to_hash_map()
.get(&LimitType::AuthRegister)
.unwrap()
.remaining
== 0
{ {
return false; return false;
} }
if limit.bucket == LimitType::AuthRegister if limit.bucket == LimitType::AuthRegister
&& limits.get(&LimitType::AbsoluteRegister).unwrap().remaining == 0 && rate_limits
.to_hash_map()
.get(&LimitType::AbsoluteRegister)
.unwrap()
.remaining
== 0
{ {
return false; return false;
} }
@ -133,100 +155,97 @@ impl LimitedRequester {
None => return false, None => return false,
} }
} }
return true; true
} }
fn update_limits(&mut self, response: &Response, limit_type: LimitType) { fn update_limits(
&mut self,
response: &Response,
limit_type: LimitType,
instance_rate_limits: &mut Limits,
user_rate_limits: &mut Limits,
) {
let mut rate_limits = LimitsMutRef::combine_mut_ref(instance_rate_limits, user_rate_limits);
let remaining = match response.headers().get("X-RateLimit-Remaining") { let remaining = match response.headers().get("X-RateLimit-Remaining") {
Some(remaining) => remaining.to_str().unwrap().parse::<u64>().unwrap(), Some(remaining) => remaining.to_str().unwrap().parse::<u64>().unwrap(),
None => self.limits_rate.get(&limit_type).unwrap().remaining - 1, None => rate_limits.get_limit_mut_ref(&limit_type).remaining - 1,
}; };
let limit = match response.headers().get("X-RateLimit-Limit") { let limit = match response.headers().get("X-RateLimit-Limit") {
Some(limit) => limit.to_str().unwrap().parse::<u64>().unwrap(), Some(limit) => limit.to_str().unwrap().parse::<u64>().unwrap(),
None => self.limits_rate.get(&limit_type).unwrap().limit, None => rate_limits.get_limit_mut_ref(&limit_type).limit,
}; };
let reset = match response.headers().get("X-RateLimit-Reset") { let reset = match response.headers().get("X-RateLimit-Reset") {
Some(reset) => reset.to_str().unwrap().parse::<u64>().unwrap(), Some(reset) => reset.to_str().unwrap().parse::<u64>().unwrap(),
None => self.limits_rate.get(&limit_type).unwrap().reset, None => rate_limits.get_limit_mut_ref(&limit_type).reset,
}; };
let status = response.status(); let status = response.status();
let status_str = status.as_str(); let status_str = status.as_str();
if status_str.chars().next().unwrap() == '4' { if status_str.starts_with('4') {
self.limits_rate rate_limits
.get_mut(&LimitType::Error) .get_limit_mut_ref(&LimitType::Error)
.unwrap()
.add_remaining(-1); .add_remaining(-1);
} }
self.limits_rate rate_limits
.get_mut(&LimitType::Global) .get_limit_mut_ref(&LimitType::Global)
.unwrap()
.add_remaining(-1); .add_remaining(-1);
self.limits_rate rate_limits
.get_mut(&LimitType::Ip) .get_limit_mut_ref(&LimitType::Ip)
.unwrap()
.add_remaining(-1); .add_remaining(-1);
let mut_limits_rate = &mut self.limits_rate;
match limit_type { match limit_type {
LimitType::Error => { LimitType::Error => {
let entry = mut_limits_rate.get_mut(&LimitType::Error).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Error);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Global => { LimitType::Global => {
let entry = mut_limits_rate.get_mut(&LimitType::Global).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Global);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Ip => { LimitType::Ip => {
let entry = mut_limits_rate.get_mut(&LimitType::Ip).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Ip);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::AuthLogin => { LimitType::AuthLogin => {
let entry = mut_limits_rate.get_mut(&LimitType::AuthLogin).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthLogin);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::AbsoluteRegister => { LimitType::AbsoluteRegister => {
let entry = mut_limits_rate let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteRegister);
.get_mut(&LimitType::AbsoluteRegister)
.unwrap();
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
// AbsoluteRegister and AuthRegister both need to be updated, if a Register event // AbsoluteRegister and AuthRegister both need to be updated, if a Register event
// happens. // happens.
mut_limits_rate rate_limits
.get_mut(&LimitType::AuthRegister) .get_limit_mut_ref(&LimitType::AuthRegister)
.unwrap()
.remaining -= 1; .remaining -= 1;
} }
LimitType::AuthRegister => { LimitType::AuthRegister => {
let entry = mut_limits_rate.get_mut(&LimitType::AuthRegister).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthRegister);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
// AbsoluteRegister and AuthRegister both need to be updated, if a Register event // AbsoluteRegister and AuthRegister both need to be updated, if a Register event
// happens. // happens.
mut_limits_rate rate_limits
.get_mut(&LimitType::AbsoluteRegister) .get_limit_mut_ref(&LimitType::AbsoluteRegister)
.unwrap()
.remaining -= 1; .remaining -= 1;
} }
LimitType::AbsoluteMessage => { LimitType::AbsoluteMessage => {
let entry = mut_limits_rate let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteMessage);
.get_mut(&LimitType::AbsoluteMessage)
.unwrap();
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Channel => { LimitType::Channel => {
let entry = mut_limits_rate.get_mut(&LimitType::Channel).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Channel);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Guild => { LimitType::Guild => {
let entry = mut_limits_rate.get_mut(&LimitType::Guild).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Guild);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
LimitType::Webhook => { LimitType::Webhook => {
let entry = mut_limits_rate.get_mut(&LimitType::Webhook).unwrap(); let entry = rate_limits.get_limit_mut_ref(&LimitType::Webhook);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit); LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
} }
} }
@ -247,16 +266,7 @@ mod rate_limit {
String::from("wss://localhost:3001/"), String::from("wss://localhost:3001/"),
String::from("http://localhost:3001/cdn"), String::from("http://localhost:3001/cdn"),
); );
let requester = LimitedRequester::new(urls.api).await; let _requester = LimitedRequester::new().await;
assert_eq!(
requester.limits_rate.get(&LimitType::Ip).unwrap(),
&Limit {
bucket: LimitType::Ip,
limit: 500,
remaining: 500,
reset: 5
}
);
} }
#[tokio::test] #[tokio::test]
@ -266,16 +276,22 @@ mod rate_limit {
String::from("wss://localhost:3001/"), String::from("wss://localhost:3001/"),
String::from("http://localhost:3001/cdn"), String::from("http://localhost:3001/cdn"),
); );
let mut requester = LimitedRequester::new(urls.api.clone()).await; let mut requester = LimitedRequester::new().await;
let mut request: Option<Result<Response, InstanceServerError>> = None; let mut request: Option<Result<Response, InstanceServerError>> = None;
let mut instance_rate_limits = Limits::check_limits(urls.api.clone()).await;
let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await;
for _ in 0..=50 { for _ in 0..=50 {
let request_path = urls.api.clone() + "/some/random/nonexisting/path"; let request_path = urls.api.clone() + "/some/random/nonexisting/path";
let request_builder = requester.http.get(request_path); let request_builder = requester.http.get(request_path);
request = Some( request = Some(
requester requester
.send_request(request_builder, LimitType::Channel) .send_request(
request_builder,
LimitType::Channel,
&mut instance_rate_limits,
&mut user_rate_limits,
)
.await, .await,
); );
} }
@ -296,11 +312,18 @@ mod rate_limit {
String::from("wss://localhost:3001/"), String::from("wss://localhost:3001/"),
String::from("http://localhost:3001/cdn"), String::from("http://localhost:3001/cdn"),
); );
let mut requester = LimitedRequester::new(urls.api.clone()).await; let mut instance_rate_limits = Limits::check_limits(urls.api.clone()).await;
let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await;
let mut requester = LimitedRequester::new().await;
let request_path = urls.api.clone() + "/policies/instance/limits"; let request_path = urls.api.clone() + "/policies/instance/limits";
let request_builder = requester.http.get(request_path); let request_builder = requester.http.get(request_path);
let request = requester let request = requester
.send_request(request_builder, LimitType::Channel) .send_request(
request_builder,
LimitType::Channel,
&mut instance_rate_limits,
&mut user_rate_limits,
)
.await; .await;
let result = match request { let result = match request {
Ok(result) => result, Ok(result) => result,