diff --git a/Cargo.lock b/Cargo.lock index 34d1e4c..067676b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -265,9 +265,7 @@ dependencies = [ [[package]] name = "chorus-macros" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de4221700bc486c6e6bc261fdea478936d33067a06325895f5d2a8cde5917272" +version = "0.4.0" dependencies = [ "async-trait", "quote", diff --git a/Cargo.toml b/Cargo.toml index 60d511a..5e93a7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ thiserror = "1.0.56" jsonwebtoken = "8.3.0" log = "0.4.20" async-trait = "0.1.77" -chorus-macros = "0.3.0" +chorus-macros = { path = "./chorus-macros", version = "0" } # Note: version here is used when releasing. This will use the latest release. Make sure to republish the crate when code in macros is changed! sqlx = { version = "0.7.3", features = [ "mysql", "sqlite", diff --git a/chorus-macros/Cargo.lock b/chorus-macros/Cargo.lock index 9f14619..a3eedd4 100644 --- a/chorus-macros/Cargo.lock +++ b/chorus-macros/Cargo.lock @@ -15,7 +15,7 @@ dependencies = [ [[package]] name = "chorus-macros" -version = "0.2.1" +version = "0.4.0" dependencies = [ "async-trait", "quote", diff --git a/chorus-macros/Cargo.toml b/chorus-macros/Cargo.toml index df5bc7f..c81cc90 100644 --- a/chorus-macros/Cargo.toml +++ b/chorus-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chorus-macros" -version = "0.3.0" +version = "0.4.0" edition = "2021" license = "MPL-2.0" description = "Macros for the chorus crate." diff --git a/chorus-macros/src/lib.rs b/chorus-macros/src/lib.rs index 436fd68..4102039 100644 --- a/chorus-macros/src/lib.rs +++ b/chorus-macros/src/lib.rs @@ -155,3 +155,68 @@ pub fn composite_derive(input: TokenStream) -> TokenStream { _ => panic!("Composite derive macro only supports structs"), } } + + +#[proc_macro_derive(SqlxBitFlags)] +pub fn sqlx_bitflag_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + + quote!{ + #[cfg(feature = "sqlx")] + impl sqlx::Type for #name { + fn type_info() -> sqlx::mysql::MySqlTypeInfo { + u64::type_info() + } + } + + #[cfg(feature = "sqlx")] + impl<'q> sqlx::Encode<'q, sqlx::MySql> for #name { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + u64::encode_by_ref(&self.bits(), buf) + } + } + + #[cfg(feature = "sqlx")] + impl<'q> sqlx::Decode<'q, sqlx::MySql> for #name { + fn decode(value: >::ValueRef) -> Result { + u64::decode(value).map(|d| #name::from_bits(d).unwrap()) + } + } + } + .into() +} + +#[proc_macro_derive(SerdeBitFlags)] +pub fn serde_bitflag_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + + quote! { + impl std::str::FromStr for #name { + type Err = std::num::ParseIntError; + + fn from_str(s: &str) -> Result<#name, Self::Err> { + s.parse::().map(#name::from_bits).map(|f| f.unwrap_or(#name::empty())) + } + } + + impl serde::Serialize for #name { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.bits().to_string()) + } + } + + impl<'de> serde::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> Result<#name, D::Error> where D: serde::de::Deserializer<'de> + Sized { + // let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; + let s = crate::types::serde::string_or_u64(deserializer)?; + + Ok(Self::from_bits(s).unwrap()) + } + } + } + .into() +} \ No newline at end of file diff --git a/src/api/users/users.rs b/src/api/users/users.rs index 6101713..4f6ef57 100644 --- a/src/api/users/users.rs +++ b/src/api/users/users.rs @@ -32,7 +32,7 @@ impl ChorusUser { /// # Notes /// This function is a wrapper around [`User::get_settings`]. pub async fn get_settings(&mut self) -> ChorusResult { - User::get_settings(self).await + User::get_settings(self).await } /// Modifies the current user's representation. (See [`User`]) @@ -40,12 +40,18 @@ impl ChorusUser { /// # Reference /// See pub async fn modify(&mut self, modify_schema: UserModifySchema) -> ChorusResult { - if modify_schema.new_password.is_some() + + // See , note 1 + let requires_current_password = modify_schema.username.is_some() + || modify_schema.discriminator.is_some() || modify_schema.email.is_some() - || modify_schema.code.is_some() - { + || modify_schema.date_of_birth.is_some() + || modify_schema.new_password.is_some(); + + if requires_current_password && modify_schema.current_password.is_none() { return Err(ChorusError::PasswordRequired); } + let request = Client::new() .patch(format!( "{}/users/@me", @@ -132,4 +138,3 @@ impl User { } } } - diff --git a/src/types/config/types/register_configuration.rs b/src/types/config/types/register_configuration.rs index 19cedfb..4afc40c 100644 --- a/src/types/config/types/register_configuration.rs +++ b/src/types/config/types/register_configuration.rs @@ -3,6 +3,7 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. use serde::{Deserialize, Serialize}; +use serde_aux::prelude::deserialize_number_from_string; use crate::types::{config::types::subconfigs::register::{ DateOfBirthConfiguration, PasswordConfiguration, RegistrationEmailConfiguration, @@ -22,6 +23,7 @@ pub struct RegisterConfiguration { pub allow_multiple_accounts: bool, pub block_proxies: bool, pub incrementing_discriminators: bool, + #[serde(deserialize_with = "deserialize_number_from_string")] pub default_rights: Rights, } diff --git a/src/types/entities/application.rs b/src/types/entities/application.rs index c772788..ac4cb97 100644 --- a/src/types/entities/application.rs +++ b/src/types/entities/application.rs @@ -31,7 +31,7 @@ pub struct Application { pub verify_key: String, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub owner: Shared, - pub flags: u64, + pub flags: ApplicationFlags, #[cfg(feature = "sqlx")] pub redirect_uris: Option>>, #[cfg(not(feature = "sqlx"))] @@ -73,7 +73,7 @@ impl Default for Application { bot_require_code_grant: false, verify_key: "".to_string(), owner: Default::default(), - flags: 0, + flags: ApplicationFlags::empty(), redirect_uris: None, rpc_application_state: 0, store_application_state: 1, @@ -93,12 +93,6 @@ impl Default for Application { } } -impl Application { - pub fn flags(&self) -> ApplicationFlags { - ApplicationFlags::from_bits(self.flags.to_owned()).unwrap() - } -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] /// # Reference /// See @@ -108,7 +102,8 @@ pub struct InstallParams { } bitflags! { - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, chorus_macros::SerdeBitFlags)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// # Reference /// See pub struct ApplicationFlags: u64 { diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index 3d6b4ca..c1219ad 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -8,8 +8,8 @@ use serde_aux::prelude::deserialize_string_from_number; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::fmt::Debug; -use crate::types::Shared; use crate::types::{ + PermissionFlags, Shared, entities::{GuildMember, User}, utils::Snowflake, }; @@ -64,7 +64,9 @@ pub struct Channel { pub managed: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub member: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub member_count: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub message_count: Option, pub name: Option, pub nsfw: Option, @@ -75,6 +77,7 @@ pub struct Channel { #[cfg(not(feature = "sqlx"))] #[cfg_attr(feature = "client", observe_option_vec)] pub permission_overwrites: Option>>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub permissions: Option, pub position: Option, pub rate_limit_per_user: Option, @@ -85,6 +88,7 @@ pub struct Channel { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub thread_metadata: Option, pub topic: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub total_message_sent: Option, pub user_limit: Option, pub video_quality_mode: Option, @@ -144,14 +148,20 @@ pub struct Tag { pub struct PermissionOverwrite { pub id: Snowflake, #[serde(rename = "type")] - #[serde(deserialize_with = "deserialize_string_from_number")] - pub overwrite_type: String, + pub overwrite_type: PermissionOverwriteType, #[serde(default)] - #[serde(deserialize_with = "deserialize_string_from_number")] - pub allow: String, + pub allow: PermissionFlags, #[serde(default)] - #[serde(deserialize_with = "deserialize_string_from_number")] - pub deny: String, + pub deny: PermissionFlags, +} + + +#[derive(Debug, Serialize_repr, Deserialize_repr, Clone, PartialEq, Eq, PartialOrd)] +#[repr(u8)] +/// # Reference +pub enum PermissionOverwriteType { + Role = 0, + Member = 1, } #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] @@ -256,3 +266,12 @@ pub enum ChannelType { // TODO: Couldn't find reference Unhandled = 255, } + + +/// # Reference +/// See +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +pub struct FollowedChannel { + pub channel_id: Snowflake, + pub webhook_id: Snowflake +} \ No newline at end of file diff --git a/src/types/entities/emoji.rs b/src/types/entities/emoji.rs index 1b68f0d..23956b5 100644 --- a/src/types/entities/emoji.rs +++ b/src/types/entities/emoji.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; -use crate::types::Shared; +use crate::types::{PartialEmoji, Shared}; use crate::types::entities::User; use crate::types::Snowflake; @@ -66,3 +66,18 @@ impl PartialEq for Emoji { || self.available != other.available) } } + +impl From for Emoji { + fn from(value: PartialEmoji) -> Self { + Self { + id: value.id.unwrap_or_default(), // TODO: this should be handled differently + name: Some(value.name), + roles: None, + user: None, + require_colons: Some(value.animated), + managed: None, + animated: Some(value.animated), + available: None, + } + } +} \ No newline at end of file diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index 66b9335..866b172 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -59,7 +59,8 @@ pub struct Guild { pub emojis: Vec>, pub explicit_content_filter: Option, //#[cfg_attr(feature = "sqlx", sqlx(try_from = "String"))] - pub features: Option, + #[serde(default)] + pub features: GuildFeaturesList, pub icon: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub icon_hash: Option, @@ -99,7 +100,7 @@ pub struct Guild { pub splash: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub stickers: Option>, - pub system_channel_flags: Option, + pub system_channel_flags: Option, pub system_channel_id: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub vanity_url_code: Option, @@ -111,7 +112,7 @@ pub struct Guild { #[cfg_attr(feature = "client", observe_option_vec)] pub webhooks: Option>>, #[cfg(feature = "sqlx")] - pub welcome_screen: Option>, + pub welcome_screen: sqlx::types::Json>, #[cfg(not(feature = "sqlx"))] pub welcome_screen: Option, pub widget_channel_id: Option, @@ -422,7 +423,8 @@ pub enum PremiumTier { } bitflags! { - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, chorus_macros::SerdeBitFlags)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// # Reference /// See pub struct SystemChannelFlags: u64 { diff --git a/src/types/entities/guild_member.rs b/src/types/entities/guild_member.rs index df07583..522824c 100644 --- a/src/types/entities/guild_member.rs +++ b/src/types/entities/guild_member.rs @@ -5,7 +5,7 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use crate::types::Shared; +use crate::types::{GuildMemberFlags, Shared}; use crate::types::{entities::PublicUser, Snowflake}; #[derive(Debug, Deserialize, Default, Serialize, Clone)] @@ -25,7 +25,7 @@ pub struct GuildMember { pub premium_since: Option>, pub deaf: bool, pub mute: bool, - pub flags: Option, + pub flags: Option, pub pending: Option, pub permissions: Option, pub communication_disabled_until: Option>, diff --git a/src/types/entities/invite.rs b/src/types/entities/invite.rs index caec03f..388cc32 100644 --- a/src/types/entities/invite.rs +++ b/src/types/entities/invite.rs @@ -16,7 +16,9 @@ use super::{Application, Channel, GuildMember, NSFWLevel, User}; #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct Invite { + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub approximate_member_count: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub approximate_presence_count: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub channel: Option, @@ -45,7 +47,7 @@ pub struct Invite { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub target_user: Option, pub temporary: Option, - pub uses: Option, + pub uses: Option, } /// The guild an invite is for. @@ -77,20 +79,17 @@ impl From for InviteGuild { icon: value.icon, splash: value.splash, verification_level: value.verification_level.unwrap_or_default(), - features: value.features.unwrap_or_default(), + features: value.features, vanity_url_code: value.vanity_url_code, description: value.description, banner: value.banner, premium_subscription_count: value.premium_subscription_count, nsfw_deprecated: None, nsfw_level: value.nsfw_level.unwrap_or_default(), - welcome_screen: value.welcome_screen.map(|obj| { - #[cfg(feature = "sqlx")] - let res = obj.0; - #[cfg(not(feature = "sqlx"))] - let res = obj; - res - }), + #[cfg(feature = "sqlx")] + welcome_screen: value.welcome_screen.0, + #[cfg(not(feature = "sqlx"))] + welcome_screen: value.welcome_screen, } } } diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index 4b57e06..34a7b9b 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -2,8 +2,10 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use bitflags::bitflags; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; use crate::types::{ Shared, @@ -39,7 +41,7 @@ pub struct Message { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub attachments: Option>, #[cfg(feature = "sqlx")] - pub embeds: Vec>, + pub embeds: sqlx::types::Json>, #[cfg(not(feature = "sqlx"))] pub embeds: Option>, #[cfg(feature = "sqlx")] @@ -50,7 +52,7 @@ pub struct Message { pub pinned: bool, pub webhook_id: Option, #[serde(rename = "type")] - pub message_type: i32, + pub message_type: MessageType, #[cfg(feature = "sqlx")] pub activity: Option>, #[cfg(not(feature = "sqlx"))] @@ -62,14 +64,22 @@ pub struct Message { pub message_reference: Option>, #[cfg(not(feature = "sqlx"))] pub message_reference: Option, - pub flags: Option, + pub flags: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub referenced_message: Option>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub interaction: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub thread: Option, + #[cfg(feature = "sqlx")] + pub components: Option>>, + #[cfg(not(feature = "sqlx"))] pub components: Option>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub sticker_items: Option>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub stickers: Option>, - pub position: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub role_subscription_data: Option, } @@ -103,7 +113,7 @@ impl PartialEq for Message { && self.thread == other.thread && self.components == other.components && self.sticker_items == other.sticker_items - && self.position == other.position + // && self.position == other.position && self.role_subscription_data == other.role_subscription_data } } @@ -112,12 +122,22 @@ impl PartialEq for Message { /// # Reference /// See pub struct MessageReference { + #[serde(rename = "type")] + pub reference_type: MessageReferenceType, pub message_id: Snowflake, pub channel_id: Snowflake, pub guild_id: Option, pub fail_if_not_exists: Option, } +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Eq, Ord, PartialOrd)] +pub enum MessageReferenceType { + /// A standard reference used by replies and system messages + Default = 0, + /// A reference used to point to a message at a point in time + Forward = 1, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct MessageInteraction { pub id: Snowflake, @@ -227,10 +247,15 @@ pub struct EmbedField { pub struct Reaction { pub count: u32, pub burst_count: u32, + #[serde(default)] pub me: bool, + #[serde(default)] pub burst_me: bool, pub burst_colors: Vec, pub emoji: Emoji, + #[cfg(feature = "sqlx")] + #[serde(skip)] + pub user_ids: Vec } #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Eq, PartialOrd, Ord)] @@ -253,3 +278,155 @@ pub struct MessageActivity { pub activity_type: i64, pub party_id: Option, } + +#[derive(Debug, Default, PartialEq, Clone, Copy, Serialize_repr, Deserialize_repr, Eq, PartialOrd, Ord)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[repr(u8)] +#[cfg_attr(feature = "sqlx", derive(sqlx::Type))] +/// # Reference +/// See +pub enum MessageType { + /// A default message + #[default] + Default = 0, + /// A message sent when a user is added to a group DM or thread + RecipientAdd = 1, + /// A message sent when a user is removed from a group DM or thread + RecipientRemove = 2, + /// A message sent when a user creates a call in a private channel + Call = 3, + /// A message sent when a group DM or thread's name is changed + ChannelNameChange = 4, + /// A message sent when a group DM's icon is changed + ChannelIconChange = 5, + /// A message sent when a message is pinned in a channel + ChannelPinnedMessage = 6, + /// A message sent when a user joins a guild + GuildMemberJoin = 7, + /// A message sent when a user subscribes to (boosts) a guild + UserPremiumGuildSubscription = 8, + /// A message sent when a user subscribes to (boosts) a guild to tier 1 + UserPremiumGuildSubscriptionTier1 = 9, + /// A message sent when a user subscribes to (boosts) a guild to tier 2 + UserPremiumGuildSubscriptionTier2 = 10, + /// A message sent when a user subscribes to (boosts) a guild to tier 3 + UserPremiumGuildSubscriptionTier3 = 11, + /// A message sent when a news channel is followed + ChannelFollowAdd = 12, + /// A message sent when a user starts streaming in a guild (deprecated) + #[deprecated] + GuildStream = 13, + /// A message sent when a guild is disqualified from discovery + GuildDiscoveryDisqualified = 14, + /// A message sent when a guild requalifies for discovery + GuildDiscoveryRequalified = 15, + /// A message sent when a guild has failed discovery requirements for a week + GuildDiscoveryGracePeriodInitial = 16, + /// A message sent when a guild has failed discovery requirements for 3 weeks + GuildDiscoveryGracePeriodFinal = 17, + /// A message sent when a thread is created + ThreadCreated = 18, + /// A message sent when a user replies to a message + Reply = 19, + /// A message sent when a user uses a slash command + #[serde(rename = "CHAT_INPUT_COMMAND")] + ApplicationCommand = 20, + /// A message sent when a thread starter message is added to a thread + ThreadStarterMessage = 21, + /// A message sent to remind users to invite friends to a guild + GuildInviteReminder = 22, + /// A message sent when a user uses a context menu command + ContextMenuCommand = 23, + /// A message sent when auto moderation takes an action + AutoModerationAction = 24, + /// A message sent when a user purchases or renews a role subscription + RoleSubscriptionPurchase = 25, + /// A message sent when a user is upsold to a premium interaction + InteractionPremiumUpsell = 26, + /// A message sent when a stage channel starts + StageStart = 27, + /// A message sent when a stage channel ends + StageEnd = 28, + /// A message sent when a user starts speaking in a stage channel + StageSpeaker = 29, + /// A message sent when a user raises their hand in a stage channel + StageRaiseHand = 30, + /// A message sent when a stage channel's topic is changed + StageTopic = 31, + /// A message sent when a user purchases an application premium subscription + GuildApplicationPremiumSubscription = 32, + /// A message sent when a user adds an application to group DM + PrivateChannelIntegrationAdded = 33, + /// A message sent when a user removed an application from a group DM + PrivateChannelIntegrationRemoved = 34, + /// A message sent when a user gifts a premium (Nitro) referral + PremiumReferral = 35, + /// A message sent when a user enabled lockdown for the guild + GuildIncidentAlertModeEnabled = 36, + /// A message sent when a user disables lockdown for the guild + GuildIncidentAlertModeDisabled = 37, + /// A message sent when a user reports a raid for the guild + GuildIncidentReportRaid = 38, + /// A message sent when a user reports a false alarm for the guild + GuildIncidentReportFalseAlarm = 39, + /// A message sent when no one sends a message in the current channel for 1 hour + GuildDeadchatRevivePrompt = 40, + /// A message sent when a user buys another user a gift + CustomGift = 41, + GuildGamingStatsPrompt = 42, + /// A message sent when a user purchases a guild product + PurchaseNotification = 44 +} + +bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, chorus_macros::SerdeBitFlags)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] + /// # Reference + /// See + pub struct MessageFlags: u64 { + /// This message has been published to subscribed channels (via Channel Following) + const CROSSPOSTED = 1 << 0; + /// This message originated from a message in another channel (via Channel Following) + const IS_CROSSPOST = 1 << 1; + /// Embeds will not be included when serializing this message + const SUPPRESS_EMBEDS = 1 << 2; + /// The source message for this crosspost has been deleted (via Channel Following) + const SOURCE_MESSAGE_DELETED = 1 << 3; + /// This message came from the urgent message system + const URGENT = 1 << 4; + /// This message has an associated thread, with the same ID as the message + const HAS_THREAD = 1 << 5; + /// This message is only visible to the user who invoked the interaction + const EPHEMERAL = 1 << 6; + /// This message is an interaction response and the bot is "thinking" + const LOADING = 1 << 7; + /// Some roles were not mentioned and added to the thread + const FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8; + /// This message contains a link that impersonates Discord + const SHOULD_SHOW_LINK_NOT_DISCORD_WARNING = 1 << 10; + /// This message will not trigger push and desktop notifications + const SUPPRESS_NOTIFICATIONS = 1 << 12; + /// This message's audio attachments are rendered as voice messages + const VOICE_MESSAGE = 1 << 13; + /// This message has a forwarded message snapshot attached + const HAS_SNAPSHOT = 1 << 14; + } +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct PartialEmoji { + #[serde(default)] + pub id: Option, + pub name: String, + #[serde(default)] + pub animated: bool +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, PartialOrd)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[cfg_attr(feature = "sqlx", derive(sqlx::Type))] +#[repr(u8)] +pub enum ReactionType { + Normal = 0, + Burst = 1, // The dreaded super reactions +} \ No newline at end of file diff --git a/src/types/entities/role.rs b/src/types/entities/role.rs index 1b5e91e..7e804f5 100644 --- a/src/types/entities/role.rs +++ b/src/types/entities/role.rs @@ -71,7 +71,8 @@ pub struct RoleTags { } bitflags! { - #[derive(Debug, Default, Clone, Hash, Serialize, Deserialize, PartialEq, Eq)] + #[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, chorus_macros::SerdeBitFlags)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// Permissions limit what users of certain roles can do on a Guild to Guild basis. /// /// # Reference: diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index 66fbab8..70669d0 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -4,9 +4,11 @@ use crate::types::utils::Snowflake; use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_aux::prelude::deserialize_option_number_from_string; use std::fmt::Debug; +use std::num::ParseIntError; +use std::str::FromStr; #[cfg(feature = "client")] use crate::gateway::Updateable; @@ -54,7 +56,7 @@ pub struct User { /// So we need to account for that #[serde(default)] #[serde(deserialize_with = "deserialize_option_number_from_string")] - pub flags: Option, + pub flags: Option, pub premium_since: Option>, pub premium_type: Option, pub pronouns: Option, @@ -111,8 +113,8 @@ impl From for PublicUser { const CUSTOM_USER_FLAG_OFFSET: u64 = 1 << 32; bitflags::bitflags! { - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] - #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, chorus_macros::SerdeBitFlags)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct UserFlags: u64 { const DISCORD_EMPLOYEE = 1 << 0; const PARTNERED_SERVER_OWNER = 1 << 1; diff --git a/src/types/entities/webhook.rs b/src/types/entities/webhook.rs index 3fcc43b..aea9f41 100644 --- a/src/types/entities/webhook.rs +++ b/src/types/entities/webhook.rs @@ -32,13 +32,13 @@ use crate::types::{ pub struct Webhook { pub id: Snowflake, #[serde(rename = "type")] - pub webhook_type: i32, + pub webhook_type: WebhookType, pub name: String, pub avatar: String, pub token: String, - pub guild_id: Snowflake, + pub guild_id: Snowflake, pub channel_id: Snowflake, - pub application_id: Snowflake, + pub application_id: Option, #[serde(skip_serializing_if = "Option::is_none")] #[cfg_attr(feature = "sqlx", sqlx(skip))] pub user: Option>, @@ -48,3 +48,13 @@ pub struct Webhook { #[serde(skip_serializing_if = "Option::is_none")] pub url: Option, } + +#[derive(Serialize, Deserialize, Debug, Default, Clone, Copy)] +#[repr(u8)] +#[cfg_attr(feature = "sqlx", derive(sqlx::Type))] +pub enum WebhookType { + #[default] + Incoming = 1, + ChannelFollower = 2, + Application = 3, +} \ No newline at end of file diff --git a/src/types/events/voice_gateway/speaking.rs b/src/types/events/voice_gateway/speaking.rs index 64cf2e7..7a505fd 100644 --- a/src/types/events/voice_gateway/speaking.rs +++ b/src/types/events/voice_gateway/speaking.rs @@ -32,8 +32,8 @@ bitflags! { /// Bitflags of speaking types; /// /// See - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Serialize, Deserialize)] - pub struct SpeakingBitflags: u8 { + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, chorus_macros::SerdeBitFlags)] + pub struct SpeakingBitflags: u64 { /// Whether we'll be transmitting normal voice audio const MICROPHONE = 1 << 0; /// Whether we'll be transmitting context audio for video, no speaking indicator diff --git a/src/types/schema/channel.rs b/src/types/schema/channel.rs index 260b10e..33df3ff 100644 --- a/src/types/schema/channel.rs +++ b/src/types/schema/channel.rs @@ -59,7 +59,7 @@ pub struct GetChannelMessagesSchema { /// Between 1 and 100, defaults to 50. pub limit: Option, #[serde(flatten)] - pub anchor: ChannelMessagesAnchor, + pub anchor: Option, } #[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)] @@ -74,21 +74,21 @@ impl GetChannelMessagesSchema { pub fn before(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::Before(anchor), + anchor: Some(ChannelMessagesAnchor::Before(anchor)), } } pub fn around(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::Around(anchor), + anchor: Some(ChannelMessagesAnchor::Around(anchor)), } } pub fn after(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::After(anchor), + anchor: Some(ChannelMessagesAnchor::After(anchor)), } } @@ -131,63 +131,14 @@ impl Default for CreateChannelInviteSchema { } bitflags! { - #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, chorus_macros::SerdeBitFlags)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct InviteFlags: u64 { const GUEST = 1 << 0; const VIEWED = 1 << 1; } } -impl Serialize for InviteFlags { - fn serialize(&self, serializer: S) -> Result { - self.bits().to_string().serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for InviteFlags { - fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { - struct FlagsVisitor; - - impl<'de> Visitor<'de> for FlagsVisitor - { - type Value = InviteFlags; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("a raw u64 value of flags") - } - - fn visit_u64(self, v: u64) -> Result { - InviteFlags::from_bits(v).ok_or(serde::de::Error::custom(Error::InvalidFlags(v))) - } - } - - deserializer.deserialize_u64(FlagsVisitor) - } -} - -#[cfg(feature = "sqlx")] -impl sqlx::Type for InviteFlags { - fn type_info() -> sqlx::mysql::MySqlTypeInfo { - u64::type_info() - } -} - -#[cfg(feature = "sqlx")] -impl<'q> sqlx::Encode<'q, sqlx::MySql> for InviteFlags { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - u64::encode_by_ref(&self.0.0, buf) - } -} - -#[cfg(feature = "sqlx")] -impl<'r> sqlx::Decode<'r, sqlx::MySql> for InviteFlags { - fn decode(value: >::ValueRef) -> Result { - let raw = u64::decode(value)?; - - Ok(Self::from_bits(raw).unwrap()) - } -} - #[derive(Debug, Deserialize, Serialize, Clone, Copy, Default, PartialOrd, Ord, PartialEq, Eq)] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] @@ -226,3 +177,15 @@ pub struct ModifyChannelPositionsSchema { pub lock_permissions: Option, pub parent_id: Option, } + +/// See +#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialOrd, Ord, PartialEq, Eq)] +pub struct AddFollowingChannelSchema { + pub webhook_channel_id: Snowflake, +} + +#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialOrd, Ord, PartialEq, Eq)] +pub struct CreateWebhookSchema { + pub name: String, + pub avatar: Option, +} \ No newline at end of file diff --git a/src/types/schema/guild.rs b/src/types/schema/guild.rs index 2e29ce0..d820cd8 100644 --- a/src/types/schema/guild.rs +++ b/src/types/schema/guild.rs @@ -132,6 +132,7 @@ pub struct ModifyGuildMemberSchema { bitflags! { #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// Represents the flags of a Guild Member. /// /// # Reference: diff --git a/src/types/schema/message.rs b/src/types/schema/message.rs index 7551b6b..4a7f2ab 100644 --- a/src/types/schema/message.rs +++ b/src/types/schema/message.rs @@ -7,13 +7,13 @@ use serde::{Deserialize, Serialize}; use crate::types::entities::{ AllowedMention, Component, Embed, MessageReference, PartialDiscordFileAttachment, }; -use crate::types::{Attachment, Snowflake}; +use crate::types::{Attachment, MessageFlags, MessageType, ReactionType, Snowflake}; #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")] pub struct MessageSendSchema { #[serde(rename = "type")] - pub message_type: Option, + pub message_type: Option, pub content: Option, pub nonce: Option, pub tts: Option, @@ -118,13 +118,21 @@ pub struct MessageAck { #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)] pub struct MessageModifySchema { - content: Option, - embeds: Option>, - embed: Option, - allowed_mentions: Option, - components: Option>, - flags: Option, - files: Option>, - payload_json: Option, - attachments: Option>, + pub content: Option, + pub embeds: Option>, + pub embed: Option, + pub allowed_mentions: Option, + pub components: Option>, + pub flags: Option, + pub files: Option>, + pub payload_json: Option, + pub attachments: Option>, } + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)] +pub struct ReactionQuerySchema { + pub after: Option, + pub limit: Option, + #[serde(rename = "type")] + pub reaction_type: Option +} \ No newline at end of file diff --git a/src/types/schema/user.rs b/src/types/schema/user.rs index 7d21754..e2600a4 100644 --- a/src/types/schema/user.rs +++ b/src/types/schema/user.rs @@ -4,24 +4,91 @@ use std::collections::HashMap; +use chrono::NaiveDate; use serde::{Deserialize, Serialize}; use crate::types::Snowflake; -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] #[serde(rename_all = "snake_case")] /// A schema used to modify a user. +/// +/// See pub struct UserModifySchema { + /// The user's new username (2-32 characters) + /// + /// Requires that `current_password` is set. pub username: Option, + // TODO: Maybe add a special discriminator type? + /// Requires that `current_password` is set. + pub discriminator: Option, + /// The user's display name (1-32 characters) + /// + /// # Note + /// + /// This is not yet implemented on Spacebar + pub global_name: Option, + // TODO: Add a CDN data type pub avatar: Option, - pub bio: Option, - pub accent_color: Option, - pub banner: Option, - pub current_password: Option, - pub new_password: Option, - pub code: Option, + /// Note: This is not yet implemented on Spacebar + pub avatar_decoration_id: Option, + /// Note: This is not yet implemented on Spacebar + pub avatar_decoration_sku_id: Option, + /// The user's email address; if changing from a verified email, email_token must be provided + /// + /// Requires that `current_password` is set. + // TODO: Is ^ up to date? One would think this may not be the case, since email_token exists pub email: Option, - pub discriminator: Option, + /// The user's email token from their previous email, required if a new email is set. + /// + /// See and + /// for changing the user's email. + /// + /// # Note + /// + /// This is not yet implemented on Spacebar + pub email_token: Option, + /// The user's pronouns (max 40 characters) + /// + /// # Note + /// + /// This is not yet implemented on Spacebar + pub pronouns: Option, + /// The user's banner. + /// + /// Can only be changed for premium users + pub banner: Option, + /// The user's bio (max 190 characters) + pub bio: Option, + /// The user's accent color, as a hex integer + pub accent_color: Option, + /// The user's [UserFlags]. + /// + /// Only [UserFlags::PREMIUM_PROMO_DISMISSED], [UserFlags::HAS_UNREAD_URGENT_MESSAGES] + /// and DISABLE_PREMIUM can be set. + /// + /// # Note + /// + /// This is not yet implemented on Spacebar + pub flags: Option, + /// The user's date of birth, can only be set once + /// + /// Requires that `current_password` is set. + pub date_of_birth: Option, + /// The user's current password (if the account does not have a password, this sets it) + /// + /// Required for updating `username`, `discriminator`, `email`, `date_of_birth` and + /// `new_password` + #[serde(rename = "password")] + pub current_password: Option, + /// The user's new password (8-72 characters) + /// + /// Requires that `current_password` is set. + /// + /// Regenerates the user's token + pub new_password: Option, + /// Spacebar only field, potentially same as `email_token` + pub code: Option, } /// A schema used to create a private channel. @@ -33,7 +100,7 @@ pub struct UserModifySchema { /// /// # Reference: /// Read: -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] pub struct PrivateChannelCreateSchema { pub recipients: Option>, pub access_tokens: Option>, diff --git a/src/types/utils/rights.rs b/src/types/utils/rights.rs index 63978da..598f782 100644 --- a/src/types/utils/rights.rs +++ b/src/types/utils/rights.rs @@ -2,11 +2,11 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use std::num::ParseIntError; +use std::str::FromStr; use bitflags::bitflags; -use serde::{Deserialize, Serialize}; - -#[cfg(feature = "sqlx")] -use sqlx::{{Decode, Encode, MySql}, database::{HasArguments, HasValueRef}, encode::IsNull, error::BoxDynError, mysql::MySqlValueRef}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use crate::types::UserFlags; bitflags! { /// Rights are instance-wide, per-user permissions for everything you may perform on the instance, @@ -18,7 +18,8 @@ bitflags! { /// /// # Reference /// See - #[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] + #[derive(Debug, Clone, Copy, Eq, PartialEq, chorus_macros::SerdeBitFlags)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct Rights: u64 { /// All rights const OPERATOR = 1 << 0; @@ -132,33 +133,6 @@ bitflags! { } } -#[cfg(feature = "sqlx")] -impl sqlx::Type for Rights { - fn type_info() -> ::TypeInfo { - u64::type_info() - } - - fn compatible(ty: &::TypeInfo) -> bool { - u64::compatible(ty) - } -} - -#[cfg(feature = "sqlx")] -impl<'q> Encode<'q, MySql> for Rights { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { - >::encode_by_ref(&self.0.0, buf) - } -} - -#[cfg(feature = "sqlx")] -impl<'r> Decode<'r, MySql> for Rights { - fn decode(value: >::ValueRef) -> Result { - let raw = >::decode(value)?; - Ok(Rights::from_bits(raw).unwrap()) - } -} - - impl Rights { pub fn any(&self, permission: Rights, check_operator: bool) -> bool { (check_operator && self.contains(Rights::OPERATOR)) || self.contains(permission) diff --git a/src/types/utils/serde.rs b/src/types/utils/serde.rs index 8584004..544a305 100644 --- a/src/types/utils/serde.rs +++ b/src/types/utils/serde.rs @@ -1,6 +1,7 @@ use core::fmt; use chrono::{LocalResult, NaiveDateTime}; -use serde::de; +use serde::{de, Deserialize, Deserializer}; +use serde::de::Error; #[doc(hidden)] #[derive(Debug)] @@ -259,4 +260,19 @@ pub(crate) fn serde_from(me: LocalResult, _ts: &V) -> Result } LocalResult::Single(val) => Ok(val), } +} + +#[derive(serde::Deserialize, serde::Serialize)] +#[serde(untagged)] +enum StringOrU64 { + String(String), + U64(u64), +} + +pub fn string_or_u64<'de, D>(d: D) -> Result +where D: Deserializer<'de> { + match StringOrU64::deserialize(d)? { + StringOrU64::String(s) => s.parse::().map_err(D::Error::custom), + StringOrU64::U64(u) => Ok(u) + } } \ No newline at end of file diff --git a/src/types/utils/snowflake.rs b/src/types/utils/snowflake.rs index 1582085..a021f90 100644 --- a/src/types/utils/snowflake.rs +++ b/src/types/utils/snowflake.rs @@ -8,8 +8,6 @@ use std::{ }; use chrono::{DateTime, TimeZone, Utc}; -#[cfg(feature = "sqlx")] -use sqlx::Type; /// 2015-01-01 const EPOCH: i64 = 1420070400000; @@ -19,8 +17,6 @@ const EPOCH: i64 = 1420070400000; /// # Reference /// See #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "sqlx", derive(Type))] -#[cfg_attr(feature = "sqlx", sqlx(transparent))] pub struct Snowflake(pub u64); impl Snowflake { @@ -102,6 +98,27 @@ impl<'de> serde::Deserialize<'de> for Snowflake { } } +#[cfg(feature = "sqlx")] +impl sqlx::Type for Snowflake { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +#[cfg(feature = "sqlx")] +impl<'q> sqlx::Encode<'q, sqlx::MySql> for Snowflake { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + >::encode_by_ref(&self.0.to_string(), buf) + } +} + +#[cfg(feature = "sqlx")] +impl<'d> sqlx::Decode<'d, sqlx::MySql> for Snowflake { + fn decode(value: >::ValueRef) -> Result { + >::decode(value).map(|s| s.parse::().map(Snowflake).unwrap()) + } +} + #[cfg(test)] mod test { use chrono::{DateTime, Utc}; diff --git a/tests/channels.rs b/tests/channels.rs index 14359d2..e00744a 100644 --- a/tests/channels.rs +++ b/tests/channels.rs @@ -2,10 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use chorus::types::{ - self, Channel, GetChannelMessagesSchema, MessageSendSchema, PermissionFlags, - PermissionOverwrite, PrivateChannelCreateSchema, RelationshipType, Snowflake, -}; +use chorus::types::{self, Channel, GetChannelMessagesSchema, MessageSendSchema, PermissionFlags, PermissionOverwrite, PermissionOverwriteType, PrivateChannelCreateSchema, RelationshipType, Snowflake}; mod common; @@ -69,16 +66,13 @@ async fn modify_channel() { .unwrap(); assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string())); - let permission_override = PermissionFlags::from_vec(Vec::from([ - PermissionFlags::MANAGE_CHANNELS, - PermissionFlags::MANAGE_MESSAGES, - ])); + let permission_override = PermissionFlags::MANAGE_CHANNELS | PermissionFlags::MANAGE_MESSAGES; let user_id: types::Snowflake = bundle.user.object.read().unwrap().id; let permission_override = PermissionOverwrite { id: user_id, - overwrite_type: "1".to_string(), + overwrite_type: PermissionOverwriteType::Member, allow: permission_override, - deny: "0".to_string(), + deny: PermissionFlags::empty(), }; let channel_id: Snowflake = bundle.channel.read().unwrap().id; Channel::modify_permissions( diff --git a/tests/types.rs b/tests/types.rs index 8895bb4..48b132c 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -920,7 +920,7 @@ mod entities { .clone() ); let flags = ApplicationFlags::from_bits(0).unwrap(); - assert!(application.flags() == flags); + assert!(application.flags == flags); } #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]