fix can_send_request
This commit is contained in:
parent
d9b6d77c69
commit
e8cad2a4ff
170
src/limit.rs
170
src/limit.rs
|
@ -1,6 +1,6 @@
|
||||||
use crate::api::limits::Config;
|
use crate::api::limits::Config;
|
||||||
|
|
||||||
use reqwest::{Body, Client, Request, RequestBuilder};
|
use reqwest::{Client, Request, RequestBuilder};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
|
||||||
|
@ -45,16 +45,135 @@ impl LimitedRequester {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_limit(&self, bucket: &str) -> Option<&Limit> {
|
||||||
|
for limit in self.limits.iter() {
|
||||||
|
if limit.bucket == bucket {
|
||||||
|
return Some(limit.to_owned());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `last_match` returns the last match of a `Vec<&str>` in a `String`. It returns a `Option<(usize, String)>`
|
||||||
|
/// that contains the index of the match and the match itself.
|
||||||
|
/// If no match is found, it returns `None`.
|
||||||
|
/// # Example
|
||||||
|
/// ```rs
|
||||||
|
/// let string = "https://discord.com/api/v8/channels/1234567890/messages";
|
||||||
|
/// let matches = ["channels", "messages"];
|
||||||
|
/// let index_match = last_match(string, matches.to_vec());
|
||||||
|
/// assert_eq!(index_match, Some((43, String::from("messages"))));
|
||||||
|
/// ```
|
||||||
|
fn last_match(string: &str, search_for: Vec<&str>) -> Option<(usize, String)> {
|
||||||
|
let mut index_match: (usize, &str) = (0, "");
|
||||||
|
for _match in search_for.iter() {
|
||||||
|
if !string.contains(_match) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Get the index of the match
|
||||||
|
let temp_index_match = string.match_indices(_match).next().unwrap();
|
||||||
|
// As only the last match is relevant, we only update the index_match if the index is
|
||||||
|
// higher than the previous one.
|
||||||
|
if temp_index_match.0 > index_match.0 {
|
||||||
|
index_match = temp_index_match;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if index_match.0 == 0 {
|
||||||
|
return None;
|
||||||
|
} else {
|
||||||
|
return Some((index_match.0, String::from(index_match.1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// `can_send_request` checks if a request can be sent. It returns a `bool` that indicates if
|
/// `can_send_request` checks if a request can be sent. It returns a `bool` that indicates if
|
||||||
/// the request can be sent.
|
/// the request can be sent.
|
||||||
fn can_send_request(&mut self) -> bool {
|
fn can_send_request(&mut self, request: RequestBuilder) -> bool {
|
||||||
let mut can_send = true;
|
// get the url from request
|
||||||
for limit in self.limits.iter_mut() {
|
let global_limit = self.get_limit("global").unwrap();
|
||||||
if limit.remaining == 0 {
|
let ip_limit = self.get_limit("ip").unwrap();
|
||||||
can_send = false;
|
|
||||||
|
if ip_limit.remaining == 0 || global_limit.remaining == 0 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let url_path = request
|
||||||
|
.try_clone()
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
.url()
|
||||||
|
.path()
|
||||||
|
.to_string();
|
||||||
|
// Define the different rate limit buckets as they would appear in the URL
|
||||||
|
let matches = [
|
||||||
|
"login", "register", "webhooks", "channels", "messages", "guilds",
|
||||||
|
];
|
||||||
|
let index_match_string = LimitedRequester::last_match(&url_path, matches.to_vec())
|
||||||
|
.unwrap_or_else(|| (0, String::from("")));
|
||||||
|
if index_match_string.0 == 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
let index_match = (index_match_string.0, index_match_string.1.as_str());
|
||||||
|
match index_match.1 {
|
||||||
|
"login" => {
|
||||||
|
let auth_limit = self.get_limit("login").unwrap();
|
||||||
|
if auth_limit.remaining != 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
"register" => {
|
||||||
|
let auth_limit = self.get_limit("auth.register").unwrap();
|
||||||
|
let absolute_limit = self.get_limit("absoluteRate.register").unwrap();
|
||||||
|
if auth_limit.remaining != 0 && absolute_limit.remaining != 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
"messages" => {
|
||||||
|
let absolute_limit = self.get_limit("absoluteRate.sendMessages").unwrap();
|
||||||
|
let request_method = request
|
||||||
|
.try_clone()
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
.method()
|
||||||
|
.as_str()
|
||||||
|
.to_owned();
|
||||||
|
if absolute_limit.remaining != 0
|
||||||
|
|| request_method != "POST"
|
||||||
|
|| request_method != "PUT"
|
||||||
|
|| request_method != "PATCH"
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
"webhooks" => {
|
||||||
|
let auth_limit = self.get_limit("webhooks").unwrap();
|
||||||
|
if auth_limit.remaining != 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
"channels" => {
|
||||||
|
let auth_limit = self.get_limit("channels").unwrap();
|
||||||
|
if auth_limit.remaining != 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
"guilds" => {
|
||||||
|
let auth_limit = self.get_limit("guilds").unwrap();
|
||||||
|
if auth_limit.remaining != 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
&_ => {
|
||||||
|
panic!();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
can_send
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `update_limits` updates the `Limit`s of the `LimitedRequester` based on the response headers
|
/// `update_limits` updates the `Limit`s of the `LimitedRequester` based on the response headers
|
||||||
|
@ -73,15 +192,13 @@ impl LimitedRequester {
|
||||||
if reset_epoch != 0 {
|
if reset_epoch != 0 {
|
||||||
self.last_reset_epoch = reset_epoch;
|
self.last_reset_epoch = reset_epoch;
|
||||||
}
|
}
|
||||||
|
|
||||||
let httpclient = reqwest::Client::new();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `send_request` sends a request to the server, if `can_send_request()` is true.
|
/// `send_request` sends a request to the server, if `can_send_request()` is true.
|
||||||
/// It will then update the `Limit`s by calling `update_limits()`.
|
/// It will then update the `Limit`s by calling `update_limits()`.
|
||||||
/// # Example
|
/// # Example
|
||||||
pub async fn send_request(&mut self, request: RequestBuilder) -> reqwest::Response {
|
pub async fn send_request(&mut self, request: RequestBuilder) -> reqwest::Response {
|
||||||
if !self.can_send_request() {
|
if !self.can_send_request(request.try_clone().unwrap()) {
|
||||||
panic!("429: Rate limited");
|
panic!("429: Rate limited");
|
||||||
}
|
}
|
||||||
let response = request.send().await.unwrap();
|
let response = request.send().await.unwrap();
|
||||||
|
@ -112,17 +229,24 @@ impl LimitedRequester {
|
||||||
|
|
||||||
let mut limit_vector = Vec::new();
|
let mut limit_vector = Vec::new();
|
||||||
if !config.rate.enabled {
|
if !config.rate.enabled {
|
||||||
|
// The different rate limit buckets, except for the absoluteRate ones. These will be
|
||||||
|
// handled seperately.
|
||||||
let types = [
|
let types = [
|
||||||
"rate.ip",
|
"ip",
|
||||||
"rate.routes.auth.login",
|
"auth.login",
|
||||||
"rate.routes.auth.register",
|
"auth.register",
|
||||||
|
"global",
|
||||||
|
"error",
|
||||||
|
"guild",
|
||||||
|
"webhook",
|
||||||
|
"channel",
|
||||||
];
|
];
|
||||||
for type_ in types.iter() {
|
for type_ in types.iter() {
|
||||||
limit_vector.push(Limit {
|
limit_vector.push(Limit {
|
||||||
limit: u64::MAX,
|
limit: u64::MAX,
|
||||||
remaining: u64::MAX,
|
remaining: u64::MAX,
|
||||||
reset: 1,
|
reset: 1,
|
||||||
bucket: String::from(*type_),
|
bucket: type_.to_string(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -130,50 +254,50 @@ impl LimitedRequester {
|
||||||
limit: config.rate.ip.count,
|
limit: config.rate.ip.count,
|
||||||
remaining: config.rate.ip.count,
|
remaining: config.rate.ip.count,
|
||||||
reset: config.rate.ip.window,
|
reset: config.rate.ip.window,
|
||||||
bucket: String::from("rate.ip"),
|
bucket: String::from("ip"),
|
||||||
});
|
});
|
||||||
limit_vector.push(Limit {
|
limit_vector.push(Limit {
|
||||||
limit: config.rate.global.count,
|
limit: config.rate.global.count,
|
||||||
remaining: config.rate.global.count,
|
remaining: config.rate.global.count,
|
||||||
reset: config.rate.global.window,
|
reset: config.rate.global.window,
|
||||||
bucket: String::from("rate.global"),
|
bucket: String::from("global"),
|
||||||
});
|
});
|
||||||
limit_vector.push(Limit {
|
limit_vector.push(Limit {
|
||||||
limit: config.rate.error.count,
|
limit: config.rate.error.count,
|
||||||
remaining: config.rate.error.count,
|
remaining: config.rate.error.count,
|
||||||
reset: config.rate.error.window,
|
reset: config.rate.error.window,
|
||||||
bucket: String::from("rate.error"),
|
bucket: String::from("error"),
|
||||||
});
|
});
|
||||||
limit_vector.push(Limit {
|
limit_vector.push(Limit {
|
||||||
limit: config.rate.routes.guild.count,
|
limit: config.rate.routes.guild.count,
|
||||||
remaining: config.rate.routes.guild.count,
|
remaining: config.rate.routes.guild.count,
|
||||||
reset: config.rate.routes.guild.window,
|
reset: config.rate.routes.guild.window,
|
||||||
bucket: String::from("rate.routes.guild"),
|
bucket: String::from("guild"),
|
||||||
});
|
});
|
||||||
limit_vector.push(Limit {
|
limit_vector.push(Limit {
|
||||||
limit: config.rate.routes.webhook.count,
|
limit: config.rate.routes.webhook.count,
|
||||||
remaining: config.rate.routes.webhook.count,
|
remaining: config.rate.routes.webhook.count,
|
||||||
reset: config.rate.routes.webhook.window,
|
reset: config.rate.routes.webhook.window,
|
||||||
bucket: String::from("rate.routes.webhook"),
|
bucket: String::from("webhook"),
|
||||||
});
|
});
|
||||||
limit_vector.push(Limit {
|
limit_vector.push(Limit {
|
||||||
limit: config.rate.routes.channel.count,
|
limit: config.rate.routes.channel.count,
|
||||||
remaining: config.rate.routes.channel.count,
|
remaining: config.rate.routes.channel.count,
|
||||||
reset: config.rate.routes.channel.window,
|
reset: config.rate.routes.channel.window,
|
||||||
bucket: String::from("rate.routes.channel"),
|
bucket: String::from("channel"),
|
||||||
});
|
});
|
||||||
limit_vector.push(Limit {
|
limit_vector.push(Limit {
|
||||||
limit: config.rate.routes.auth.login.count,
|
limit: config.rate.routes.auth.login.count,
|
||||||
remaining: config.rate.routes.auth.login.count,
|
remaining: config.rate.routes.auth.login.count,
|
||||||
reset: config.rate.routes.auth.login.window,
|
reset: config.rate.routes.auth.login.window,
|
||||||
bucket: String::from("rate.routes.auth.login"),
|
bucket: String::from("auth.login"),
|
||||||
});
|
});
|
||||||
|
|
||||||
limit_vector.push(Limit {
|
limit_vector.push(Limit {
|
||||||
limit: config.rate.routes.auth.register.count,
|
limit: config.rate.routes.auth.register.count,
|
||||||
remaining: config.rate.routes.auth.register.count,
|
remaining: config.rate.routes.auth.register.count,
|
||||||
reset: config.rate.routes.auth.register.window,
|
reset: config.rate.routes.auth.register.window,
|
||||||
bucket: String::from("rate.routes.auth.register"),
|
bucket: String::from("auth.register"),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,7 +314,7 @@ impl LimitedRequester {
|
||||||
limit: config.absoluteRate.sendMessage.limit,
|
limit: config.absoluteRate.sendMessage.limit,
|
||||||
remaining: config.absoluteRate.sendMessage.limit,
|
remaining: config.absoluteRate.sendMessage.limit,
|
||||||
reset: config.absoluteRate.sendMessage.window,
|
reset: config.absoluteRate.sendMessage.window,
|
||||||
bucket: String::from("absoluteRate.sendMessage"),
|
bucket: String::from("absoluteRate.messages"),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue