Merge branch 'main' into perpetual/gateway-dev

This commit is contained in:
kozabrada123 2023-07-11 18:35:12 +02:00
commit 7de62f0152
35 changed files with 1258 additions and 1924 deletions

View File

@ -10,7 +10,7 @@ backend = ["poem", "sqlx"]
client = []
[dependencies]
tokio = {version = "1.28.1", features = ["rt", "macros", "rt-multi-thread", "full"]}
tokio = {version = "1.28.1"}
serde = {version = "1.0.163", features = ["derive"]}
serde_json = {version= "1.0.96", features = ["raw_value"]}
serde-aux = "4.2.0"
@ -34,7 +34,9 @@ poem = { version = "1.3.55", optional = true }
sqlx = { git = "https://github.com/zert3x/sqlx", branch="feature/skip", features = ["mysql", "sqlite", "json", "chrono", "ipnetwork", "runtime-tokio-native-tls", "any"], optional = true }
thiserror = "1.0.40"
jsonwebtoken = "8.3.0"
log = "0.4.19"
[dev-dependencies]
tokio = {version = "1.28.1", features = ["full"]}
lazy_static = "1.4.0"
rusty-hook = "0.11.2"
rusty-hook = "0.11.2"

View File

@ -2,64 +2,41 @@ use std::cell::RefCell;
use std::rc::Rc;
use reqwest::Client;
use serde_json::{from_str, json};
use serde_json::to_string;
use crate::api::limits::LimitType;
use crate::errors::{ChorusLibError, ChorusResult};
use crate::api::LimitType;
use crate::errors::ChorusResult;
use crate::instance::{Instance, UserMeta};
use crate::limit::LimitedRequester;
use crate::types::{ErrorResponse, LoginResult, LoginSchema};
use crate::ratelimiter::ChorusRequest;
use crate::types::{LoginResult, LoginSchema};
impl Instance {
pub async fn login_account(&mut self, login_schema: &LoginSchema) -> ChorusResult<UserMeta> {
let json_schema = json!(login_schema);
let client = Client::new();
let endpoint_url = self.urls.api.clone() + "/auth/login";
let request_builder = client.post(endpoint_url).body(json_schema.to_string());
let chorus_request = ChorusRequest {
request: Client::new()
.post(endpoint_url)
.body(to_string(login_schema).unwrap()),
limit_type: LimitType::AuthLogin,
};
// We do not have a user yet, and the UserRateLimits will not be affected by a login
// request (since login is an instance wide limit), which is why we are just cloning the
// instances' limits to pass them on as user_rate_limits later.
let mut cloned_limits = self.limits.clone();
let response = LimitedRequester::send_request(
request_builder,
LimitType::AuthRegister,
self,
&mut cloned_limits,
)
.await;
if response.is_err() {
return Err(ChorusLibError::NoResponse);
let mut shell = UserMeta::shell(Rc::new(RefCell::new(self.clone())), "None".to_string());
let login_result = chorus_request
.deserialize_response::<LoginResult>(&mut shell)
.await?;
let object = self.get_user(login_result.token.clone(), None).await?;
if self.limits_information.is_some() {
self.limits_information.as_mut().unwrap().ratelimits = shell.limits.clone().unwrap();
}
let response_unwrap = response.unwrap();
let status = response_unwrap.status();
let response_text_string = response_unwrap.text().await.unwrap();
if status.is_client_error() {
let json: ErrorResponse = serde_json::from_str(&response_text_string).unwrap();
let error_type = json.errors.errors.iter().next().unwrap().0.to_owned();
let mut error = "".to_string();
for (_, value) in json.errors.errors.iter() {
for error_item in value._errors.iter() {
error += &(error_item.message.to_string() + " (" + &error_item.code + ")");
}
}
return Err(ChorusLibError::InvalidFormBodyError { error_type, error });
}
let cloned_limits = self.limits.clone();
let login_result: LoginResult = from_str(&response_text_string).unwrap();
let object = self
.get_user(login_result.token.clone(), None)
.await
.unwrap();
let user = UserMeta::new(
Rc::new(RefCell::new(self.clone())),
login_result.token,
cloned_limits,
self.clone_limits_if_some(),
login_result.settings,
object,
);
Ok(user)
}
}

View File

@ -1,14 +1,14 @@
use std::{cell::RefCell, rc::Rc};
use reqwest::Client;
use serde_json::{from_str, json};
use serde_json::to_string;
use crate::{
api::limits::LimitType,
errors::{ChorusLibError, ChorusResult},
api::policies::instance::LimitType,
errors::ChorusResult,
instance::{Instance, Token, UserMeta},
limit::LimitedRequester,
types::{ErrorResponse, RegisterSchema},
ratelimiter::ChorusRequest,
types::RegisterSchema,
};
impl Instance {
@ -25,49 +25,30 @@ impl Instance {
&mut self,
register_schema: &RegisterSchema,
) -> ChorusResult<UserMeta> {
let json_schema = json!(register_schema);
let client = Client::new();
let endpoint_url = self.urls.api.clone() + "/auth/register";
let request_builder = client.post(endpoint_url).body(json_schema.to_string());
let chorus_request = ChorusRequest {
request: Client::new()
.post(endpoint_url)
.body(to_string(register_schema).unwrap()),
limit_type: LimitType::AuthRegister,
};
// We do not have a user yet, and the UserRateLimits will not be affected by a login
// request (since register is an instance wide limit), which is why we are just cloning
// the instances' limits to pass them on as user_rate_limits later.
let mut cloned_limits = self.limits.clone();
let response = LimitedRequester::send_request(
request_builder,
LimitType::AuthRegister,
self,
&mut cloned_limits,
)
.await;
if response.is_err() {
return Err(ChorusLibError::NoResponse);
}
let response_unwrap = response.unwrap();
let status = response_unwrap.status();
let response_unwrap_text = response_unwrap.text().await.unwrap();
let token = from_str::<Token>(&response_unwrap_text).unwrap();
let token = token.token;
if status.is_client_error() {
let json: ErrorResponse = serde_json::from_str(&token).unwrap();
let error_type = json.errors.errors.iter().next().unwrap().0.to_owned();
let mut error = "".to_string();
for (_, value) in json.errors.errors.iter() {
for error_item in value._errors.iter() {
error += &(error_item.message.to_string() + " (" + &error_item.code + ")");
}
}
return Err(ChorusLibError::InvalidFormBodyError { error_type, error });
let mut shell = UserMeta::shell(Rc::new(RefCell::new(self.clone())), "None".to_string());
let token = chorus_request
.deserialize_response::<Token>(&mut shell)
.await?
.token;
if self.limits_information.is_some() {
self.limits_information.as_mut().unwrap().ratelimits = shell.limits.unwrap();
}
let user_object = self.get_user(token.clone(), None).await.unwrap();
let settings = UserMeta::get_settings(&token, &self.urls.api.clone(), self)
.await
.unwrap();
let settings = UserMeta::get_settings(&token, &self.urls.api.clone(), self).await?;
let user = UserMeta::new(
Rc::new(RefCell::new(self.clone())),
token.clone(),
cloned_limits,
self.clone_limits_if_some(),
settings,
user_object,
);

View File

@ -2,32 +2,23 @@ use reqwest::Client;
use serde_json::to_string;
use crate::{
api::common,
errors::{ChorusLibError, ChorusResult},
api::LimitType,
errors::{ChorusError, ChorusResult},
instance::UserMeta,
ratelimiter::ChorusRequest,
types::{Channel, ChannelModifySchema, GetChannelMessagesSchema, Message, Snowflake},
};
impl Channel {
pub async fn get(user: &mut UserMeta, channel_id: Snowflake) -> ChorusResult<Channel> {
let url = user.belongs_to.borrow_mut().urls.api.clone();
let request = Client::new()
.get(format!("{}/channels/{}/", url, channel_id))
.bearer_auth(user.token());
let result = common::deserialize_response::<Channel>(
request,
user,
crate::api::limits::LimitType::Channel,
)
.await;
if result.is_err() {
return Err(ChorusLibError::RequestErrorError {
url: format!("{}/channels/{}/", url, channel_id),
error: result.err().unwrap().to_string(),
});
}
Ok(result.unwrap())
let url = user.belongs_to.borrow().urls.api.clone();
let chorus_request = ChorusRequest {
request: Client::new()
.get(format!("{}/channels/{}/", url, channel_id))
.bearer_auth(user.token()),
limit_type: LimitType::Channel(channel_id),
};
chorus_request.deserialize_response::<Channel>(user).await
}
/// Deletes a channel.
@ -44,15 +35,17 @@ impl Channel {
///
/// A `Result` that contains a `ChorusLibError` if an error occurred during the request, or `()` if the request was successful.
pub async fn delete(self, user: &mut UserMeta) -> ChorusResult<()> {
let request = Client::new()
.delete(format!(
"{}/channels/{}/",
user.belongs_to.borrow_mut().urls.api,
self.id
))
.bearer_auth(user.token());
common::handle_request_as_result(request, user, crate::api::limits::LimitType::Channel)
.await
let chorus_request = ChorusRequest {
request: Client::new()
.delete(format!(
"{}/channels/{}/",
user.belongs_to.borrow().urls.api,
self.id
))
.bearer_auth(user.token()),
limit_type: LimitType::Channel(self.id),
};
chorus_request.handle_request_as_result(user).await
}
/// Modifies a channel.
@ -75,20 +68,18 @@ impl Channel {
channel_id: Snowflake,
user: &mut UserMeta,
) -> ChorusResult<()> {
let request = Client::new()
.patch(format!(
"{}/channels/{}/",
user.belongs_to.borrow().urls.api,
channel_id
))
.bearer_auth(user.token())
.body(to_string(&modify_data).unwrap());
let new_channel = common::deserialize_response::<Channel>(
request,
user,
crate::api::limits::LimitType::Channel,
)
.await?;
let chorus_request = ChorusRequest {
request: Client::new()
.patch(format!(
"{}/channels/{}/",
user.belongs_to.borrow().urls.api,
channel_id
))
.bearer_auth(user.token())
.body(to_string(&modify_data).unwrap()),
limit_type: LimitType::Channel(channel_id),
};
let new_channel = chorus_request.deserialize_response::<Channel>(user).await?;
let _ = std::mem::replace(self, new_channel);
Ok(())
}
@ -97,16 +88,21 @@ impl Channel {
range: GetChannelMessagesSchema,
channel_id: Snowflake,
user: &mut UserMeta,
) -> Result<Vec<Message>, ChorusLibError> {
let request = Client::new()
.get(format!(
"{}/channels/{}/messages",
user.belongs_to.borrow().urls.api,
channel_id
))
.bearer_auth(user.token())
.query(&range);
) -> Result<Vec<Message>, ChorusError> {
let chorus_request = ChorusRequest {
request: Client::new()
.get(format!(
"{}/channels/{}/messages",
user.belongs_to.borrow().urls.api,
channel_id
))
.bearer_auth(user.token())
.query(&range),
limit_type: Default::default(),
};
common::deserialize_response::<Vec<Message>>(request, user, Default::default()).await
chorus_request
.deserialize_response::<Vec<Message>>(user)
.await
}
}

View File

@ -3,8 +3,9 @@ use http::HeaderMap;
use reqwest::{multipart, Client};
use serde_json::to_string;
use crate::api::deserialize_response;
use crate::api::LimitType;
use crate::instance::UserMeta;
use crate::ratelimiter::ChorusRequest;
use crate::types::{Message, MessageSendSchema, PartialDiscordFileAttachment, Snowflake};
impl Message {
@ -24,16 +25,18 @@ impl Message {
channel_id: Snowflake,
message: &mut MessageSendSchema,
files: Option<Vec<PartialDiscordFileAttachment>>,
) -> Result<Message, crate::errors::ChorusLibError> {
) -> Result<Message, crate::errors::ChorusError> {
let url_api = user.belongs_to.borrow().urls.api.clone();
if files.is_none() {
let request = Client::new()
.post(format!("{}/channels/{}/messages/", url_api, channel_id))
.bearer_auth(user.token())
.body(to_string(message).unwrap());
deserialize_response::<Message>(request, user, crate::api::limits::LimitType::Channel)
.await
let chorus_request = ChorusRequest {
request: Client::new()
.post(format!("{}/channels/{}/messages/", url_api, channel_id))
.bearer_auth(user.token())
.body(to_string(message).unwrap()),
limit_type: LimitType::Channel(channel_id),
};
chorus_request.deserialize_response::<Message>(user).await
} else {
for (index, attachment) in message.attachments.iter_mut().enumerate() {
attachment.get_mut(index).unwrap().set_id(index as i16);
@ -62,13 +65,14 @@ impl Message {
form = form.part(part_name, part);
}
let request = Client::new()
.post(format!("{}/channels/{}/messages/", url_api, channel_id))
.bearer_auth(user.token())
.multipart(form);
deserialize_response::<Message>(request, user, crate::api::limits::LimitType::Channel)
.await
let chorus_request = ChorusRequest {
request: Client::new()
.post(format!("{}/channels/{}/messages/", url_api, channel_id))
.bearer_auth(user.token())
.multipart(form),
limit_type: LimitType::Channel(channel_id),
};
chorus_request.deserialize_response::<Message>(user).await
}
}
}
@ -91,7 +95,7 @@ impl UserMeta {
message: &mut MessageSendSchema,
channel_id: Snowflake,
files: Option<Vec<PartialDiscordFileAttachment>>,
) -> Result<Message, crate::errors::ChorusLibError> {
) -> Result<Message, crate::errors::ChorusError> {
Message::send(self, channel_id, message, files).await
}
}

View File

@ -2,9 +2,10 @@ use reqwest::Client;
use serde_json::to_string;
use crate::{
api::handle_request_as_result,
errors::{ChorusLibError, ChorusResult},
api::LimitType,
errors::{ChorusError, ChorusResult},
instance::UserMeta,
ratelimiter::ChorusRequest,
types::{self, PermissionOverwrite, Snowflake},
};
@ -25,24 +26,25 @@ impl types::Channel {
channel_id: Snowflake,
overwrite: PermissionOverwrite,
) -> ChorusResult<()> {
let url = {
format!(
"{}/channels/{}/permissions/{}",
user.belongs_to.borrow_mut().urls.api,
channel_id,
overwrite.id
)
};
let url = format!(
"{}/channels/{}/permissions/{}",
user.belongs_to.borrow_mut().urls.api,
channel_id,
overwrite.id
);
let body = match to_string(&overwrite) {
Ok(string) => string,
Err(e) => {
return Err(ChorusLibError::FormCreationError {
return Err(ChorusError::FormCreation {
error: e.to_string(),
});
}
};
let request = Client::new().put(url).bearer_auth(user.token()).body(body);
handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await
let chorus_request = ChorusRequest {
request: Client::new().put(url).bearer_auth(user.token()).body(body),
limit_type: LimitType::Channel(channel_id),
};
chorus_request.handle_request_as_result(user).await
}
/// Deletes a permission overwrite for a channel.
@ -67,7 +69,10 @@ impl types::Channel {
channel_id,
overwrite_id
);
let request = Client::new().delete(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await
let chorus_request = ChorusRequest {
request: Client::new().delete(url).bearer_auth(user.token()),
limit_type: LimitType::Channel(channel_id),
};
chorus_request.handle_request_as_result(user).await
}
}

View File

@ -1,10 +1,11 @@
use reqwest::Client;
use crate::{
api::handle_request_as_result,
api::LimitType,
errors::ChorusResult,
instance::UserMeta,
types::{self, Snowflake},
ratelimiter::ChorusRequest,
types::{self, PublicUser, Snowflake},
};
/**
@ -16,20 +17,15 @@ pub struct ReactionMeta {
}
impl ReactionMeta {
/**
Deletes all reactions for a message.
This endpoint requires the `MANAGE_MESSAGES` permission to be present on the current user.
# Arguments
* `user` - A mutable reference to a [`UserMeta`] instance.
# Returns
A `Result` [`()`] [`crate::errors::ChorusLibError`] if something went wrong.
Fires a `Message Reaction Remove All` Gateway event.
# Reference
See [https://discord.com/developers/docs/resources/channel#delete-all-reactions](https://discord.com/developers/docs/resources/channel#delete-all-reactions)
*/
/// Deletes all reactions for a message.
/// This endpoint requires the `MANAGE_MESSAGES` permission to be present on the current user.
/// # Arguments
/// * `user` - A mutable reference to a [`UserMeta`] instance.
/// # Returns
/// A `Result` [`()`] [`crate::errors::ChorusLibError`] if something went wrong.
/// Fires a `Message Reaction Remove All` Gateway event.
/// # Reference
/// See [https://discord.com/developers/docs/resources/channel#delete-all-reactions](https://discord.com/developers/docs/resources/channel#delete-all-reactions)
pub async fn delete_all(&self, user: &mut UserMeta) -> ChorusResult<()> {
let url = format!(
"{}/channels/{}/messages/{}/reactions/",
@ -37,26 +33,24 @@ impl ReactionMeta {
self.channel_id,
self.message_id
);
let request = Client::new().delete(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await
let chorus_request = ChorusRequest {
request: Client::new().delete(url).bearer_auth(user.token()),
limit_type: LimitType::Channel(self.channel_id),
};
chorus_request.handle_request_as_result(user).await
}
/**
Gets a list of users that reacted with a specific emoji to a message.
# Arguments
* `emoji` - A string slice containing the emoji to search for. The emoji must be URL Encoded or
the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
format name:id with the emoji name and emoji id.
* `user` - A mutable reference to a [`UserMeta`] instance.
# Returns
A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong.
# Reference
See [https://discord.com/developers/docs/resources/channel#get-reactions](https://discord.com/developers/docs/resources/channel#get-reactions)
*/
pub async fn get(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> {
/// Gets a list of users that reacted with a specific emoji to a message.
/// # Arguments
/// * `emoji` - A string slice containing the emoji to search for. The emoji must be URL Encoded or
/// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
/// format name:id with the emoji name and emoji id.
/// * `user` - A mutable reference to a [`UserMeta`] instance.
/// # Returns
/// A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong.
/// # Reference
/// See [https://discord.com/developers/docs/resources/channel#get-reactions](https://discord.com/developers/docs/resources/channel#get-reactions)
pub async fn get(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<Vec<PublicUser>> {
let url = format!(
"{}/channels/{}/messages/{}/reactions/{}/",
user.belongs_to.borrow().urls.api,
@ -64,27 +58,27 @@ impl ReactionMeta {
self.message_id,
emoji
);
let request = Client::new().get(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await
let chorus_request = ChorusRequest {
request: Client::new().get(url).bearer_auth(user.token()),
limit_type: LimitType::Channel(self.channel_id),
};
chorus_request
.deserialize_response::<Vec<PublicUser>>(user)
.await
}
/**
Deletes all the reactions for a given `emoji` on a message. This endpoint requires the
MANAGE_MESSAGES permission to be present on the current user.
# Arguments
* `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or
the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
format name:id with the emoji name and emoji id.
* `user` - A mutable reference to a [`UserMeta`] instance.
# Returns
A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong.
Fires a `Message Reaction Remove Emoji` Gateway event.
# Reference
See [https://discord.com/developers/docs/resources/channel#delete-all-reactions-for-emoji](https://discord.com/developers/docs/resources/channel#delete-all-reactions-for-emoji)
*/
/// Deletes all the reactions for a given `emoji` on a message. This endpoint requires the
/// MANAGE_MESSAGES permission to be present on the current user.
/// # Arguments
/// * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or
/// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
/// format name:id with the emoji name and emoji id.
/// * `user` - A mutable reference to a [`UserMeta`] instance.
/// # Returns
/// A Result that is [`Err(crate::errors::ChorusLibError)`] if something went wrong.
/// Fires a `Message Reaction Remove Emoji` Gateway event.
/// # Reference
/// See [https://discord.com/developers/docs/resources/channel#delete-all-reactions-for-emoji](https://discord.com/developers/docs/resources/channel#delete-all-reactions-for-emoji)
pub async fn delete_emoji(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> {
let url = format!(
"{}/channels/{}/messages/{}/reactions/{}/",
@ -93,29 +87,28 @@ impl ReactionMeta {
self.message_id,
emoji
);
let request = Client::new().delete(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await
let chorus_request = ChorusRequest {
request: Client::new().delete(url).bearer_auth(user.token()),
limit_type: LimitType::Channel(self.channel_id),
};
chorus_request.handle_request_as_result(user).await
}
/**
Create a reaction for the message.
This endpoint requires the READ_MESSAGE_HISTORY permission
to be present on the current user. Additionally, if nobody else has reacted to the message using
this emoji, this endpoint requires the ADD_REACTIONS permission to be present on the current
user.
# Arguments
* `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or
the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
format name:id with the emoji name and emoji id.
* `user` - A mutable reference to a [`UserMeta`] instance.
# Returns
A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`].
# Reference
See [https://discord.com/developers/docs/resources/channel#create-reaction](https://discord.com/developers/docs/resources/channel#create-reaction)
*/
/// Create a reaction for the message.
/// This endpoint requires the READ_MESSAGE_HISTORY permission
/// to be present on the current user. Additionally, if nobody else has reacted to the message using
/// this emoji, this endpoint requires the ADD_REACTIONS permission to be present on the current
/// user.
/// # Arguments
/// * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or
/// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
/// format name:id with the emoji name and emoji id.
/// * `user` - A mutable reference to a [`UserMeta`] instance.
/// # Returns
/// A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`].
/// # Reference
/// See [https://discord.com/developers/docs/resources/channel#create-reaction](https://discord.com/developers/docs/resources/channel#create-reaction)
///
pub async fn create(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> {
let url = format!(
"{}/channels/{}/messages/{}/reactions/{}/@me/",
@ -124,26 +117,24 @@ impl ReactionMeta {
self.message_id,
emoji
);
let request = Client::new().put(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await
let chorus_request = ChorusRequest {
request: Client::new().put(url).bearer_auth(user.token()),
limit_type: LimitType::Channel(self.channel_id),
};
chorus_request.handle_request_as_result(user).await
}
/**
Delete a reaction the current user has made for the message.
# Arguments
* `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or
the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
format name:id with the emoji name and emoji id.
* `user` - A mutable reference to a [`UserMeta`] instance.
# Returns
A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`].
Fires a `Message Reaction Remove` Gateway event.
# Reference
See [https://discord.com/developers/docs/resources/channel#delete-own-reaction](https://discord.com/developers/docs/resources/channel#delete-own-reaction)
*/
/// Delete a reaction the current user has made for the message.
/// # Arguments
/// * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or
/// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
/// format name:id with the emoji name and emoji id.
/// * `user` - A mutable reference to a [`UserMeta`] instance.
/// # Returns
/// A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`].
/// Fires a `Message Reaction Remove` Gateway event.
/// # Reference
/// See [https://discord.com/developers/docs/resources/channel#delete-own-reaction](https://discord.com/developers/docs/resources/channel#delete-own-reaction)
pub async fn remove(&self, emoji: &str, user: &mut UserMeta) -> ChorusResult<()> {
let url = format!(
"{}/channels/{}/messages/{}/reactions/{}/@me/",
@ -152,29 +143,26 @@ impl ReactionMeta {
self.message_id,
emoji
);
let request = Client::new().delete(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await
let chorus_request = ChorusRequest {
request: Client::new().delete(url).bearer_auth(user.token()),
limit_type: LimitType::Channel(self.channel_id),
};
chorus_request.handle_request_as_result(user).await
}
/**
Delete a user's reaction to a message.
This endpoint requires the MANAGE_MESSAGES permission to be present on the current user.
# Arguments
* `user_id` - ID of the user whose reaction is to be deleted.
* `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or
the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
format name:id with the emoji name and emoji id.
* `user` - A mutable reference to a [`UserMeta`] instance.
# Returns
A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`].
Fires a Message Reaction Remove Gateway event.
# Reference
See [https://discord.com/developers/docs/resources/channel#delete-own-reaction](https://discord.com/developers/docs/resources/channel#delete-own-reaction)
*/
/// Delete a user's reaction to a message.
/// This endpoint requires the MANAGE_MESSAGES permission to be present on the current user.
/// # Arguments
/// * `user_id` - ID of the user whose reaction is to be deleted.
/// * `emoji` - A string slice containing the emoji to delete. The `emoji` must be URL Encoded or
/// the request will fail with 10014: Unknown Emoji. To use custom emoji, you must encode it in the
/// format name:id with the emoji name and emoji id.
/// * `user` - A mutable reference to a [`UserMeta`] instance.
/// # Returns
/// A `Result` containing [`()`] or a [`crate::errors::ChorusLibError`].
/// Fires a Message Reaction Remove Gateway event.
/// # Reference
/// See [https://discord.com/developers/docs/resources/channel#delete-own-reaction](https://discord.com/developers/docs/resources/channel#delete-own-reaction)
pub async fn delete_user(
&self,
user_id: Snowflake,
@ -189,7 +177,10 @@ impl ReactionMeta {
emoji,
user_id
);
let request = Client::new().delete(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Channel).await
let chorus_request = ChorusRequest {
request: Client::new().delete(url).bearer_auth(user.token()),
limit_type: LimitType::Channel(self.channel_id),
};
chorus_request.handle_request_as_result(user).await
}
}

View File

@ -1,72 +0,0 @@
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::from_str;
use crate::{
errors::{ChorusLibError, ChorusResult},
instance::UserMeta,
limit::LimitedRequester,
};
use super::limits::LimitType;
/// Sends a request to wherever it needs to go and performs some basic error handling.
pub async fn handle_request(
request: RequestBuilder,
user: &mut UserMeta,
limit_type: LimitType,
) -> Result<reqwest::Response, crate::errors::ChorusLibError> {
LimitedRequester::send_request(
request,
limit_type,
&mut user.belongs_to.borrow_mut(),
&mut user.limits,
)
.await
}
/// Sends a request to wherever it needs to go. Returns [`Ok(())`] on success and
/// [`Err(ChorusLibError)`] on failure.
pub async fn handle_request_as_result(
request: RequestBuilder,
user: &mut UserMeta,
limit_type: LimitType,
) -> ChorusResult<()> {
match handle_request(request, user, limit_type).await {
Ok(_) => Ok(()),
Err(e) => Err(ChorusLibError::InvalidResponseError {
error: e.to_string(),
}),
}
}
pub async fn deserialize_response<T: for<'a> Deserialize<'a>>(
request: RequestBuilder,
user: &mut UserMeta,
limit_type: LimitType,
) -> ChorusResult<T> {
let response = handle_request(request, user, limit_type).await.unwrap();
let response_text = match response.text().await {
Ok(string) => string,
Err(e) => {
return Err(ChorusLibError::InvalidResponseError {
error: format!(
"Error while trying to process the HTTP response into a String: {}",
e
),
});
}
};
let object = match from_str::<T>(&response_text) {
Ok(object) => object,
Err(e) => {
return Err(ChorusLibError::InvalidResponseError {
error: format!(
"Error while trying to deserialize the JSON response into T: {}",
e
),
})
}
};
Ok(object)
}

View File

@ -2,15 +2,11 @@ use reqwest::Client;
use serde_json::from_str;
use serde_json::to_string;
use crate::api::deserialize_response;
use crate::api::handle_request;
use crate::api::handle_request_as_result;
use crate::api::limits::Limits;
use crate::errors::ChorusLibError;
use crate::api::LimitType;
use crate::errors::ChorusError;
use crate::errors::ChorusResult;
use crate::instance::Instance;
use crate::instance::UserMeta;
use crate::limit::LimitedRequester;
use crate::ratelimiter::ChorusRequest;
use crate::types::Snowflake;
use crate::types::{Channel, ChannelCreateSchema, Guild, GuildCreateSchema};
@ -36,11 +32,14 @@ impl Guild {
guild_create_schema: GuildCreateSchema,
) -> ChorusResult<Guild> {
let url = format!("{}/guilds/", user.belongs_to.borrow().urls.api);
let request = reqwest::Client::new()
.post(url.clone())
.bearer_auth(user.token.clone())
.body(to_string(&guild_create_schema).unwrap());
deserialize_response::<Guild>(request, user, crate::api::limits::LimitType::Guild).await
let chorus_request = ChorusRequest {
request: Client::new()
.post(url.clone())
.bearer_auth(user.token.clone())
.body(to_string(&guild_create_schema).unwrap()),
limit_type: LimitType::Global,
};
chorus_request.deserialize_response::<Guild>(user).await
}
/// Deletes a guild.
@ -73,10 +72,13 @@ impl Guild {
user.belongs_to.borrow().urls.api,
guild_id
);
let request = reqwest::Client::new()
.post(url.clone())
.bearer_auth(user.token.clone());
handle_request_as_result(request, user, crate::api::limits::LimitType::Guild).await
let chorus_request = ChorusRequest {
request: Client::new()
.post(url.clone())
.bearer_auth(user.token.clone()),
limit_type: LimitType::Global,
};
chorus_request.handle_request_as_result(user).await
}
/// Sends a request to create a new channel in the guild.
@ -97,14 +99,7 @@ impl Guild {
user: &mut UserMeta,
schema: ChannelCreateSchema,
) -> ChorusResult<Channel> {
Channel::_create(
&user.token,
self.id,
schema,
&mut user.limits,
&mut user.belongs_to.borrow_mut(),
)
.await
Channel::create(user, self.id, schema).await
}
/// Returns a `Result` containing a vector of `Channel` structs if the request was successful, or an `ChorusLibError` if there was an error.
@ -117,20 +112,21 @@ impl Guild {
/// * `limits_instance` - A mutable reference to a `Limits` struct containing the instance's rate limits.
///
pub async fn channels(&self, user: &mut UserMeta) -> ChorusResult<Vec<Channel>> {
let request = Client::new()
.get(format!(
"{}/guilds/{}/channels/",
user.belongs_to.borrow().urls.api,
self.id
))
.bearer_auth(user.token());
let result = handle_request(request, user, crate::api::limits::LimitType::Channel)
.await
.unwrap();
let chorus_request = ChorusRequest {
request: Client::new()
.get(format!(
"{}/guilds/{}/channels/",
user.belongs_to.borrow().urls.api,
self.id
))
.bearer_auth(user.token()),
limit_type: LimitType::Channel(self.id),
};
let result = chorus_request.send_request(user).await?;
let stringed_response = match result.text().await {
Ok(value) => value,
Err(e) => {
return Err(ChorusLibError::InvalidResponseError {
return Err(ChorusError::InvalidResponse {
error: e.to_string(),
});
}
@ -138,7 +134,7 @@ impl Guild {
let _: Vec<Channel> = match from_str(&stringed_response) {
Ok(result) => return Ok(result),
Err(e) => {
return Err(ChorusLibError::InvalidResponseError {
return Err(ChorusError::InvalidResponse {
error: e.to_string(),
});
}
@ -155,35 +151,19 @@ impl Guild {
/// * `limits_user` - A mutable reference to a `Limits` struct containing the user's rate limits.
/// * `limits_instance` - A mutable reference to a `Limits` struct containing the instance's rate limits.
///
pub async fn get(user: &mut UserMeta, guild_id: Snowflake) -> ChorusResult<Guild> {
let mut belongs_to = user.belongs_to.borrow_mut();
Guild::_get(guild_id, &user.token, &mut user.limits, &mut belongs_to).await
}
/// For internal use. Does the same as the public get method, but does not require a second, mutable
/// borrow of `UserMeta::belongs_to`, when used in conjunction with other methods, which borrow `UserMeta::belongs_to`.
async fn _get(
guild_id: Snowflake,
token: &str,
limits_user: &mut Limits,
instance: &mut Instance,
) -> ChorusResult<Guild> {
let request = Client::new()
.get(format!("{}/guilds/{}/", instance.urls.api, guild_id))
.bearer_auth(token);
let response = match LimitedRequester::send_request(
request,
crate::api::limits::LimitType::Guild,
instance,
limits_user,
)
.await
{
Ok(response) => response,
Err(e) => return Err(e),
pub async fn get(guild_id: Snowflake, user: &mut UserMeta) -> ChorusResult<Guild> {
let chorus_request = ChorusRequest {
request: Client::new()
.get(format!(
"{}/guilds/{}/",
user.belongs_to.borrow().urls.api,
guild_id
))
.bearer_auth(user.token()),
limit_type: LimitType::Guild(guild_id),
};
let guild: Guild = from_str(&response.text().await.unwrap()).unwrap();
Ok(guild)
let response = chorus_request.deserialize_response::<Guild>(user).await?;
Ok(response)
}
}
@ -207,48 +187,17 @@ impl Channel {
guild_id: Snowflake,
schema: ChannelCreateSchema,
) -> ChorusResult<Channel> {
let mut belongs_to = user.belongs_to.borrow_mut();
Channel::_create(
&user.token,
guild_id,
schema,
&mut user.limits,
&mut belongs_to,
)
.await
}
async fn _create(
token: &str,
guild_id: Snowflake,
schema: ChannelCreateSchema,
limits_user: &mut Limits,
instance: &mut Instance,
) -> ChorusResult<Channel> {
let request = Client::new()
.post(format!(
"{}/guilds/{}/channels/",
instance.urls.api, guild_id
))
.bearer_auth(token)
.body(to_string(&schema).unwrap());
let result = match LimitedRequester::send_request(
request,
crate::api::limits::LimitType::Guild,
instance,
limits_user,
)
.await
{
Ok(result) => result,
Err(e) => return Err(e),
let chorus_request = ChorusRequest {
request: Client::new()
.post(format!(
"{}/guilds/{}/channels/",
user.belongs_to.borrow().urls.api,
guild_id
))
.bearer_auth(user.token())
.body(to_string(&schema).unwrap()),
limit_type: LimitType::Guild(guild_id),
};
match from_str::<Channel>(&result.text().await.unwrap()) {
Ok(object) => Ok(object),
Err(e) => Err(ChorusLibError::RequestErrorError {
url: format!("{}/guilds/{}/channels/", instance.urls.api, guild_id),
error: e.to_string(),
}),
}
chorus_request.deserialize_response::<Channel>(user).await
}
}

View File

@ -1,9 +1,10 @@
use reqwest::Client;
use crate::{
api::{deserialize_response, handle_request_as_result},
api::LimitType,
errors::ChorusResult,
instance::UserMeta,
ratelimiter::ChorusRequest,
types::{self, Snowflake},
};
@ -30,13 +31,13 @@ impl types::GuildMember {
guild_id,
member_id
);
let request = Client::new().get(url).bearer_auth(user.token());
deserialize_response::<types::GuildMember>(
request,
user,
crate::api::limits::LimitType::Guild,
)
.await
let chorus_request = ChorusRequest {
request: Client::new().get(url).bearer_auth(user.token()),
limit_type: LimitType::Guild(guild_id),
};
chorus_request
.deserialize_response::<types::GuildMember>(user)
.await
}
/// Adds a role to a guild member.
@ -64,8 +65,11 @@ impl types::GuildMember {
member_id,
role_id
);
let request = Client::new().put(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Guild).await
let chorus_request = ChorusRequest {
request: Client::new().put(url).bearer_auth(user.token()),
limit_type: LimitType::Guild(guild_id),
};
chorus_request.handle_request_as_result(user).await
}
/// Removes a role from a guild member.
@ -85,7 +89,7 @@ impl types::GuildMember {
guild_id: Snowflake,
member_id: Snowflake,
role_id: Snowflake,
) -> Result<(), crate::errors::ChorusLibError> {
) -> Result<(), crate::errors::ChorusError> {
let url = format!(
"{}/guilds/{}/members/{}/roles/{}/",
user.belongs_to.borrow().urls.api,
@ -93,7 +97,10 @@ impl types::GuildMember {
member_id,
role_id
);
let request = Client::new().delete(url).bearer_auth(user.token());
handle_request_as_result(request, user, crate::api::limits::LimitType::Guild).await
let chorus_request = ChorusRequest {
request: Client::new().delete(url).bearer_auth(user.token()),
limit_type: LimitType::Guild(guild_id),
};
chorus_request.handle_request_as_result(user).await
}
}

View File

@ -2,9 +2,10 @@ use reqwest::Client;
use serde_json::to_string;
use crate::{
api::deserialize_response,
errors::{ChorusLibError, ChorusResult},
api::LimitType,
errors::{ChorusError, ChorusResult},
instance::UserMeta,
ratelimiter::ChorusRequest,
types::{self, RoleCreateModifySchema, RoleObject, Snowflake},
};
@ -32,14 +33,14 @@ impl types::RoleObject {
user.belongs_to.borrow().urls.api,
guild_id
);
let request = Client::new().get(url).bearer_auth(user.token());
let roles = deserialize_response::<Vec<RoleObject>>(
request,
user,
crate::api::limits::LimitType::Guild,
)
.await
.unwrap();
let chorus_request = ChorusRequest {
request: Client::new().get(url).bearer_auth(user.token()),
limit_type: LimitType::Guild(guild_id),
};
let roles = chorus_request
.deserialize_response::<Vec<RoleObject>>(user)
.await
.unwrap();
if roles.is_empty() {
return Ok(None);
}
@ -72,8 +73,13 @@ impl types::RoleObject {
guild_id,
role_id
);
let request = Client::new().get(url).bearer_auth(user.token());
deserialize_response(request, user, crate::api::limits::LimitType::Guild).await
let chorus_request = ChorusRequest {
request: Client::new().get(url).bearer_auth(user.token()),
limit_type: LimitType::Guild(guild_id),
};
chorus_request
.deserialize_response::<RoleObject>(user)
.await
}
/// Creates a new role for a given guild.
@ -102,12 +108,17 @@ impl types::RoleObject {
guild_id
);
let body = to_string::<RoleCreateModifySchema>(&role_create_schema).map_err(|e| {
ChorusLibError::FormCreationError {
ChorusError::FormCreation {
error: e.to_string(),
}
})?;
let request = Client::new().post(url).bearer_auth(user.token()).body(body);
deserialize_response(request, user, crate::api::limits::LimitType::Guild).await
let chorus_request = ChorusRequest {
request: Client::new().post(url).bearer_auth(user.token()).body(body),
limit_type: LimitType::Guild(guild_id),
};
chorus_request
.deserialize_response::<RoleObject>(user)
.await
}
/// Updates the position of a role in the guild's hierarchy.
@ -135,16 +146,19 @@ impl types::RoleObject {
user.belongs_to.borrow().urls.api,
guild_id
);
let body = to_string(&role_position_update_schema).map_err(|e| {
ChorusLibError::FormCreationError {
let body =
to_string(&role_position_update_schema).map_err(|e| ChorusError::FormCreation {
error: e.to_string(),
}
})?;
let request = Client::new()
.patch(url)
.bearer_auth(user.token())
.body(body);
deserialize_response::<RoleObject>(request, user, crate::api::limits::LimitType::Guild)
})?;
let chorus_request = ChorusRequest {
request: Client::new()
.patch(url)
.bearer_auth(user.token())
.body(body),
limit_type: LimitType::Guild(guild_id),
};
chorus_request
.deserialize_response::<RoleObject>(user)
.await
}
@ -177,15 +191,19 @@ impl types::RoleObject {
role_id
);
let body = to_string::<RoleCreateModifySchema>(&role_create_schema).map_err(|e| {
ChorusLibError::FormCreationError {
ChorusError::FormCreation {
error: e.to_string(),
}
})?;
let request = Client::new()
.patch(url)
.bearer_auth(user.token())
.body(body);
deserialize_response::<RoleObject>(request, user, crate::api::limits::LimitType::Guild)
let chorus_request = ChorusRequest {
request: Client::new()
.patch(url)
.bearer_auth(user.token())
.body(body),
limit_type: LimitType::Guild(guild_id),
};
chorus_request
.deserialize_response::<RoleObject>(user)
.await
}
}

View File

@ -1,12 +1,10 @@
pub use channels::messages::*;
pub use common::*;
pub use guilds::*;
pub use policies::instance::instance::*;
pub use policies::instance::limits::*;
pub use policies::instance::ratelimits::*;
pub mod auth;
pub mod channels;
pub mod common;
pub mod guilds;
pub mod policies;
pub mod users;

View File

@ -1,7 +1,6 @@
use reqwest::Client;
use serde_json::from_str;
use crate::errors::{ChorusLibError, ChorusResult};
use crate::errors::{ChorusError, ChorusResult};
use crate::instance::Instance;
use crate::types::GeneralConfiguration;
@ -10,21 +9,21 @@ impl Instance {
/// # Errors
/// [`ChorusLibError`] - If the request fails.
pub async fn general_configuration_schema(&self) -> ChorusResult<GeneralConfiguration> {
let client = Client::new();
let endpoint_url = self.urls.api.clone() + "/policies/instance/";
let request = match client.get(&endpoint_url).send().await {
let request = match self.client.get(&endpoint_url).send().await {
Ok(result) => result,
Err(e) => {
return Err(ChorusLibError::RequestErrorError {
return Err(ChorusError::RequestFailed {
url: endpoint_url,
error: e.to_string(),
error: e,
});
}
};
if !request.status().as_str().starts_with('2') {
return Err(ChorusLibError::ReceivedErrorCodeError {
error_code: request.status().to_string(),
return Err(ChorusError::ReceivedErrorCode {
error_code: request.status().as_u16(),
error: request.text().await.unwrap(),
});
}

View File

@ -1,499 +0,0 @@
pub mod limits {
use std::collections::HashMap;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::from_str;
#[derive(Clone, Copy, Eq, Hash, PartialEq, Debug, Default)]
pub enum LimitType {
AuthRegister,
AuthLogin,
AbsoluteMessage,
AbsoluteRegister,
#[default]
Global,
Ip,
Channel,
Error,
Guild,
Webhook,
}
impl ToString for LimitType {
fn to_string(&self) -> String {
match self {
LimitType::AuthRegister => "AuthRegister".to_string(),
LimitType::AuthLogin => "AuthLogin".to_string(),
LimitType::AbsoluteMessage => "AbsoluteMessage".to_string(),
LimitType::AbsoluteRegister => "AbsoluteRegister".to_string(),
LimitType::Global => "Global".to_string(),
LimitType::Ip => "Ip".to_string(),
LimitType::Channel => "Channel".to_string(),
LimitType::Error => "Error".to_string(),
LimitType::Guild => "Guild".to_string(),
LimitType::Webhook => "Webhook".to_string(),
}
}
}
#[derive(Debug, Deserialize, Serialize)]
#[allow(non_snake_case)]
pub struct User {
pub maxGuilds: u64,
pub maxUsername: u64,
pub maxFriends: u64,
}
#[derive(Debug, Deserialize, Serialize)]
#[allow(non_snake_case)]
pub struct Guild {
pub maxRoles: u64,
pub maxEmojis: u64,
pub maxMembers: u64,
pub maxChannels: u64,
pub maxChannelsInCategory: u64,
}
#[derive(Debug, Deserialize, Serialize)]
#[allow(non_snake_case)]
pub struct Message {
pub maxCharacters: u64,
pub maxTTSCharacters: u64,
pub maxReactions: u64,
pub maxAttachmentSize: u64,
pub maxBulkDelete: u64,
pub maxEmbedDownloadSize: u64,
}
#[derive(Debug, Deserialize, Serialize)]
#[allow(non_snake_case)]
pub struct Channel {
pub maxPins: u64,
pub maxTopic: u64,
pub maxWebhooks: u64,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Rate {
pub enabled: bool,
pub ip: Window,
pub global: Window,
pub error: Window,
pub routes: Routes,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Window {
pub count: u64,
pub window: u64,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Routes {
pub guild: Window,
pub webhook: Window,
pub channel: Window,
pub auth: AuthRoutes,
}
#[derive(Debug, Deserialize, Serialize)]
#[allow(non_snake_case)]
pub struct AuthRoutes {
pub login: Window,
pub register: Window,
}
#[derive(Debug, Deserialize, Serialize)]
#[allow(non_snake_case)]
pub struct AbsoluteRate {
pub register: AbsoluteWindow,
pub sendMessage: AbsoluteWindow,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct AbsoluteWindow {
pub limit: u64,
pub window: u64,
pub enabled: bool,
}
#[derive(Debug, Deserialize, Serialize)]
#[allow(non_snake_case)]
pub struct Config {
pub user: User,
pub guild: Guild,
pub message: Message,
pub channel: Channel,
pub rate: Rate,
pub absoluteRate: AbsoluteRate,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct Limit {
pub bucket: LimitType,
pub limit: u64,
pub remaining: u64,
pub reset: u64,
}
impl std::fmt::Display for Limit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Bucket: {:?}, Limit: {}, Remaining: {}, Reset: {}",
self.bucket, self.limit, self.remaining, self.reset
)
}
}
impl Limit {
pub fn add_remaining(&mut self, remaining: i64) {
if remaining < 0 {
if (self.remaining as i64 + remaining) <= 0 {
self.remaining = 0;
return;
}
self.remaining -= remaining.unsigned_abs();
return;
}
self.remaining += remaining.unsigned_abs();
}
}
pub struct LimitsMutRef<'a> {
pub limit_absolute_messages: &'a mut Limit,
pub limit_absolute_register: &'a mut Limit,
pub limit_auth_login: &'a mut Limit,
pub limit_auth_register: &'a mut Limit,
pub limit_ip: &'a mut Limit,
pub limit_global: &'a mut Limit,
pub limit_error: &'a mut Limit,
pub limit_guild: &'a mut Limit,
pub limit_webhook: &'a mut Limit,
pub limit_channel: &'a mut Limit,
}
impl LimitsMutRef<'_> {
pub fn combine_mut_ref<'a>(
instance_rate_limits: &'a mut Limits,
user_rate_limits: &'a mut Limits,
) -> LimitsMutRef<'a> {
LimitsMutRef {
limit_absolute_messages: &mut instance_rate_limits.limit_absolute_messages,
limit_absolute_register: &mut instance_rate_limits.limit_absolute_register,
limit_auth_login: &mut instance_rate_limits.limit_auth_login,
limit_auth_register: &mut instance_rate_limits.limit_auth_register,
limit_channel: &mut user_rate_limits.limit_channel,
limit_error: &mut user_rate_limits.limit_error,
limit_global: &mut instance_rate_limits.limit_global,
limit_guild: &mut user_rate_limits.limit_guild,
limit_ip: &mut instance_rate_limits.limit_ip,
limit_webhook: &mut user_rate_limits.limit_webhook,
}
}
pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit {
match limit_type {
LimitType::AbsoluteMessage => self.limit_absolute_messages,
LimitType::AbsoluteRegister => self.limit_absolute_register,
LimitType::AuthLogin => self.limit_auth_login,
LimitType::AuthRegister => self.limit_auth_register,
LimitType::Channel => self.limit_channel,
LimitType::Error => self.limit_error,
LimitType::Global => self.limit_global,
LimitType::Guild => self.limit_guild,
LimitType::Ip => self.limit_ip,
LimitType::Webhook => self.limit_webhook,
}
}
pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit {
match limit_type {
LimitType::AbsoluteMessage => self.limit_absolute_messages,
LimitType::AbsoluteRegister => self.limit_absolute_register,
LimitType::AuthLogin => self.limit_auth_login,
LimitType::AuthRegister => self.limit_auth_register,
LimitType::Channel => self.limit_channel,
LimitType::Error => self.limit_error,
LimitType::Global => self.limit_global,
LimitType::Guild => self.limit_guild,
LimitType::Ip => self.limit_ip,
LimitType::Webhook => self.limit_webhook,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Limits {
pub limit_absolute_messages: Limit,
pub limit_absolute_register: Limit,
pub limit_auth_login: Limit,
pub limit_auth_register: Limit,
pub limit_ip: Limit,
pub limit_global: Limit,
pub limit_error: Limit,
pub limit_guild: Limit,
pub limit_webhook: Limit,
pub limit_channel: Limit,
}
impl Limits {
pub fn combine(instance_rate_limits: &Limits, user_rate_limits: &Limits) -> Limits {
Limits {
limit_absolute_messages: instance_rate_limits.limit_absolute_messages,
limit_absolute_register: instance_rate_limits.limit_absolute_register,
limit_auth_login: instance_rate_limits.limit_auth_login,
limit_auth_register: instance_rate_limits.limit_auth_register,
limit_channel: user_rate_limits.limit_channel,
limit_error: user_rate_limits.limit_error,
limit_global: instance_rate_limits.limit_global,
limit_guild: user_rate_limits.limit_guild,
limit_ip: instance_rate_limits.limit_ip,
limit_webhook: user_rate_limits.limit_webhook,
}
}
pub fn get_limit_ref(&self, limit_type: &LimitType) -> &Limit {
match limit_type {
LimitType::AbsoluteMessage => &self.limit_absolute_messages,
LimitType::AbsoluteRegister => &self.limit_absolute_register,
LimitType::AuthLogin => &self.limit_auth_login,
LimitType::AuthRegister => &self.limit_auth_register,
LimitType::Channel => &self.limit_channel,
LimitType::Error => &self.limit_error,
LimitType::Global => &self.limit_global,
LimitType::Guild => &self.limit_guild,
LimitType::Ip => &self.limit_ip,
LimitType::Webhook => &self.limit_webhook,
}
}
pub fn get_limit_mut_ref(&mut self, limit_type: &LimitType) -> &mut Limit {
match limit_type {
LimitType::AbsoluteMessage => &mut self.limit_absolute_messages,
LimitType::AbsoluteRegister => &mut self.limit_absolute_register,
LimitType::AuthLogin => &mut self.limit_auth_login,
LimitType::AuthRegister => &mut self.limit_auth_register,
LimitType::Channel => &mut self.limit_channel,
LimitType::Error => &mut self.limit_error,
LimitType::Global => &mut self.limit_global,
LimitType::Guild => &mut self.limit_guild,
LimitType::Ip => &mut self.limit_ip,
LimitType::Webhook => &mut self.limit_webhook,
}
}
pub fn to_hash_map(&self) -> HashMap<LimitType, Limit> {
let mut map: HashMap<LimitType, Limit> = HashMap::new();
map.insert(LimitType::AbsoluteMessage, self.limit_absolute_messages);
map.insert(LimitType::AbsoluteRegister, self.limit_absolute_register);
map.insert(LimitType::AuthLogin, self.limit_auth_login);
map.insert(LimitType::AuthRegister, self.limit_auth_register);
map.insert(LimitType::Ip, self.limit_ip);
map.insert(LimitType::Global, self.limit_global);
map.insert(LimitType::Error, self.limit_error);
map.insert(LimitType::Guild, self.limit_guild);
map.insert(LimitType::Webhook, self.limit_webhook);
map.insert(LimitType::Channel, self.limit_channel);
map
}
pub fn get_as_mut(&mut self) -> &mut Limits {
self
}
/// check_limits uses the API to get the current request limits of the instance.
/// It returns a `Limits` struct containing all the limits.
/// If the rate limit is disabled, then the limit is set to `u64::MAX`.
/// # Errors
/// This function will panic if the request fails or if the response body cannot be parsed.
/// TODO: Change this to return a Result and handle the errors properly.
pub async fn check_limits(api_url: String) -> Limits {
let client = Client::new();
let url_parsed = crate::UrlBundle::parse_url(api_url) + "/policies/instance/limits";
let result = client
.get(url_parsed)
.send()
.await
.unwrap_or_else(|e| panic!("An error occured while performing the request: {}", e))
.text()
.await
.unwrap_or_else(|e| {
panic!(
"An error occured while parsing the request body string: {}",
e
)
});
let config: Config = from_str(&result).unwrap();
// If config.rate.enabled is false, then add return a Limits struct with all limits set to u64::MAX
let mut limits: Limits;
if !config.rate.enabled {
limits = Limits {
limit_absolute_messages: Limit {
bucket: LimitType::AbsoluteMessage,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_absolute_register: Limit {
bucket: LimitType::AbsoluteRegister,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_auth_login: Limit {
bucket: LimitType::AuthLogin,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_auth_register: Limit {
bucket: LimitType::AuthRegister,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_ip: Limit {
bucket: LimitType::Ip,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_global: Limit {
bucket: LimitType::Global,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_error: Limit {
bucket: LimitType::Error,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_guild: Limit {
bucket: LimitType::Guild,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_webhook: Limit {
bucket: LimitType::Webhook,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
limit_channel: Limit {
bucket: LimitType::Channel,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
},
};
} else {
limits = Limits {
limit_absolute_messages: Limit {
bucket: LimitType::AbsoluteMessage,
limit: config.absoluteRate.sendMessage.limit,
remaining: config.absoluteRate.sendMessage.limit,
reset: config.absoluteRate.sendMessage.window,
},
limit_absolute_register: Limit {
bucket: LimitType::AbsoluteRegister,
limit: config.absoluteRate.register.limit,
remaining: config.absoluteRate.register.limit,
reset: config.absoluteRate.register.window,
},
limit_auth_login: Limit {
bucket: LimitType::AuthLogin,
limit: config.rate.routes.auth.login.count,
remaining: config.rate.routes.auth.login.count,
reset: config.rate.routes.auth.login.window,
},
limit_auth_register: Limit {
bucket: LimitType::AuthRegister,
limit: config.rate.routes.auth.register.count,
remaining: config.rate.routes.auth.register.count,
reset: config.rate.routes.auth.register.window,
},
limit_ip: Limit {
bucket: LimitType::Ip,
limit: config.rate.ip.count,
remaining: config.rate.ip.count,
reset: config.rate.ip.window,
},
limit_global: Limit {
bucket: LimitType::Global,
limit: config.rate.global.count,
remaining: config.rate.global.count,
reset: config.rate.global.window,
},
limit_error: Limit {
bucket: LimitType::Error,
limit: config.rate.error.count,
remaining: config.rate.error.count,
reset: config.rate.error.window,
},
limit_guild: Limit {
bucket: LimitType::Guild,
limit: config.rate.routes.guild.count,
remaining: config.rate.routes.guild.count,
reset: config.rate.routes.guild.window,
},
limit_webhook: Limit {
bucket: LimitType::Webhook,
limit: config.rate.routes.webhook.count,
remaining: config.rate.routes.webhook.count,
reset: config.rate.routes.webhook.window,
},
limit_channel: Limit {
bucket: LimitType::Channel,
limit: config.rate.routes.channel.count,
remaining: config.rate.routes.channel.count,
reset: config.rate.routes.channel.window,
},
};
}
if !config.absoluteRate.register.enabled {
limits.limit_absolute_register = Limit {
bucket: LimitType::AbsoluteRegister,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
};
}
if !config.absoluteRate.sendMessage.enabled {
limits.limit_absolute_messages = Limit {
bucket: LimitType::AbsoluteMessage,
limit: u64::MAX,
remaining: u64::MAX,
reset: u64::MAX,
};
}
limits
}
}
}
#[cfg(test)]
mod instance_limits {
use crate::api::limits::{Limit, LimitType};
#[test]
fn limit_below_zero() {
let mut limit = Limit {
bucket: LimitType::AbsoluteMessage,
limit: 0,
remaining: 1,
reset: 0,
};
limit.add_remaining(-2);
assert_eq!(0_u64, limit.remaining);
limit.add_remaining(-2123123);
assert_eq!(0_u64, limit.remaining);
}
}

View File

@ -1,5 +1,5 @@
pub use instance::*;
pub use limits::*;
pub use ratelimits::*;
pub mod instance;
pub mod limits;
pub mod ratelimits;

View File

@ -0,0 +1,35 @@
use std::hash::Hash;
use crate::types::Snowflake;
/// The different types of ratelimits that can be applied to a request. Includes "Baseline"-variants
/// for when the Snowflake is not yet known.
/// See <https://discord.com/developers/docs/topics/rate-limits#rate-limits> for more information.
#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, Hash)]
pub enum LimitType {
AuthRegister,
AuthLogin,
#[default]
Global,
Ip,
Channel(Snowflake),
ChannelBaseline,
Error,
Guild(Snowflake),
GuildBaseline,
Webhook(Snowflake),
WebhookBaseline,
}
/// A struct that represents the current ratelimits, either instance-wide or user-wide.
/// Unlike [`RateLimits`], this struct shows the current ratelimits, not the rate limit
/// configuration for the instance.
/// See <https://discord.com/developers/docs/topics/rate-limits#rate-limits> for more information.
#[derive(Debug, Clone)]
pub struct Limit {
pub bucket: LimitType,
pub limit: u64,
pub remaining: u64,
pub reset: u64,
pub window: u64,
}

View File

@ -1,3 +1,3 @@
pub use instance::limits::*;
pub use instance::ratelimits::*;
pub mod instance;

View File

@ -2,9 +2,10 @@ use reqwest::Client;
use serde_json::to_string;
use crate::{
api::{deserialize_response, handle_request_as_result},
api::LimitType,
errors::ChorusResult,
instance::UserMeta,
ratelimiter::ChorusRequest,
types::{self, CreateUserRelationshipSchema, RelationshipType, Snowflake},
};
@ -26,13 +27,13 @@ impl UserMeta {
self.belongs_to.borrow().urls.api,
user_id
);
let request = Client::new().get(url).bearer_auth(self.token());
deserialize_response::<Vec<types::PublicUser>>(
request,
self,
crate::api::limits::LimitType::Global,
)
.await
let chorus_request = ChorusRequest {
request: Client::new().get(url).bearer_auth(self.token()),
limit_type: LimitType::Global,
};
chorus_request
.deserialize_response::<Vec<types::PublicUser>>(self)
.await
}
/// Retrieves the authenticated user's relationships.
@ -44,13 +45,13 @@ impl UserMeta {
"{}/users/@me/relationships/",
self.belongs_to.borrow().urls.api
);
let request = Client::new().get(url).bearer_auth(self.token());
deserialize_response::<Vec<types::Relationship>>(
request,
self,
crate::api::limits::LimitType::Global,
)
.await
let chorus_request = ChorusRequest {
request: Client::new().get(url).bearer_auth(self.token()),
limit_type: LimitType::Global,
};
chorus_request
.deserialize_response::<Vec<types::Relationship>>(self)
.await
}
/// Sends a friend request to a user.
@ -70,8 +71,11 @@ impl UserMeta {
self.belongs_to.borrow().urls.api
);
let body = to_string(&schema).unwrap();
let request = Client::new().post(url).bearer_auth(self.token()).body(body);
handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await
let chorus_request = ChorusRequest {
request: Client::new().post(url).bearer_auth(self.token()).body(body),
limit_type: LimitType::Global,
};
chorus_request.handle_request_as_result(self).await
}
/// Modifies the relationship between the authenticated user and the specified user.
@ -96,10 +100,13 @@ impl UserMeta {
let api_url = self.belongs_to.borrow().urls.api.clone();
match relationship_type {
RelationshipType::None => {
let request = Client::new()
.delete(format!("{}/users/@me/relationships/{}/", api_url, user_id))
.bearer_auth(self.token());
handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await
let chorus_request = ChorusRequest {
request: Client::new()
.delete(format!("{}/users/@me/relationships/{}/", api_url, user_id))
.bearer_auth(self.token()),
limit_type: LimitType::Global,
};
chorus_request.handle_request_as_result(self).await
}
RelationshipType::Friends | RelationshipType::Incoming | RelationshipType::Outgoing => {
let body = CreateUserRelationshipSchema {
@ -107,11 +114,14 @@ impl UserMeta {
from_friend_suggestion: None,
friend_token: None,
};
let request = Client::new()
.put(format!("{}/users/@me/relationships/{}/", api_url, user_id))
.bearer_auth(self.token())
.body(to_string(&body).unwrap());
handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await
let chorus_request = ChorusRequest {
request: Client::new()
.put(format!("{}/users/@me/relationships/{}/", api_url, user_id))
.bearer_auth(self.token())
.body(to_string(&body).unwrap()),
limit_type: LimitType::Global,
};
chorus_request.handle_request_as_result(self).await
}
RelationshipType::Blocked => {
let body = CreateUserRelationshipSchema {
@ -119,11 +129,14 @@ impl UserMeta {
from_friend_suggestion: None,
friend_token: None,
};
let request = Client::new()
.put(format!("{}/users/@me/relationships/{}/", api_url, user_id))
.bearer_auth(self.token())
.body(to_string(&body).unwrap());
handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await
let chorus_request = ChorusRequest {
request: Client::new()
.put(format!("{}/users/@me/relationships/{}/", api_url, user_id))
.bearer_auth(self.token())
.body(to_string(&body).unwrap()),
limit_type: LimitType::Global,
};
chorus_request.handle_request_as_result(self).await
}
RelationshipType::Suggestion | RelationshipType::Implicit => Ok(()),
}
@ -143,7 +156,10 @@ impl UserMeta {
self.belongs_to.borrow().urls.api,
user_id
);
let request = Client::new().delete(url).bearer_auth(self.token());
handle_request_as_result(request, self, crate::api::limits::LimitType::Global).await
let chorus_request = ChorusRequest {
request: Client::new().delete(url).bearer_auth(self.token()),
limit_type: LimitType::Global,
};
chorus_request.handle_request_as_result(self).await
}
}

View File

@ -1,11 +1,13 @@
use std::{cell::RefCell, rc::Rc};
use reqwest::Client;
use serde_json::to_string;
use crate::{
api::{deserialize_response, handle_request_as_result},
errors::{ChorusLibError, ChorusResult},
api::LimitType,
errors::{ChorusError, ChorusResult},
instance::{Instance, UserMeta},
limit::LimitedRequester,
ratelimiter::ChorusRequest,
types::{User, UserModifySchema, UserSettings},
};
@ -48,16 +50,20 @@ impl UserMeta {
|| modify_schema.email.is_some()
|| modify_schema.code.is_some()
{
return Err(ChorusLibError::PasswordRequiredError);
return Err(ChorusError::PasswordRequired);
}
let request = Client::new()
.patch(format!("{}/users/@me/", self.belongs_to.borrow().urls.api))
.body(to_string(&modify_schema).unwrap())
.bearer_auth(self.token());
let user_updated =
deserialize_response::<User>(request, self, crate::api::limits::LimitType::Ip)
.await
.unwrap();
let chorus_request = ChorusRequest {
request,
limit_type: LimitType::default(),
};
let user_updated = chorus_request
.deserialize_response::<User>(self)
.await
.unwrap();
let _ = std::mem::replace(&mut self.object, user_updated.clone());
Ok(user_updated)
}
@ -78,43 +84,28 @@ impl UserMeta {
self.belongs_to.borrow().urls.api
))
.bearer_auth(self.token());
handle_request_as_result(request, &mut self, crate::api::limits::LimitType::Ip).await
let chorus_request = ChorusRequest {
request,
limit_type: LimitType::default(),
};
chorus_request.handle_request_as_result(&mut self).await
}
}
impl User {
pub async fn get(user: &mut UserMeta, id: Option<&String>) -> ChorusResult<User> {
let mut belongs_to = user.belongs_to.borrow_mut();
User::_get(
&user.token(),
&format!("{}", belongs_to.urls.api),
&mut belongs_to,
id,
)
.await
}
async fn _get(
token: &str,
url_api: &str,
instance: &mut Instance,
id: Option<&String>,
) -> ChorusResult<User> {
let url_api = user.belongs_to.borrow().urls.api.clone();
let url = if id.is_none() {
format!("{}/users/@me/", url_api)
} else {
format!("{}/users/{}", url_api, id.unwrap())
};
let request = reqwest::Client::new().get(url).bearer_auth(token);
let mut cloned_limits = instance.limits.clone();
match LimitedRequester::send_request(
let request = reqwest::Client::new().get(url).bearer_auth(user.token());
let chorus_request = ChorusRequest {
request,
crate::api::limits::LimitType::Ip,
instance,
&mut cloned_limits,
)
.await
{
limit_type: LimitType::Global,
};
match chorus_request.send_request(user).await {
Ok(result) => {
let result_text = result.text().await.unwrap();
Ok(serde_json::from_str::<User>(&result_text).unwrap())
@ -131,18 +122,20 @@ impl User {
let request: reqwest::RequestBuilder = Client::new()
.get(format!("{}/users/@me/settings/", url_api))
.bearer_auth(token);
let mut cloned_limits = instance.limits.clone();
match LimitedRequester::send_request(
let mut user = UserMeta::shell(Rc::new(RefCell::new(instance.clone())), token.clone());
let chorus_request = ChorusRequest {
request,
crate::api::limits::LimitType::Ip,
instance,
&mut cloned_limits,
)
.await
{
limit_type: LimitType::Global,
};
let result = match chorus_request.send_request(&mut user).await {
Ok(result) => Ok(serde_json::from_str(&result.text().await.unwrap()).unwrap()),
Err(e) => Err(e),
};
if instance.limits_information.is_some() {
instance.limits_information.as_mut().unwrap().ratelimits =
user.belongs_to.borrow().clone_limits_if_some().unwrap();
}
result
}
}
@ -158,6 +151,12 @@ impl Instance {
This function is a wrapper around [`User::get`].
*/
pub async fn get_user(&mut self, token: String, id: Option<&String>) -> ChorusResult<User> {
User::_get(&token, &self.urls.api.clone(), self, id).await
let mut user = UserMeta::shell(Rc::new(RefCell::new(self.clone())), token);
let result = User::get(&mut user, id).await;
if self.limits_information.is_some() {
self.limits_information.as_mut().unwrap().ratelimits =
user.belongs_to.borrow().clone_limits_if_some().unwrap();
}
result
}
}

View File

@ -1,39 +1,50 @@
use custom_error::custom_error;
use reqwest::Error;
custom_error! {
#[derive(PartialEq, Eq)]
pub FieldFormatError
PasswordError = "Password must be between 1 and 72 characters.",
UsernameError = "Username must be between 2 and 32 characters.",
ConsentError = "Consent must be 'true' to register.",
EmailError = "The provided email address is in an invalid format.",
pub RegistrationError
Consent = "Consent must be 'true' to register.",
}
pub type ChorusResult<T> = std::result::Result<T, ChorusLibError>;
pub type ChorusResult<T> = std::result::Result<T, ChorusError>;
custom_error! {
#[derive(PartialEq, Eq)]
pub ChorusLibError
pub ChorusError
/// Server did not respond.
NoResponse = "Did not receive a response from the Server.",
RequestErrorError{url:String, error:String} = "An error occured while trying to GET from {url}: {error}",
ReceivedErrorCodeError{error_code:String} = "Received the following error code while requesting from the route: {error_code}",
CantGetInfoError{error:String} = "Something seems to be wrong with the instance. Cannot get information about the instance: {error}",
InvalidFormBodyError{error_type: String, error:String} = "The server responded with: {error_type}: {error}",
/// Reqwest returned an Error instead of a Response object.
RequestFailed{url:String, error: Error} = "An error occured while trying to GET from {url}: {error}",
/// Response received, however, it was not of the successful responses type. Used when no other, special case applies.
ReceivedErrorCode{error_code: u16, error: String} = "Received the following error code while requesting from the route: {error_code}",
/// Used when there is likely something wrong with the instance, the request was directed to.
CantGetInformation{error:String} = "Something seems to be wrong with the instance. Cannot get information about the instance: {error}",
/// The requests form body was malformed/invalid.
InvalidFormBody{error_type: String, error:String} = "The server responded with: {error_type}: {error}",
/// The request has not been processed by the server due to a relevant rate limit bucket being exhausted.
RateLimited{bucket:String} = "Ratelimited on Bucket {bucket}",
MultipartCreationError{error: String} = "Got an error whilst creating the form: {error}",
FormCreationError{error: String} = "Got an error whilst creating the form: {error}",
/// The multipart form could not be created.
MultipartCreation{error: String} = "Got an error whilst creating the form: {error}",
/// The regular form could not be created.
FormCreation{error: String} = "Got an error whilst creating the form: {error}",
/// The token is invalid.
TokenExpired = "Token expired, invalid or not found.",
/// No permission
NoPermission = "You do not have the permissions needed to perform this action.",
/// Resource not found
NotFound{error: String} = "The provided resource hasn't been found: {error}",
PasswordRequiredError = "You need to provide your current password to authenticate for this action.",
InvalidResponseError{error: String} = "The response is malformed and cannot be processed. Error: {error}",
InvalidArgumentsError{error: String} = "Invalid arguments were provided. Error: {error}"
/// Used when you, for example, try to change your spacebar account password without providing your old password for verification.
PasswordRequired = "You need to provide your current password to authenticate for this action.",
/// Malformed or unexpected response.
InvalidResponse{error: String} = "The response is malformed and cannot be processed. Error: {error}",
/// Invalid, insufficient or too many arguments provided.
InvalidArguments{error: String} = "Invalid arguments were provided. Error: {error}"
}
custom_error! {
#[derive(PartialEq, Eq)]
pub ObserverError
AlreadySubscribedError = "Each event can only be subscribed to once."
AlreadySubscribed = "Each event can only be subscribed to once."
}
custom_error! {
@ -45,25 +56,25 @@ custom_error! {
#[derive(PartialEq, Eq)]
pub GatewayError
// Errors we have received from the gateway
UnknownError = "We're not sure what went wrong. Try reconnecting?",
UnknownOpcodeError = "You sent an invalid Gateway opcode or an invalid payload for an opcode",
DecodeError = "Gateway server couldn't decode payload",
NotAuthenticatedError = "You sent a payload prior to identifying",
AuthenticationFailedError = "The account token sent with your identify payload is invalid",
AlreadyAuthenticatedError = "You've already identified, no need to reauthenticate",
InvalidSequenceNumberError = "The sequence number sent when resuming the session was invalid. Reconnect and start a new session",
RateLimitedError = "You are being rate limited!",
SessionTimedOutError = "Your session timed out. Reconnect and start a new one",
InvalidShardError = "You sent us an invalid shard when identifying",
ShardingRequiredError = "The session would have handled too many guilds - you are required to shard your connection in order to connect",
InvalidAPIVersionError = "You sent an invalid Gateway version",
InvalidIntentsError = "You sent an invalid intent",
DisallowedIntentsError = "You sent a disallowed intent. You may have tried to specify an intent that you have not enabled or are not approved for",
Unknown = "We're not sure what went wrong. Try reconnecting?",
UnknownOpcode = "You sent an invalid Gateway opcode or an invalid payload for an opcode",
Decode = "Gateway server couldn't decode payload",
NotAuthenticated = "You sent a payload prior to identifying",
AuthenticationFailed = "The account token sent with your identify payload is invalid",
AlreadyAuthenticated = "You've already identified, no need to reauthenticate",
InvalidSequenceNumber = "The sequence number sent when resuming the session was invalid. Reconnect and start a new session",
RateLimited = "You are being rate limited!",
SessionTimedOut = "Your session timed out. Reconnect and start a new one",
InvalidShard = "You sent us an invalid shard when identifying",
ShardingRequired = "The session would have handled too many guilds - you are required to shard your connection in order to connect",
InvalidAPIVersion = "You sent an invalid Gateway version",
InvalidIntents = "You sent an invalid intent",
DisallowedIntents = "You sent a disallowed intent. You may have tried to specify an intent that you have not enabled or are not approved for",
// Errors when initiating a gateway connection
CannotConnectError{error: String} = "Cannot connect due to a tungstenite error: {error}",
NonHelloOnInitiateError{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong",
CannotConnect{error: String} = "Cannot connect due to a tungstenite error: {error}",
NonHelloOnInitiate{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong",
// Other misc errors
UnexpectedOpcodeReceivedError{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}",
UnexpectedOpcodeReceived{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}",
}

View File

@ -8,6 +8,7 @@ use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream;
use futures_util::SinkExt;
use futures_util::StreamExt;
use log::{info, trace, warn};
use native_tls::TlsConnector;
use tokio::net::TcpStream;
use tokio::sync::mpsc::error::TryRecvError;
@ -94,25 +95,21 @@ impl GatewayMessage {
let processed_content = content.to_lowercase().replace('.', "");
match processed_content.as_str() {
"unknown error" | "4000" => Some(GatewayError::UnknownError),
"unknown opcode" | "4001" => Some(GatewayError::UnknownOpcodeError),
"decode error" | "error while decoding payload" | "4002" => {
Some(GatewayError::DecodeError)
}
"not authenticated" | "4003" => Some(GatewayError::NotAuthenticatedError),
"authentication failed" | "4004" => Some(GatewayError::AuthenticationFailedError),
"already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticatedError),
"invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumberError),
"rate limited" | "4008" => Some(GatewayError::RateLimitedError),
"session timed out" | "4009" => Some(GatewayError::SessionTimedOutError),
"invalid shard" | "4010" => Some(GatewayError::InvalidShardError),
"sharding required" | "4011" => Some(GatewayError::ShardingRequiredError),
"invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersionError),
"invalid intent(s)" | "invalid intent" | "4013" => {
Some(GatewayError::InvalidIntentsError)
}
"unknown error" | "4000" => Some(GatewayError::Unknown),
"unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode),
"decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode),
"not authenticated" | "4003" => Some(GatewayError::NotAuthenticated),
"authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed),
"already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated),
"invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber),
"rate limited" | "4008" => Some(GatewayError::RateLimited),
"session timed out" | "4009" => Some(GatewayError::SessionTimedOut),
"invalid shard" | "4010" => Some(GatewayError::InvalidShard),
"sharding required" | "4011" => Some(GatewayError::ShardingRequired),
"invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion),
"invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents),
"disallowed intent(s)" | "disallowed intents" | "4014" => {
Some(GatewayError::DisallowedIntentsError)
Some(GatewayError::DisallowedIntents)
}
_ => None,
}
@ -191,7 +188,7 @@ impl GatewayHandle {
pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Identify..");
trace!("GW: Sending Identify..");
self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await;
}
@ -200,7 +197,7 @@ impl GatewayHandle {
pub async fn send_resume(&self, to_send: types::GatewayResume) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Resume..");
trace!("GW: Sending Resume..");
self.send_json_event(GATEWAY_RESUME, to_send_value).await;
}
@ -209,7 +206,7 @@ impl GatewayHandle {
pub async fn send_update_presence(&self, to_send: types::UpdatePresence) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Update Presence..");
trace!("GW: Sending Update Presence..");
self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value)
.await;
@ -219,7 +216,7 @@ impl GatewayHandle {
pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Request Guild Members..");
trace!("GW: Sending Request Guild Members..");
self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value)
.await;
@ -229,7 +226,7 @@ impl GatewayHandle {
pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Update Voice State..");
trace!("GW: Sending Update Voice State..");
self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value)
.await;
@ -239,7 +236,7 @@ impl GatewayHandle {
pub async fn send_call_sync(&self, to_send: types::CallSync) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Call Sync..");
trace!("GW: Sending Call Sync..");
self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await;
}
@ -248,7 +245,7 @@ impl GatewayHandle {
pub async fn send_lazy_request(&self, to_send: types::LazyRequest) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Lazy Request..");
trace!("GW: Sending Lazy Request..");
self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value)
.await;
@ -293,7 +290,7 @@ impl Gateway {
{
Ok(websocket_stream) => websocket_stream,
Err(e) => {
return Err(GatewayError::CannotConnectError {
return Err(GatewayError::CannotConnect {
error: e.to_string(),
})
}
@ -313,12 +310,12 @@ impl Gateway {
serde_json::from_str(msg.to_text().unwrap()).unwrap();
if gateway_payload.op_code != GATEWAY_HELLO {
return Err(GatewayError::NonHelloOnInitiateError {
return Err(GatewayError::NonHelloOnInitiate {
opcode: gateway_payload.op_code,
});
}
println!("GW: Received Hello");
info!("GW: Received Hello");
let gateway_hello: types::HelloData =
serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap();
@ -367,7 +364,7 @@ impl Gateway {
}
// We couldn't receive the next message or it was an error, something is wrong with the websocket, close
println!("GW: Websocket is broken, stopping gateway");
warn!("GW: Websocket is broken, stopping gateway");
break;
}
}
@ -401,7 +398,7 @@ impl Gateway {
}
if !msg.is_error() && !msg.is_payload() {
println!(
warn!(
"Message unrecognised: {:?}, please open an issue on the chorus github",
msg.message.to_string()
);
@ -410,12 +407,10 @@ impl Gateway {
// To:do: handle errors in a good way, maybe observers like events?
if msg.is_error() {
println!("GW: Received error, connection will close..");
warn!("GW: Received error, connection will close..");
let _error = msg.error();
{}
self.close().await;
return;
}
@ -428,7 +423,7 @@ impl Gateway {
GATEWAY_DISPATCH => {
let gateway_payload_t = gateway_payload.clone().event_name.unwrap();
println!("GW: Received {}..", gateway_payload_t);
trace!("GW: Received {}..", gateway_payload_t);
//println!("Event data dump: {}", gateway_payload.d.clone().unwrap().get());
@ -443,7 +438,7 @@ impl Gateway {
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -459,7 +454,7 @@ impl Gateway {
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -481,7 +476,7 @@ impl Gateway {
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -495,7 +490,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -509,7 +504,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -523,7 +518,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -537,7 +532,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -551,7 +546,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -565,7 +560,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -579,7 +574,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -593,7 +588,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -607,7 +602,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -621,7 +616,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -635,7 +630,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -649,7 +644,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -663,7 +658,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -677,7 +672,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -691,7 +686,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -705,7 +700,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -719,7 +714,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -733,7 +728,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -747,7 +742,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -761,7 +756,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -775,7 +770,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -789,7 +784,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -803,7 +798,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -817,7 +812,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -831,7 +826,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -845,7 +840,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -859,7 +854,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -873,7 +868,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -887,7 +882,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -901,7 +896,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -915,7 +910,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -929,7 +924,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -943,7 +938,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -957,7 +952,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -971,7 +966,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -985,7 +980,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -999,7 +994,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1014,7 +1009,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1033,7 +1028,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1047,7 +1042,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1061,7 +1056,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1075,7 +1070,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1089,7 +1084,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1103,7 +1098,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1117,7 +1112,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1131,7 +1126,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1145,7 +1140,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1159,7 +1154,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1173,7 +1168,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1187,7 +1182,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1201,7 +1196,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1215,7 +1210,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1229,7 +1224,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1243,7 +1238,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1257,7 +1252,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1271,7 +1266,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1285,7 +1280,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1299,7 +1294,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1313,7 +1308,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1327,7 +1322,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1341,7 +1336,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1353,7 +1348,7 @@ impl Gateway {
let result: Result<Vec<types::Session>, serde_json::Error> =
serde_json::from_str(gateway_payload.event_data.unwrap().get());
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1373,7 +1368,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1387,7 +1382,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1401,7 +1396,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1415,7 +1410,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1429,7 +1424,7 @@ impl Gateway {
Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event)
.await;
if result.is_err() {
println!(
warn!(
"Failed to parse gateway event {} ({})",
gateway_payload_t,
result.err().unwrap()
@ -1438,14 +1433,14 @@ impl Gateway {
}
}
_ => {
println!("Received unrecognized gateway event ({})! Please open an issue on the chorus github so we can implement it", &gateway_payload_t);
warn!("Received unrecognized gateway event ({})! Please open an issue on the chorus github so we can implement it", &gateway_payload_t);
}
}
}
// We received a heartbeat from the server
// "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately."
GATEWAY_HEARTBEAT => {
println!("GW: Received Heartbeat // Heartbeat Request");
trace!("GW: Received Heartbeat // Heartbeat Request");
// Tell the heartbeat handler it should send a heartbeat right away
@ -1469,10 +1464,10 @@ impl Gateway {
// Starts our heartbeat
// We should have already handled this in gateway init
GATEWAY_HELLO => {
panic!("Received hello when it was unexpected");
warn!("Received hello when it was unexpected");
}
GATEWAY_HEARTBEAT_ACK => {
println!("GW: Received Heartbeat ACK");
trace!("GW: Received Heartbeat ACK");
// Tell the heartbeat handler we received an ack
@ -1494,13 +1489,13 @@ impl Gateway {
| GATEWAY_REQUEST_GUILD_MEMBERS
| GATEWAY_CALL_SYNC
| GATEWAY_LAZY_REQUEST => {
let error = GatewayError::UnexpectedOpcodeReceivedError {
let error = GatewayError::UnexpectedOpcodeReceived {
opcode: gateway_payload.op_code,
};
Err::<(), GatewayError>(error).unwrap();
}
_ => {
println!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code);
warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code);
}
}
@ -1522,6 +1517,7 @@ impl Gateway {
}
/// Handles sending heartbeats to the gateway in another thread
#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used
struct HeartbeatHandler {
/// The heartbeat interval in milliseconds
pub heartbeat_interval: u128,
@ -1588,6 +1584,7 @@ impl HeartbeatHandler {
loop {
let should_shutdown = kill_receive.try_recv().is_ok();
if should_shutdown {
trace!("GW: Closing heartbeat task");
break;
}
@ -1627,11 +1624,11 @@ impl HeartbeatHandler {
&& last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT
{
should_send = true;
println!("GW: Timed out waiting for a heartbeat ack, resending");
info!("GW: Timed out waiting for a heartbeat ack, resending");
}
if should_send {
println!("GW: Sending Heartbeat..");
trace!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
@ -1645,7 +1642,7 @@ impl HeartbeatHandler {
let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
println!("GW: Couldnt send heartbeat, websocket seems broken");
warn!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
@ -1879,7 +1876,7 @@ mod example {
#[derive(Debug)]
struct Consumer {
name: String,
_name: String,
events_received: AtomicI32,
}
@ -1900,13 +1897,13 @@ mod example {
};
let consumer = Arc::new(Consumer {
name: "first".into(),
_name: "first".into(),
events_received: 0.into(),
});
event.subscribe(consumer.clone());
let second_consumer = Arc::new(Consumer {
name: "second".into(),
_name: "second".into(),
events_received: 0.into(),
});
event.subscribe(second_consumer.clone());

View File

@ -1,26 +1,36 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt;
use std::rc::Rc;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::api::limits::Limits;
use crate::errors::{ChorusLibError, ChorusResult, FieldFormatError};
use crate::api::{Limit, LimitType};
use crate::errors::ChorusResult;
use crate::ratelimiter::ChorusRequest;
use crate::types::types::subconfigs::limits::rates::RateLimits;
use crate::types::{GeneralConfiguration, User, UserSettings};
use crate::UrlBundle;
#[derive(Debug, Clone)]
/**
The [`Instance`] what you will be using to perform all sorts of actions on the Spacebar server.
If `limits_information` is `None`, then the instance will not be rate limited.
*/
pub struct Instance {
pub urls: UrlBundle,
pub instance_info: GeneralConfiguration,
pub limits: Limits,
pub limits_information: Option<LimitsInformation>,
pub client: Client,
}
#[derive(Debug, Clone)]
pub struct LimitsInformation {
pub ratelimits: HashMap<LimitType, Limit>,
pub configuration: RateLimits,
}
impl Instance {
/// Creates a new [`Instance`].
/// # Arguments
@ -28,24 +38,43 @@ impl Instance {
/// * `requester` - The [`LimitedRequester`] that will be used to make requests to the Spacebar server.
/// # Errors
/// * [`InstanceError`] - If the instance cannot be created.
pub async fn new(urls: UrlBundle) -> ChorusResult<Instance> {
pub async fn new(urls: UrlBundle, limited: bool) -> ChorusResult<Instance> {
let limits_information;
if limited {
let limits_configuration =
Some(ChorusRequest::get_limits_config(&urls.api).await?.rate);
let limits = Some(ChorusRequest::limits_config_to_hashmap(
limits_configuration.as_ref().unwrap(),
));
limits_information = Some(LimitsInformation {
ratelimits: limits.unwrap(),
configuration: limits_configuration.unwrap(),
});
} else {
limits_information = None;
}
let mut instance = Instance {
urls: urls.clone(),
// Will be overwritten in the next step
instance_info: GeneralConfiguration::default(),
limits: Limits::check_limits(urls.api).await,
limits_information,
client: Client::new(),
};
instance.instance_info = match instance.general_configuration_schema().await {
Ok(schema) => schema,
Err(e) => {
return Err(ChorusLibError::CantGetInfoError {
error: e.to_string(),
});
log::warn!("Could not get instance configuration schema: {}", e);
GeneralConfiguration::default()
}
};
Ok(instance)
}
pub(crate) fn clone_limits_if_some(&self) -> Option<HashMap<LimitType, Limit>> {
if self.limits_information.is_some() {
return Some(self.limits_information.as_ref().unwrap().ratelimits.clone());
}
None
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
@ -59,30 +88,11 @@ impl fmt::Display for Token {
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Username {
pub username: String,
}
impl Username {
/// Creates a new [`Username`].
/// # Arguments
/// * `username` - The username that will be used to create the [`Username`].
/// # Errors
/// * [`UsernameFormatError`] - If the username is not between 2 and 32 characters.
pub fn new(username: String) -> Result<Username, FieldFormatError> {
if username.len() < 2 || username.len() > 32 {
return Err(FieldFormatError::UsernameError);
}
Ok(Username { username })
}
}
#[derive(Debug)]
pub struct UserMeta {
pub belongs_to: Rc<RefCell<Instance>>,
pub token: String,
pub limits: Limits,
pub limits: Option<HashMap<LimitType, Limit>>,
pub settings: UserSettings,
pub object: User,
}
@ -99,7 +109,7 @@ impl UserMeta {
pub fn new(
belongs_to: Rc<RefCell<Instance>>,
token: String,
limits: Limits,
limits: Option<HashMap<LimitType, Limit>>,
settings: UserSettings,
object: User,
) -> UserMeta {
@ -111,4 +121,24 @@ impl UserMeta {
object,
}
}
/// Creates a new 'shell' of a user. The user does not exist as an object, and exists so that you have
/// a UserMeta object to make Rate Limited requests with. This is useful in scenarios like
/// registering or logging in to the Instance, where you do not yet have a User object, but still
/// need to make a RateLimited request.
pub(crate) fn shell(instance: Rc<RefCell<Instance>>, token: String) -> UserMeta {
let settings = UserSettings::default();
let object = User::default();
UserMeta {
belongs_to: instance.clone(),
token,
limits: instance
.borrow()
.limits_information
.as_ref()
.map(|info| info.ratelimits.clone()),
settings,
object,
}
}
}

View File

@ -10,7 +10,7 @@ pub mod gateway;
#[cfg(feature = "client")]
pub mod instance;
#[cfg(feature = "client")]
pub mod limit;
pub mod ratelimiter;
pub mod types;
#[cfg(feature = "client")]
pub mod voice;

View File

@ -1,304 +0,0 @@
use reqwest::{RequestBuilder, Response};
use crate::{
api::limits::{Limit, LimitType, Limits, LimitsMutRef},
errors::{ChorusLibError, ChorusResult},
instance::Instance,
};
#[derive(Debug)]
pub struct LimitedRequester;
impl LimitedRequester {
/// Checks if a request can be sent without hitting API rate limits and sends it, if true.
/// Will automatically update the rate limits of the LimitedRequester the request has been
/// sent with.
///
/// # Arguments
///
/// * `request`: A `RequestBuilder` that contains a request ready to be sent. Unfinished or
/// invalid requests will result in the method panicing.
/// * `limit_type`: Because this library does not yet implement a way to check for which rate
/// limit will be used when the request gets send, you will have to specify this manually using
/// a `LimitType` enum.
///
/// # Returns
///
/// * `Response`: The `Response` gotten from sending the request to the server. This will be
/// returned if the Request was built and send successfully. Is wrapped in an `Option`.
/// * `None`: `None` will be returned if the rate limit has been hit, and the request could
/// therefore not have been sent.
///
/// # Errors
///
/// This method will error if:
///
/// * The request does not return a success status code (200-299)
/// * The supplied `RequestBuilder` contains invalid or incomplete information
/// * There has been an error with processing (unwrapping) the `Response`
/// * The call to `update_limits` yielded errors. Read the methods' Errors section for more
/// information.
pub async fn send_request(
request: RequestBuilder,
limit_type: LimitType,
instance: &mut Instance,
user_rate_limits: &mut Limits,
) -> ChorusResult<Response> {
if LimitedRequester::can_send_request(limit_type, &instance.limits, user_rate_limits) {
let built_request = match request.build() {
Ok(request) => request,
Err(e) => {
return Err(ChorusLibError::RequestErrorError {
url: "".to_string(),
error: e.to_string(),
});
}
};
let result = instance.client.execute(built_request).await;
let response = match result {
Ok(is_response) => is_response,
Err(e) => {
return Err(ChorusLibError::ReceivedErrorCodeError {
error_code: e.to_string(),
});
}
};
LimitedRequester::update_limits(
&response,
limit_type,
&mut instance.limits,
user_rate_limits,
);
if !response.status().is_success() {
match response.status().as_u16() {
401 => Err(ChorusLibError::TokenExpired),
403 => Err(ChorusLibError::TokenExpired),
_ => Err(ChorusLibError::ReceivedErrorCodeError {
error_code: response.status().as_str().to_string(),
}),
}
} else {
Ok(response)
}
} else {
Err(ChorusLibError::RateLimited {
bucket: limit_type.to_string(),
})
}
}
fn update_limit_entry(entry: &mut Limit, reset: u64, remaining: u64, limit: u64) {
if reset != entry.reset {
entry.reset = reset;
entry.remaining = limit;
entry.limit = limit;
} else {
entry.remaining = remaining;
entry.limit = limit;
}
}
fn can_send_request(
limit_type: LimitType,
instance_rate_limits: &Limits,
user_rate_limits: &Limits,
) -> bool {
// Check if all of the limits in this vec have at least one remaining request
let rate_limits = Limits::combine(instance_rate_limits, user_rate_limits);
let constant_limits: Vec<&LimitType> = [
&LimitType::Error,
&LimitType::Global,
&LimitType::Ip,
&limit_type,
]
.to_vec();
for limit in constant_limits.iter() {
match rate_limits.to_hash_map().get(limit) {
Some(limit) => {
if limit.remaining == 0 {
return false;
}
// AbsoluteRegister and AuthRegister can cancel each other out.
if limit.bucket == LimitType::AbsoluteRegister
&& rate_limits
.to_hash_map()
.get(&LimitType::AuthRegister)
.unwrap()
.remaining
== 0
{
return false;
}
if limit.bucket == LimitType::AuthRegister
&& rate_limits
.to_hash_map()
.get(&LimitType::AbsoluteRegister)
.unwrap()
.remaining
== 0
{
return false;
}
}
None => return false,
}
}
true
}
fn update_limits(
response: &Response,
limit_type: LimitType,
instance_rate_limits: &mut Limits,
user_rate_limits: &mut Limits,
) {
let mut rate_limits = LimitsMutRef::combine_mut_ref(instance_rate_limits, user_rate_limits);
let remaining = match response.headers().get("X-RateLimit-Remaining") {
Some(remaining) => remaining.to_str().unwrap().parse::<u64>().unwrap(),
None => rate_limits.get_limit_mut_ref(&limit_type).remaining - 1,
};
let limit = match response.headers().get("X-RateLimit-Limit") {
Some(limit) => limit.to_str().unwrap().parse::<u64>().unwrap(),
None => rate_limits.get_limit_mut_ref(&limit_type).limit,
};
let reset = match response.headers().get("X-RateLimit-Reset") {
Some(reset) => reset.to_str().unwrap().parse::<u64>().unwrap(),
None => rate_limits.get_limit_mut_ref(&limit_type).reset,
};
let status = response.status();
let status_str = status.as_str();
if status_str.starts_with('4') {
rate_limits
.get_limit_mut_ref(&LimitType::Error)
.add_remaining(-1);
}
rate_limits
.get_limit_mut_ref(&LimitType::Global)
.add_remaining(-1);
rate_limits
.get_limit_mut_ref(&LimitType::Ip)
.add_remaining(-1);
match limit_type {
LimitType::Error => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::Error);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
}
LimitType::Global => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::Global);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
}
LimitType::Ip => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::Ip);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
}
LimitType::AuthLogin => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthLogin);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
}
LimitType::AbsoluteRegister => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteRegister);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
// AbsoluteRegister and AuthRegister both need to be updated, if a Register event
// happens.
rate_limits
.get_limit_mut_ref(&LimitType::AuthRegister)
.remaining -= 1;
}
LimitType::AuthRegister => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::AuthRegister);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
// AbsoluteRegister and AuthRegister both need to be updated, if a Register event
// happens.
rate_limits
.get_limit_mut_ref(&LimitType::AbsoluteRegister)
.remaining -= 1;
}
LimitType::AbsoluteMessage => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::AbsoluteMessage);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
}
LimitType::Channel => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::Channel);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
}
LimitType::Guild => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::Guild);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
}
LimitType::Webhook => {
let entry = rate_limits.get_limit_mut_ref(&LimitType::Webhook);
LimitedRequester::update_limit_entry(entry, reset, remaining, limit);
}
}
}
}
#[cfg(test)]
mod rate_limit {
use serde_json::from_str;
use crate::{api::limits::Config, UrlBundle};
use super::*;
#[tokio::test]
async fn run_into_limit() {
let urls = UrlBundle::new(
String::from("http://localhost:3001/api/"),
String::from("wss://localhost:3001/"),
String::from("http://localhost:3001/cdn"),
);
let mut request: Option<ChorusResult<Response>> = None;
let mut instance = Instance::new(urls.clone()).await.unwrap();
let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await;
for _ in 0..=50 {
let request_path = urls.api.clone() + "/some/random/nonexisting/path";
let request_builder = instance.client.get(request_path);
request = Some(
LimitedRequester::send_request(
request_builder,
LimitType::Channel,
&mut instance,
&mut user_rate_limits,
)
.await,
);
}
assert!(matches!(request, Some(Err(_))));
}
#[tokio::test]
async fn test_send_request() {
let urls = UrlBundle::new(
String::from("http://localhost:3001/api/"),
String::from("wss://localhost:3001/"),
String::from("http://localhost:3001/cdn"),
);
let mut instance = Instance::new(urls.clone()).await.unwrap();
let mut user_rate_limits = Limits::check_limits(urls.api.clone()).await;
let _requester = LimitedRequester;
let request_path = urls.api.clone() + "/policies/instance/limits";
let request_builder = instance.client.get(request_path);
let request = LimitedRequester::send_request(
request_builder,
LimitType::Channel,
&mut instance,
&mut user_rate_limits,
)
.await;
let result = match request {
Ok(result) => result,
Err(_) => panic!("Request failed"),
};
let _config: Config = from_str(result.text().await.unwrap().as_str()).unwrap();
}
}

462
src/ratelimiter.rs Normal file
View File

@ -0,0 +1,462 @@
use std::collections::HashMap;
use log;
use reqwest::{Client, RequestBuilder, Response};
use serde::Deserialize;
use serde_json::from_str;
use crate::{
api::{Limit, LimitType},
errors::{ChorusError, ChorusResult},
instance::UserMeta,
types::{types::subconfigs::limits::rates::RateLimits, LimitsConfiguration},
};
/// Chorus' request struct. This struct is used to send rate-limited requests to the Spacebar server.
/// See <https://discord.com/developers/docs/topics/rate-limits#rate-limits> for more information.
pub struct ChorusRequest {
pub request: RequestBuilder,
pub limit_type: LimitType,
}
impl ChorusRequest {
/// Sends a [`ChorusRequest`]. Checks if the user is rate limited, and if not, sends the request.
/// If the user is not rate limited and the instance has rate limits enabled, it will update the
/// rate limits.
#[allow(clippy::await_holding_refcell_ref)]
pub(crate) async fn send_request(self, user: &mut UserMeta) -> ChorusResult<Response> {
if !ChorusRequest::can_send_request(user, &self.limit_type) {
log::info!("Rate limit hit. Bucket: {:?}", self.limit_type);
return Err(ChorusError::RateLimited {
bucket: format!("{:?}", self.limit_type),
});
}
let belongs_to = user.belongs_to.borrow();
let result = match belongs_to
.client
.execute(self.request.build().unwrap())
.await
{
Ok(result) => result,
Err(error) => {
log::warn!("Request failed: {:?}", error);
return Err(ChorusError::RequestFailed {
url: error.url().unwrap().to_string(),
error,
});
}
};
drop(belongs_to);
if !result.status().is_success() {
if result.status().as_u16() == 429 {
log::warn!("Rate limit hit unexpectedly. Bucket: {:?}. Setting the instances' remaining global limit to 0 to have cooldown.", self.limit_type);
user.belongs_to
.borrow_mut()
.limits_information
.as_mut()
.unwrap()
.ratelimits
.get_mut(&LimitType::Global)
.unwrap()
.remaining = 0;
return Err(ChorusError::RateLimited {
bucket: format!("{:?}", self.limit_type),
});
}
log::warn!("Request failed: {:?}", result);
return Err(ChorusRequest::interpret_error(result).await);
}
ChorusRequest::update_rate_limits(user, &self.limit_type, !result.status().is_success());
Ok(result)
}
fn can_send_request(user: &mut UserMeta, limit_type: &LimitType) -> bool {
log::trace!("Checking if user or instance is rate-limited...");
let mut belongs_to = user.belongs_to.borrow_mut();
if belongs_to.limits_information.is_none() {
log::trace!("Instance indicates no rate limits are configured. Continuing.");
return true;
}
let instance_dictated_limits = [
&LimitType::AuthLogin,
&LimitType::AuthRegister,
&LimitType::Global,
&LimitType::Ip,
];
let limits = match instance_dictated_limits.contains(&limit_type) {
true => {
log::trace!(
"Limit type {:?} is dictated by the instance. Continuing.",
limit_type
);
belongs_to
.limits_information
.as_mut()
.unwrap()
.ratelimits
.clone()
}
false => {
log::trace!(
"Limit type {:?} is dictated by the user. Continuing.",
limit_type
);
ChorusRequest::ensure_limit_in_map(
&belongs_to
.limits_information
.as_ref()
.unwrap()
.configuration,
user.limits.as_mut().unwrap(),
limit_type,
);
user.limits.as_mut().unwrap().clone()
}
};
let global = belongs_to
.limits_information
.as_ref()
.unwrap()
.ratelimits
.get(&LimitType::Global)
.unwrap();
let ip = belongs_to
.limits_information
.as_ref()
.unwrap()
.ratelimits
.get(&LimitType::Ip)
.unwrap();
let limit_type_limit = limits.get(limit_type).unwrap();
global.remaining > 0 && ip.remaining > 0 && limit_type_limit.remaining > 0
}
fn ensure_limit_in_map(
rate_limits_config: &RateLimits,
map: &mut HashMap<LimitType, Limit>,
limit_type: &LimitType,
) {
log::trace!("Ensuring limit type {:?} is in the map.", limit_type);
let time: u64 = chrono::Utc::now().timestamp() as u64;
match limit_type {
LimitType::Channel(snowflake) => {
if map.get(&LimitType::Channel(*snowflake)).is_some() {
log::trace!(
"Limit type {:?} is already in the map. Returning.",
limit_type
);
return;
}
log::trace!("Limit type {:?} is not in the map. Adding it.", limit_type);
let channel_limit = &rate_limits_config.routes.channel;
map.insert(
LimitType::Channel(*snowflake),
Limit {
bucket: LimitType::Channel(*snowflake),
limit: channel_limit.count,
remaining: channel_limit.count,
reset: channel_limit.window + time,
window: channel_limit.window,
},
);
}
LimitType::Guild(snowflake) => {
if map.get(&LimitType::Guild(*snowflake)).is_some() {
return;
}
let guild_limit = &rate_limits_config.routes.guild;
map.insert(
LimitType::Guild(*snowflake),
Limit {
bucket: LimitType::Guild(*snowflake),
limit: guild_limit.count,
remaining: guild_limit.count,
reset: guild_limit.window + time,
window: guild_limit.window,
},
);
}
LimitType::Webhook(snowflake) => {
if map.get(&LimitType::Webhook(*snowflake)).is_some() {
return;
}
let webhook_limit = &rate_limits_config.routes.webhook;
map.insert(
LimitType::Webhook(*snowflake),
Limit {
bucket: LimitType::Webhook(*snowflake),
limit: webhook_limit.count,
remaining: webhook_limit.count,
reset: webhook_limit.window + time,
window: webhook_limit.window,
},
);
}
other_limit => {
if map.get(other_limit).is_some() {
return;
}
let limits_map = ChorusRequest::limits_config_to_hashmap(rate_limits_config);
map.insert(
*other_limit,
Limit {
bucket: *other_limit,
limit: limits_map.get(other_limit).as_ref().unwrap().limit,
remaining: limits_map.get(other_limit).as_ref().unwrap().remaining,
reset: limits_map.get(other_limit).as_ref().unwrap().reset,
window: limits_map.get(other_limit).as_ref().unwrap().window,
},
);
}
}
}
async fn interpret_error(response: reqwest::Response) -> ChorusError {
match response.status().as_u16() {
401..=403 | 407 => ChorusError::NoPermission,
404 => ChorusError::NotFound {
error: response.text().await.unwrap(),
},
405 | 408 | 409 => ChorusError::ReceivedErrorCode { error_code: response.status().as_u16(), error: response.text().await.unwrap() },
411..=421 | 426 | 428 | 431 => ChorusError::InvalidArguments {
error: response.text().await.unwrap(),
},
429 => panic!("Illegal state: Rate limit exception should have been caught before this function call."),
451 => ChorusError::NoResponse,
500..=599 => ChorusError::ReceivedErrorCode { error_code: response.status().as_u16(), error: response.text().await.unwrap() },
_ => ChorusError::ReceivedErrorCode { error_code: response.status().as_u16(), error: response.text().await.unwrap()},
}
}
/// Updates the rate limits of the user. The following steps are performed:
/// 1. If the current unix timestamp is greater than the reset timestamp, the reset timestamp is
/// set to the current unix timestamp + the rate limit window. The remaining rate limit is
/// reset to the rate limit limit.
/// 2. The remaining rate limit is decreased by 1.
fn update_rate_limits(user: &mut UserMeta, limit_type: &LimitType, response_was_err: bool) {
let instance_dictated_limits = [
&LimitType::AuthLogin,
&LimitType::AuthRegister,
&LimitType::Global,
&LimitType::Ip,
];
// modify this to store something to look up the value with later, instead of storing a reference to the actual data itself.
let mut relevant_limits = Vec::new();
if instance_dictated_limits.contains(&limit_type) {
relevant_limits.push((LimitOrigin::Instance, *limit_type));
} else {
relevant_limits.push((LimitOrigin::User, *limit_type));
}
relevant_limits.push((LimitOrigin::Instance, LimitType::Global));
relevant_limits.push((LimitOrigin::Instance, LimitType::Ip));
if response_was_err {
relevant_limits.push((LimitOrigin::User, LimitType::Error));
}
let time: u64 = chrono::Utc::now().timestamp() as u64;
for relevant_limit in relevant_limits.iter() {
let mut belongs_to = user.belongs_to.borrow_mut();
let limit = match relevant_limit.0 {
LimitOrigin::Instance => {
log::trace!(
"Updating instance rate limit. Bucket: {:?}",
relevant_limit.1
);
belongs_to
.limits_information
.as_mut()
.unwrap()
.ratelimits
.get_mut(&relevant_limit.1)
.unwrap()
}
LimitOrigin::User => {
log::trace!("Updating user rate limit. Bucket: {:?}", relevant_limit.1);
user.limits
.as_mut()
.unwrap()
.get_mut(&relevant_limit.1)
.unwrap()
}
};
if time > limit.reset {
// Spacebar does not yet return rate limit information in its response headers. We
// therefore have to guess the next rate limit window. This is not ideal. Oh well!
log::trace!("Rate limit replenished. Bucket: {:?}", limit.bucket);
limit.reset += limit.window;
limit.remaining = limit.limit;
}
limit.remaining -= 1;
}
}
pub(crate) async fn get_limits_config(url_api: &str) -> ChorusResult<LimitsConfiguration> {
let request = Client::new()
.get(format!("{}/policies/instance/limits/", url_api))
.send()
.await;
let request = match request {
Ok(request) => request,
Err(e) => {
return Err(ChorusError::RequestFailed {
url: url_api.to_string(),
error: e,
})
}
};
let limits_configuration = match request.status().as_u16() {
200 => from_str::<LimitsConfiguration>(&request.text().await.unwrap()).unwrap(),
429 => {
return Err(ChorusError::RateLimited {
bucket: format!("{:?}", LimitType::Ip),
})
}
404 => return Err(ChorusError::NotFound { error: "Route \"/policies/instance/limits/\" not found. Are you perhaps trying to request the Limits configuration from an unsupported server?".to_string() }),
400..=u16::MAX => {
return Err(ChorusError::ReceivedErrorCode { error_code: request.status().as_u16(), error: request.text().await.unwrap() })
}
_ => {
return Err(ChorusError::InvalidResponse {
error: request.text().await.unwrap(),
})
}
};
Ok(limits_configuration)
}
pub(crate) fn limits_config_to_hashmap(
limits_configuration: &RateLimits,
) -> HashMap<LimitType, Limit> {
let config = limits_configuration.clone();
let routes = config.routes;
let mut map: HashMap<LimitType, Limit> = HashMap::new();
let time: u64 = chrono::Utc::now().timestamp() as u64;
map.insert(
LimitType::AuthLogin,
Limit {
bucket: LimitType::AuthLogin,
limit: routes.auth.login.count,
remaining: routes.auth.login.count,
reset: routes.auth.login.window + time,
window: routes.auth.login.window,
},
);
map.insert(
LimitType::AuthRegister,
Limit {
bucket: LimitType::AuthRegister,
limit: routes.auth.register.count,
remaining: routes.auth.register.count,
reset: routes.auth.register.window + time,
window: routes.auth.register.window,
},
);
map.insert(
LimitType::ChannelBaseline,
Limit {
bucket: LimitType::ChannelBaseline,
limit: routes.channel.count,
remaining: routes.channel.count,
reset: routes.channel.window + time,
window: routes.channel.window,
},
);
map.insert(
LimitType::Error,
Limit {
bucket: LimitType::Error,
limit: config.error.count,
remaining: config.error.count,
reset: config.error.window + time,
window: config.error.window,
},
);
map.insert(
LimitType::Global,
Limit {
bucket: LimitType::Global,
limit: config.global.count,
remaining: config.global.count,
reset: config.global.window + time,
window: config.global.window,
},
);
map.insert(
LimitType::Ip,
Limit {
bucket: LimitType::Ip,
limit: config.ip.count,
remaining: config.ip.count,
reset: config.ip.window + time,
window: config.ip.window,
},
);
map.insert(
LimitType::GuildBaseline,
Limit {
bucket: LimitType::GuildBaseline,
limit: routes.guild.count,
remaining: routes.guild.count,
reset: routes.guild.window + time,
window: routes.guild.window,
},
);
map.insert(
LimitType::WebhookBaseline,
Limit {
bucket: LimitType::WebhookBaseline,
limit: routes.webhook.count,
remaining: routes.webhook.count,
reset: routes.webhook.window + time,
window: routes.webhook.window,
},
);
map
}
/// Sends a [`ChorusRequest`] and returns a [`ChorusResult`] that contains nothing if the request
/// was successful, or a [`ChorusError`] if the request failed.
pub(crate) async fn handle_request_as_result(self, user: &mut UserMeta) -> ChorusResult<()> {
match self.send_request(user).await {
Ok(_) => Ok(()),
Err(e) => Err(e),
}
}
/// Sends a [`ChorusRequest`] and returns a [`ChorusResult`] that contains a [`T`] if the request
/// was successful, or a [`ChorusError`] if the request failed.
pub(crate) async fn deserialize_response<T: for<'a> Deserialize<'a>>(
self,
user: &mut UserMeta,
) -> ChorusResult<T> {
let response = self.send_request(user).await?;
let response_text = match response.text().await {
Ok(string) => string,
Err(e) => {
return Err(ChorusError::InvalidResponse {
error: format!(
"Error while trying to process the HTTP response into a String: {}",
e
),
});
}
};
let object = match from_str::<T>(&response_text) {
Ok(object) => object,
Err(e) => {
return Err(ChorusError::InvalidResponse {
error: format!(
"Error while trying to deserialize the JSON response into T: {}",
e
),
})
}
};
Ok(object)
}
}
enum LimitOrigin {
Instance,
User,
}

View File

@ -18,10 +18,8 @@ pub struct GeneralConfiguration {
impl Default for GeneralConfiguration {
fn default() -> Self {
Self {
instance_name: String::from("Spacebar Instance"),
instance_description: Some(String::from(
"This is a Spacebar instance made in the pre-release days",
)),
instance_name: String::from("Spacebar-compatible Instance"),
instance_description: Some(String::from("This is a spacebar-compatible instance.")),
front_page: None,
tos_page: None,
correspondence_email: None,

View File

@ -1,7 +1,12 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::types::config::types::subconfigs::limits::ratelimits::{
route::RouteRateLimit, RateLimitOptions,
use crate::{
api::LimitType,
types::config::types::subconfigs::limits::ratelimits::{
route::RouteRateLimit, RateLimitOptions,
},
};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -39,3 +44,18 @@ impl Default for RateLimits {
}
}
}
impl RateLimits {
pub fn to_hash_map(&self) -> HashMap<LimitType, RateLimitOptions> {
let mut map = HashMap::new();
map.insert(LimitType::AuthLogin, self.routes.auth.login.clone());
map.insert(LimitType::AuthRegister, self.routes.auth.register.clone());
map.insert(LimitType::ChannelBaseline, self.routes.channel.clone());
map.insert(LimitType::Error, self.error.clone());
map.insert(LimitType::Global, self.global.clone());
map.insert(LimitType::Ip, self.ip.clone());
map.insert(LimitType::WebhookBaseline, self.routes.webhook.clone());
map.insert(LimitType::GuildBaseline, self.routes.guild.clone());
map
}
}

View File

@ -1,9 +1,8 @@
use crate::types::utils::Snowflake;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_aux::prelude::deserialize_option_number_from_string;
use crate::types::utils::Snowflake;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[cfg_attr(feature = "sqlx", derive(sqlx::Type))]
pub struct UserData {
@ -16,7 +15,6 @@ impl User {
PublicUser::from(self)
}
}
#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))]
pub struct User {
@ -89,6 +87,7 @@ impl From<User> for PublicUser {
}
}
#[allow(dead_code)] // FIXME: Remove this when we actually use this
const CUSTOM_USER_FLAG_OFFSET: u64 = 1 << 32;
bitflags::bitflags! {

View File

@ -1,122 +1,8 @@
use regex::Regex;
use serde::{Deserialize, Serialize};
use crate::errors::FieldFormatError;
/**
A struct that represents a well-formed email address.
*/
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct AuthEmail {
pub email: String,
}
impl AuthEmail {
/**
Returns a new [`Result<AuthEmail, FieldFormatError>`].
## Arguments
The email address you want to validate.
## Errors
You will receive a [`FieldFormatError`], if:
- The email address is not in a valid format.
*/
pub fn new(email: String) -> Result<AuthEmail, FieldFormatError> {
let regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
if !regex.is_match(email.as_str()) {
return Err(FieldFormatError::EmailError);
}
Ok(AuthEmail { email })
}
}
/**
A struct that represents a well-formed username.
## Arguments
Please use new() to create a new instance of this struct.
## Errors
You will receive a [`FieldFormatError`], if:
- The username is not between 2 and 32 characters.
*/
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct AuthUsername {
pub username: String,
}
impl AuthUsername {
/**
Returns a new [`Result<AuthUsername, FieldFormatError>`].
## Arguments
The username you want to validate.
## Errors
You will receive a [`FieldFormatError`], if:
- The username is not between 2 and 32 characters.
*/
pub fn new(username: String) -> Result<AuthUsername, FieldFormatError> {
if username.len() < 2 || username.len() > 32 {
Err(FieldFormatError::UsernameError)
} else {
Ok(AuthUsername { username })
}
}
}
/**
A struct that represents a well-formed password.
## Arguments
Please use new() to create a new instance of this struct.
## Errors
You will receive a [`FieldFormatError`], if:
- The password is not between 1 and 72 characters.
*/
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct AuthPassword {
pub password: String,
}
impl AuthPassword {
/**
Returns a new [`Result<AuthPassword, FieldFormatError>`].
## Arguments
The password you want to validate.
## Errors
You will receive a [`FieldFormatError`], if:
- The password is not between 1 and 72 characters.
*/
pub fn new(password: String) -> Result<AuthPassword, FieldFormatError> {
if password.is_empty() || password.len() > 72 {
Err(FieldFormatError::PasswordError)
} else {
Ok(AuthPassword { password })
}
}
}
/**
A struct that represents a well-formed register request.
## Arguments
Please use new() to create a new instance of this struct.
## Errors
You will receive a [`FieldFormatError`], if:
- The username is not between 2 and 32 characters.
- The password is not between 1 and 72 characters.
*/
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct RegisterSchema {
username: String,
password: Option<String>,
consent: bool,
email: Option<String>,
fingerprint: Option<String>,
invite: Option<String>,
date_of_birth: Option<String>,
gift_code_sku_id: Option<String>,
captcha_key: Option<String>,
promotional_email_opt_in: Option<bool>,
}
pub struct RegisterSchemaOptions {
pub username: String,
pub password: Option<String>,
pub consent: bool,
@ -129,123 +15,21 @@ pub struct RegisterSchemaOptions {
pub promotional_email_opt_in: Option<bool>,
}
impl RegisterSchema {
pub fn builder(username: impl Into<String>, consent: bool) -> RegisterSchemaOptions {
RegisterSchemaOptions {
username: username.into(),
password: None,
consent,
email: None,
fingerprint: None,
invite: None,
date_of_birth: None,
gift_code_sku_id: None,
captcha_key: None,
promotional_email_opt_in: None,
}
}
}
impl RegisterSchemaOptions {
/**
Create a new [`RegisterSchema`].
## Arguments
All but "String::username" and "bool::consent" are optional.
## Errors
You will receive a [`FieldFormatError`], if:
- The username is less than 2 or more than 32 characters in length
- You supply a `password` which is less than 1 or more than 72 characters in length.
These constraints have been defined [in the Spacebar-API](https://docs.spacebar.chat/routes/)
*/
pub fn build(self) -> Result<RegisterSchema, FieldFormatError> {
let username = AuthUsername::new(self.username)?.username;
let email = if let Some(email) = self.email {
Some(AuthEmail::new(email)?.email)
} else {
None
};
let password = if let Some(password) = self.password {
Some(AuthPassword::new(password)?.password)
} else {
None
};
if !self.consent {
return Err(FieldFormatError::ConsentError);
}
Ok(RegisterSchema {
username,
password,
consent: self.consent,
email,
fingerprint: self.fingerprint,
invite: self.invite,
date_of_birth: self.date_of_birth,
gift_code_sku_id: self.gift_code_sku_id,
captcha_key: self.captcha_key,
promotional_email_opt_in: self.promotional_email_opt_in,
})
}
}
/**
A struct that represents a well-formed login request.
## Arguments
Please use new() to create a new instance of this struct.
## Errors
You will receive a [`FieldFormatError`], if:
- The username is not between 2 and 32 characters.
- The password is not between 1 and 72 characters.
*/
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct LoginSchema {
/// For Discord, usernames must be between 2 and 32 characters,
/// but other servers may have different limits.
pub login: String,
pub password: Option<String>,
/// For Discord, must be between 1 and 72 characters,
/// but other servers may have different limits.
pub password: String,
pub undelete: Option<bool>,
pub captcha_key: Option<String>,
pub login_source: Option<String>,
pub gift_code_sku_id: Option<String>,
}
impl LoginSchema {
/**
Returns a new [`Result<LoginSchema, FieldFormatError>`].
## Arguments
login: The username you want to login with.
password: The password you want to login with.
undelete: Honestly no idea what this is for.
captcha_key: The captcha key you want to login with.
login_source: The login source.
gift_code_sku_id: The gift code sku id.
## Errors
You will receive a [`FieldFormatError`], if:
- The username is less than 2 or more than 32 characters in length
*/
pub fn new(
login: String,
password: Option<String>,
undelete: Option<bool>,
captcha_key: Option<String>,
login_source: Option<String>,
gift_code_sku_id: Option<String>,
) -> Result<LoginSchema, FieldFormatError> {
Ok(LoginSchema {
login,
password,
undelete,
captcha_key,
login_source,
gift_code_sku_id,
})
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TotpSchema {

View File

@ -15,76 +15,3 @@ mod message;
mod relationship;
mod role;
mod user;
#[cfg(test)]
mod schemas_tests {
use crate::errors::FieldFormatError;
use super::*;
#[test]
fn password_too_short() {
assert_eq!(
AuthPassword::new("".to_string()),
Err(FieldFormatError::PasswordError)
);
}
#[test]
fn password_too_long() {
let mut long_pw = String::new();
for _ in 0..73 {
long_pw += "a";
}
assert_eq!(
AuthPassword::new(long_pw),
Err(FieldFormatError::PasswordError)
);
}
#[test]
fn username_too_short() {
assert_eq!(
AuthUsername::new("T".to_string()),
Err(FieldFormatError::UsernameError)
);
}
#[test]
fn username_too_long() {
let mut long_un = String::new();
for _ in 0..33 {
long_un += "a";
}
assert_eq!(
AuthUsername::new(long_un),
Err(FieldFormatError::UsernameError)
);
}
#[test]
fn consent_false() {
assert_eq!(
RegisterSchema::builder("Test", false).build(),
Err(FieldFormatError::ConsentError)
);
}
#[test]
fn invalid_email() {
assert_eq!(
AuthEmail::new("p@p.p".to_string()),
Err(FieldFormatError::EmailError)
)
}
#[test]
fn valid_email() {
let reg = RegisterSchemaOptions {
email: Some("me@mail.de".to_string()),
..RegisterSchema::builder("Testy", true)
}
.build();
assert_ne!(reg, Err(FieldFormatError::EmailError));
}
}

View File

@ -73,6 +73,7 @@ impl Rights {
}
}
#[allow(dead_code)] // FIXME: Remove this when we use this
fn all_rights() -> Rights {
Rights::OPERATOR
| Rights::MANAGE_APPLICATIONS

View File

@ -12,7 +12,7 @@ const EPOCH: i64 = 1420070400000;
/// Unique identifier including a timestamp.
/// See https://discord.com/developers/docs/reference#snowflakes
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "sqlx", derive(Type))]
#[cfg_attr(feature = "sqlx", sqlx(transparent))]
pub struct Snowflake(u64);

View File

@ -1,16 +1,16 @@
use chorus::types::{RegisterSchema, RegisterSchemaOptions};
use chorus::types::RegisterSchema;
mod common;
#[tokio::test]
async fn test_registration() {
let mut bundle = common::setup().await;
let reg = RegisterSchemaOptions {
let reg = RegisterSchema {
username: "Hiiii".into(),
date_of_birth: Some("2000-01-01".to_string()),
..RegisterSchema::builder("Hiiii", true)
}
.build()
.unwrap();
consent: true,
..Default::default()
};
bundle.instance.register_account(&reg).await.unwrap();
common::teardown(bundle).await;
}

View File

@ -1,9 +1,8 @@
use chorus::{
errors::ChorusResult,
instance::{Instance, UserMeta},
types::{
Channel, ChannelCreateSchema, Guild, GuildCreateSchema, RegisterSchema,
RegisterSchemaOptions, RoleCreateModifySchema, RoleObject,
RoleCreateModifySchema, RoleObject,
},
UrlBundle,
};
@ -25,14 +24,14 @@ pub async fn setup() -> TestBundle {
"ws://localhost:3001".to_string(),
"http://localhost:3001".to_string(),
);
let mut instance = Instance::new(urls.clone()).await.unwrap();
let mut instance = Instance::new(urls.clone(), true).await.unwrap();
// Requires the existance of the below user.
let reg = RegisterSchemaOptions {
let reg = RegisterSchema {
username: "integrationtestuser".into(),
consent: true,
date_of_birth: Some("2000-01-01".to_string()),
..RegisterSchema::builder("integrationtestuser", true)
}
.build()
.unwrap();
..Default::default()
};
let guild_create_schema = GuildCreateSchema {
name: Some("Test-Guild!".to_string()),
region: None,

View File

@ -1,15 +1,15 @@
use chorus::types::{self, RegisterSchema, RegisterSchemaOptions, Relationship, RelationshipType};
use chorus::types::{self, RegisterSchema, Relationship, RelationshipType};
mod common;
#[tokio::test]
async fn test_get_mutual_relationships() {
let register_schema = RegisterSchemaOptions {
let register_schema = RegisterSchema {
username: "integrationtestuser2".to_string(),
consent: true,
date_of_birth: Some("2000-01-01".to_string()),
..RegisterSchema::builder("integrationtestuser2", true)
}
.build()
.unwrap();
..Default::default()
};
let mut bundle = common::setup().await;
let belongs_to = &mut bundle.instance;
@ -19,7 +19,7 @@ async fn test_get_mutual_relationships() {
username: user.object.username.clone(),
discriminator: Some(user.object.discriminator.clone()),
};
other_user.send_friend_request(friend_request_schema).await;
let _ = other_user.send_friend_request(friend_request_schema).await;
let relationships = user
.get_mutual_relationships(other_user.object.id)
.await
@ -30,12 +30,12 @@ async fn test_get_mutual_relationships() {
#[tokio::test]
async fn test_get_relationships() {
let register_schema = RegisterSchemaOptions {
let register_schema = RegisterSchema {
username: "integrationtestuser2".to_string(),
consent: true,
date_of_birth: Some("2000-01-01".to_string()),
..RegisterSchema::builder("integrationtestuser2", true)
}
.build()
.unwrap();
..Default::default()
};
let mut bundle = common::setup().await;
let belongs_to = &mut bundle.instance;
@ -45,7 +45,10 @@ async fn test_get_relationships() {
username: user.object.username.clone(),
discriminator: Some(user.object.discriminator.clone()),
};
other_user.send_friend_request(friend_request_schema).await;
other_user
.send_friend_request(friend_request_schema)
.await
.unwrap();
let relationships = user.get_relationships().await.unwrap();
assert_eq!(relationships.get(0).unwrap().id, other_user.object.id);
common::teardown(bundle).await
@ -53,18 +56,18 @@ async fn test_get_relationships() {
#[tokio::test]
async fn test_modify_relationship_friends() {
let register_schema = RegisterSchemaOptions {
let register_schema = RegisterSchema {
username: "integrationtestuser2".to_string(),
consent: true,
date_of_birth: Some("2000-01-01".to_string()),
..RegisterSchema::builder("integrationtestuser2", true)
}
.build()
.unwrap();
..Default::default()
};
let mut bundle = common::setup().await;
let belongs_to = &mut bundle.instance;
let user = &mut bundle.user;
let mut other_user = belongs_to.register_account(&register_schema).await.unwrap();
other_user
let _ = other_user
.modify_user_relationship(user.object.id, types::RelationshipType::Friends)
.await;
let relationships = user.get_relationships().await.unwrap();
@ -79,7 +82,8 @@ async fn test_modify_relationship_friends() {
relationships.get(0).unwrap().relationship_type,
RelationshipType::Outgoing
);
user.modify_user_relationship(other_user.object.id, RelationshipType::Friends)
let _ = user
.modify_user_relationship(other_user.object.id, RelationshipType::Friends)
.await;
assert_eq!(
other_user
@ -91,7 +95,7 @@ async fn test_modify_relationship_friends() {
.relationship_type,
RelationshipType::Friends
);
user.remove_relationship(other_user.object.id).await;
let _ = user.remove_relationship(other_user.object.id).await;
assert_eq!(
other_user.get_relationships().await.unwrap(),
Vec::<Relationship>::new()
@ -101,18 +105,18 @@ async fn test_modify_relationship_friends() {
#[tokio::test]
async fn test_modify_relationship_block() {
let register_schema = RegisterSchemaOptions {
let register_schema = RegisterSchema {
username: "integrationtestuser2".to_string(),
consent: true,
date_of_birth: Some("2000-01-01".to_string()),
..RegisterSchema::builder("integrationtestuser2", true)
}
.build()
.unwrap();
..Default::default()
};
let mut bundle = common::setup().await;
let belongs_to = &mut bundle.instance;
let user = &mut bundle.user;
let mut other_user = belongs_to.register_account(&register_schema).await.unwrap();
other_user
let _ = other_user
.modify_user_relationship(user.object.id, types::RelationshipType::Blocked)
.await;
let relationships = user.get_relationships().await.unwrap();
@ -123,7 +127,7 @@ async fn test_modify_relationship_block() {
relationships.get(0).unwrap().relationship_type,
RelationshipType::Blocked
);
other_user.remove_relationship(user.object.id).await;
let _ = other_user.remove_relationship(user.object.id).await;
assert_eq!(
other_user.get_relationships().await.unwrap(),
Vec::<Relationship>::new()