From e8cad2a4ff1d5a2613070aa07b3abace6ac6e318 Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Tue, 11 Apr 2023 21:27:06 +0200 Subject: [PATCH] fix can_send_request --- src/limit.rs | 174 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 149 insertions(+), 25 deletions(-) diff --git a/src/limit.rs b/src/limit.rs index 412cf75..9340f69 100644 --- a/src/limit.rs +++ b/src/limit.rs @@ -1,6 +1,6 @@ use crate::api::limits::Config; -use reqwest::{Body, Client, Request, RequestBuilder}; +use reqwest::{Client, Request, RequestBuilder}; use serde_json::from_str; use std::collections::VecDeque; @@ -45,16 +45,135 @@ impl LimitedRequester { } } - /// `can_send_request` checks if a request can be sent. It returns a `bool` that indicates if - /// the request can be sent. - fn can_send_request(&mut self) -> bool { - let mut can_send = true; - for limit in self.limits.iter_mut() { - if limit.remaining == 0 { - can_send = false; + fn get_limit(&self, bucket: &str) -> Option<&Limit> { + for limit in self.limits.iter() { + if limit.bucket == bucket { + return Some(limit.to_owned()); + } + } + None + } + + /// `last_match` returns the last match of a `Vec<&str>` in a `String`. It returns a `Option<(usize, String)>` + /// that contains the index of the match and the match itself. + /// If no match is found, it returns `None`. + /// # Example + /// ```rs + /// let string = "https://discord.com/api/v8/channels/1234567890/messages"; + /// let matches = ["channels", "messages"]; + /// let index_match = last_match(string, matches.to_vec()); + /// assert_eq!(index_match, Some((43, String::from("messages")))); + /// ``` + fn last_match(string: &str, search_for: Vec<&str>) -> Option<(usize, String)> { + let mut index_match: (usize, &str) = (0, ""); + for _match in search_for.iter() { + if !string.contains(_match) { + continue; + } + // Get the index of the match + let temp_index_match = string.match_indices(_match).next().unwrap(); + // As only the last match is relevant, we only update the index_match if the index is + // higher than the previous one. + if temp_index_match.0 > index_match.0 { + index_match = temp_index_match; + } + } + if index_match.0 == 0 { + return None; + } else { + return Some((index_match.0, String::from(index_match.1))); + } + } + + /// `can_send_request` checks if a request can be sent. It returns a `bool` that indicates if + /// the request can be sent. + fn can_send_request(&mut self, request: RequestBuilder) -> bool { + // get the url from request + let global_limit = self.get_limit("global").unwrap(); + let ip_limit = self.get_limit("ip").unwrap(); + + if ip_limit.remaining == 0 || global_limit.remaining == 0 { + return false; + } + + let url_path = request + .try_clone() + .unwrap() + .build() + .unwrap() + .url() + .path() + .to_string(); + // Define the different rate limit buckets as they would appear in the URL + let matches = [ + "login", "register", "webhooks", "channels", "messages", "guilds", + ]; + let index_match_string = LimitedRequester::last_match(&url_path, matches.to_vec()) + .unwrap_or_else(|| (0, String::from(""))); + if index_match_string.0 == 0 { + return true; + } + let index_match = (index_match_string.0, index_match_string.1.as_str()); + match index_match.1 { + "login" => { + let auth_limit = self.get_limit("login").unwrap(); + if auth_limit.remaining != 0 { + return true; + } + return false; + } + "register" => { + let auth_limit = self.get_limit("auth.register").unwrap(); + let absolute_limit = self.get_limit("absoluteRate.register").unwrap(); + if auth_limit.remaining != 0 && absolute_limit.remaining != 0 { + return true; + } + return false; + } + "messages" => { + let absolute_limit = self.get_limit("absoluteRate.sendMessages").unwrap(); + let request_method = request + .try_clone() + .unwrap() + .build() + .unwrap() + .method() + .as_str() + .to_owned(); + if absolute_limit.remaining != 0 + || request_method != "POST" + || request_method != "PUT" + || request_method != "PATCH" + { + return true; + } + return false; + } + "webhooks" => { + let auth_limit = self.get_limit("webhooks").unwrap(); + if auth_limit.remaining != 0 { + return true; + } + return false; + } + "channels" => { + let auth_limit = self.get_limit("channels").unwrap(); + if auth_limit.remaining != 0 { + return true; + } + return false; + } + "guilds" => { + let auth_limit = self.get_limit("guilds").unwrap(); + if auth_limit.remaining != 0 { + return true; + } + return false; + } + &_ => { + panic!(); } } - can_send } /// `update_limits` updates the `Limit`s of the `LimitedRequester` based on the response headers @@ -73,15 +192,13 @@ impl LimitedRequester { if reset_epoch != 0 { self.last_reset_epoch = reset_epoch; } - - let httpclient = reqwest::Client::new(); } /// `send_request` sends a request to the server, if `can_send_request()` is true. /// It will then update the `Limit`s by calling `update_limits()`. /// # Example pub async fn send_request(&mut self, request: RequestBuilder) -> reqwest::Response { - if !self.can_send_request() { + if !self.can_send_request(request.try_clone().unwrap()) { panic!("429: Rate limited"); } let response = request.send().await.unwrap(); @@ -112,17 +229,24 @@ impl LimitedRequester { let mut limit_vector = Vec::new(); if !config.rate.enabled { + // The different rate limit buckets, except for the absoluteRate ones. These will be + // handled seperately. let types = [ - "rate.ip", - "rate.routes.auth.login", - "rate.routes.auth.register", + "ip", + "auth.login", + "auth.register", + "global", + "error", + "guild", + "webhook", + "channel", ]; for type_ in types.iter() { limit_vector.push(Limit { limit: u64::MAX, remaining: u64::MAX, reset: 1, - bucket: String::from(*type_), + bucket: type_.to_string(), }); } } else { @@ -130,50 +254,50 @@ impl LimitedRequester { limit: config.rate.ip.count, remaining: config.rate.ip.count, reset: config.rate.ip.window, - bucket: String::from("rate.ip"), + bucket: String::from("ip"), }); limit_vector.push(Limit { limit: config.rate.global.count, remaining: config.rate.global.count, reset: config.rate.global.window, - bucket: String::from("rate.global"), + bucket: String::from("global"), }); limit_vector.push(Limit { limit: config.rate.error.count, remaining: config.rate.error.count, reset: config.rate.error.window, - bucket: String::from("rate.error"), + bucket: String::from("error"), }); limit_vector.push(Limit { limit: config.rate.routes.guild.count, remaining: config.rate.routes.guild.count, reset: config.rate.routes.guild.window, - bucket: String::from("rate.routes.guild"), + bucket: String::from("guild"), }); limit_vector.push(Limit { limit: config.rate.routes.webhook.count, remaining: config.rate.routes.webhook.count, reset: config.rate.routes.webhook.window, - bucket: String::from("rate.routes.webhook"), + bucket: String::from("webhook"), }); limit_vector.push(Limit { limit: config.rate.routes.channel.count, remaining: config.rate.routes.channel.count, reset: config.rate.routes.channel.window, - bucket: String::from("rate.routes.channel"), + bucket: String::from("channel"), }); limit_vector.push(Limit { limit: config.rate.routes.auth.login.count, remaining: config.rate.routes.auth.login.count, reset: config.rate.routes.auth.login.window, - bucket: String::from("rate.routes.auth.login"), + bucket: String::from("auth.login"), }); limit_vector.push(Limit { limit: config.rate.routes.auth.register.count, remaining: config.rate.routes.auth.register.count, reset: config.rate.routes.auth.register.window, - bucket: String::from("rate.routes.auth.register"), + bucket: String::from("auth.register"), }); } @@ -190,7 +314,7 @@ impl LimitedRequester { limit: config.absoluteRate.sendMessage.limit, remaining: config.absoluteRate.sendMessage.limit, reset: config.absoluteRate.sendMessage.window, - bucket: String::from("absoluteRate.sendMessage"), + bucket: String::from("absoluteRate.messages"), }); }