From 0468292ec646d8be99cbb4f97a47b23340c4a6de Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Fri, 7 Jun 2024 02:08:21 -0400 Subject: [PATCH 01/22] Add type locks --- src/types/entities/message.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index 4b57e06..6e493fc 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -63,13 +63,20 @@ pub struct Message { #[cfg(not(feature = "sqlx"))] pub message_reference: Option, pub flags: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub referenced_message: Option>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub interaction: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub thread: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub components: Option>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub sticker_items: Option>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub stickers: Option>, pub position: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub role_subscription_data: Option, } From ed5365ade4e6a367d5ff8ade79e07382c2d5e6e7 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Fri, 7 Jun 2024 02:08:41 -0400 Subject: [PATCH 02/22] Fix inverted type wrapping --- src/types/entities/message.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index 6e493fc..9159c58 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -39,7 +39,7 @@ pub struct Message { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub attachments: Option>, #[cfg(feature = "sqlx")] - pub embeds: Vec>, + pub embeds: sqlx::types::Json>, #[cfg(not(feature = "sqlx"))] pub embeds: Option>, #[cfg(feature = "sqlx")] From c2035585ae41de2b3a59fddccaf82dd542c9c1c2 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Fri, 7 Jun 2024 13:02:17 -0400 Subject: [PATCH 03/22] Update Cargo.lock --- chorus-macros/Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chorus-macros/Cargo.lock b/chorus-macros/Cargo.lock index 9f14619..64632bc 100644 --- a/chorus-macros/Cargo.lock +++ b/chorus-macros/Cargo.lock @@ -15,7 +15,7 @@ dependencies = [ [[package]] name = "chorus-macros" -version = "0.2.1" +version = "0.3.0" dependencies = [ "async-trait", "quote", From 656e5e31c084e53d5a5716d68ddb04bdf71397c9 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Fri, 7 Jun 2024 13:03:44 -0400 Subject: [PATCH 04/22] Add SqlxBitFlags derive macro --- chorus-macros/src/lib.rs | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/chorus-macros/src/lib.rs b/chorus-macros/src/lib.rs index 436fd68..6034ff8 100644 --- a/chorus-macros/src/lib.rs +++ b/chorus-macros/src/lib.rs @@ -155,3 +155,34 @@ pub fn composite_derive(input: TokenStream) -> TokenStream { _ => panic!("Composite derive macro only supports structs"), } } + +#[proc_macro_derive(SqlxBitFlags)] +pub fn sqlx_bitflag_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + + quote!{ + #[cfg(feature = "sqlx")] + impl sqlx::Type for #name { + fn type_info() -> sqlx::mysql::MySqlTypeInfo { + u64::type_info() + } + } + + #[cfg(feature = "sqlx")] + impl<'q> sqlx::Encode<'q, sqlx::MySql> for #name { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { + u64::encode_by_ref(&self.bits(), buf) + } + } + + #[cfg(feature = "sqlx")] + impl<'q> sqlx::Decode<'q, sqlx::MySql> for #name { + fn decode(value: >::ValueRef) -> Result { + u64::decode(value).map(|d| #name::from_bits(d).unwrap()) + } + } + } + .into() +} \ No newline at end of file From d89975819a08fd4304e9b25b8d7a10c32405531b Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Fri, 7 Jun 2024 13:13:03 -0400 Subject: [PATCH 05/22] Utilize new macros and use distinct Flag types --- src/types/entities/application.rs | 11 +++------ src/types/entities/guild.rs | 3 ++- src/types/entities/guild_member.rs | 4 ++-- src/types/entities/message.rs | 38 +++++++++++++++++++++++++++++- src/types/entities/role.rs | 1 + src/types/entities/user.rs | 4 ++-- src/types/schema/channel.rs | 24 +------------------ src/types/schema/guild.rs | 1 + src/types/schema/message.rs | 4 ++-- 9 files changed, 51 insertions(+), 39 deletions(-) diff --git a/src/types/entities/application.rs b/src/types/entities/application.rs index c772788..f57a71f 100644 --- a/src/types/entities/application.rs +++ b/src/types/entities/application.rs @@ -31,7 +31,7 @@ pub struct Application { pub verify_key: String, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub owner: Shared, - pub flags: u64, + pub flags: ApplicationFlags, #[cfg(feature = "sqlx")] pub redirect_uris: Option>>, #[cfg(not(feature = "sqlx"))] @@ -73,7 +73,7 @@ impl Default for Application { bot_require_code_grant: false, verify_key: "".to_string(), owner: Default::default(), - flags: 0, + flags: ApplicationFlags::empty(), redirect_uris: None, rpc_application_state: 0, store_application_state: 1, @@ -93,12 +93,6 @@ impl Default for Application { } } -impl Application { - pub fn flags(&self) -> ApplicationFlags { - ApplicationFlags::from_bits(self.flags.to_owned()).unwrap() - } -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] /// # Reference /// See @@ -109,6 +103,7 @@ pub struct InstallParams { bitflags! { #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// # Reference /// See pub struct ApplicationFlags: u64 { diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index 66b9335..b8fd0a4 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -99,7 +99,7 @@ pub struct Guild { pub splash: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub stickers: Option>, - pub system_channel_flags: Option, + pub system_channel_flags: Option, pub system_channel_id: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub vanity_url_code: Option, @@ -423,6 +423,7 @@ pub enum PremiumTier { bitflags! { #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// # Reference /// See pub struct SystemChannelFlags: u64 { diff --git a/src/types/entities/guild_member.rs b/src/types/entities/guild_member.rs index df07583..522824c 100644 --- a/src/types/entities/guild_member.rs +++ b/src/types/entities/guild_member.rs @@ -5,7 +5,7 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use crate::types::Shared; +use crate::types::{GuildMemberFlags, Shared}; use crate::types::{entities::PublicUser, Snowflake}; #[derive(Debug, Deserialize, Default, Serialize, Clone)] @@ -25,7 +25,7 @@ pub struct GuildMember { pub premium_since: Option>, pub deaf: bool, pub mute: bool, - pub flags: Option, + pub flags: Option, pub pending: Option, pub permissions: Option, pub communication_disabled_until: Option>, diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index 9159c58..a6d91e7 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -2,6 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use bitflags::bitflags; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -62,7 +63,7 @@ pub struct Message { pub message_reference: Option>, #[cfg(not(feature = "sqlx"))] pub message_reference: Option, - pub flags: Option, + pub flags: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub referenced_message: Option>, #[cfg_attr(feature = "sqlx", sqlx(skip))] @@ -260,3 +261,38 @@ pub struct MessageActivity { pub activity_type: i64, pub party_id: Option, } + +bitflags! { + #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] + /// # Reference + /// See + pub struct MessageFlags: u64 { + /// This message has been published to subscribed channels (via Channel Following) + const CROSSPOSTED = 1 << 0; + /// This message originated from a message in another channel (via Channel Following) + const IS_CROSSPOST = 1 << 1; + /// Embeds will not be included when serializing this message + const SUPPRESS_EMBEDS = 1 << 2; + /// The source message for this crosspost has been deleted (via Channel Following) + const SOURCE_MESSAGE_DELETED = 1 << 3; + /// This message came from the urgent message system + const URGENT = 1 << 4; + /// This message has an associated thread, with the same ID as the message + const HAS_THREAD = 1 << 5; + /// This message is only visible to the user who invoked the interaction + const EPHEMERAL = 1 << 6; + /// This message is an interaction response and the bot is "thinking" + const LOADING = 1 << 7; + /// Some roles were not mentioned and added to the thread + const FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8; + /// This message contains a link that impersonates Discord + const SHOULD_SHOW_LINK_NOT_DISCORD_WARNING = 1 << 10; + /// This message will not trigger push and desktop notifications + const SUPPRESS_NOTIFICATIONS = 1 << 12; + /// This message's audio attachments are rendered as voice messages + const VOICE_MESSAGE = 1 << 13; + /// This message has a forwarded message snapshot attached + const HAS_SNAPSHOT = 1 << 14; + } +} \ No newline at end of file diff --git a/src/types/entities/role.rs b/src/types/entities/role.rs index 1b5e91e..712c9ab 100644 --- a/src/types/entities/role.rs +++ b/src/types/entities/role.rs @@ -72,6 +72,7 @@ pub struct RoleTags { bitflags! { #[derive(Debug, Default, Clone, Hash, Serialize, Deserialize, PartialEq, Eq)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// Permissions limit what users of certain roles can do on a Guild to Guild basis. /// /// # Reference: diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index 66fbab8..1577479 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -54,7 +54,7 @@ pub struct User { /// So we need to account for that #[serde(default)] #[serde(deserialize_with = "deserialize_option_number_from_string")] - pub flags: Option, + pub flags: Option, pub premium_since: Option>, pub premium_type: Option, pub pronouns: Option, @@ -112,7 +112,7 @@ const CUSTOM_USER_FLAG_OFFSET: u64 = 1 << 32; bitflags::bitflags! { #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] - #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct UserFlags: u64 { const DISCORD_EMPLOYEE = 1 << 0; const PARTNERED_SERVER_OWNER = 1 << 1; diff --git a/src/types/schema/channel.rs b/src/types/schema/channel.rs index 260b10e..fc2eea6 100644 --- a/src/types/schema/channel.rs +++ b/src/types/schema/channel.rs @@ -132,6 +132,7 @@ impl Default for CreateChannelInviteSchema { bitflags! { #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct InviteFlags: u64 { const GUEST = 1 << 0; const VIEWED = 1 << 1; @@ -165,29 +166,6 @@ impl<'de> Deserialize<'de> for InviteFlags { } } -#[cfg(feature = "sqlx")] -impl sqlx::Type for InviteFlags { - fn type_info() -> sqlx::mysql::MySqlTypeInfo { - u64::type_info() - } -} - -#[cfg(feature = "sqlx")] -impl<'q> sqlx::Encode<'q, sqlx::MySql> for InviteFlags { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { - u64::encode_by_ref(&self.0.0, buf) - } -} - -#[cfg(feature = "sqlx")] -impl<'r> sqlx::Decode<'r, sqlx::MySql> for InviteFlags { - fn decode(value: >::ValueRef) -> Result { - let raw = u64::decode(value)?; - - Ok(Self::from_bits(raw).unwrap()) - } -} - #[derive(Debug, Deserialize, Serialize, Clone, Copy, Default, PartialOrd, Ord, PartialEq, Eq)] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] diff --git a/src/types/schema/guild.rs b/src/types/schema/guild.rs index 2e29ce0..d820cd8 100644 --- a/src/types/schema/guild.rs +++ b/src/types/schema/guild.rs @@ -132,6 +132,7 @@ pub struct ModifyGuildMemberSchema { bitflags! { #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// Represents the flags of a Guild Member. /// /// # Reference: diff --git a/src/types/schema/message.rs b/src/types/schema/message.rs index 7551b6b..d6b8cc7 100644 --- a/src/types/schema/message.rs +++ b/src/types/schema/message.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::types::entities::{ AllowedMention, Component, Embed, MessageReference, PartialDiscordFileAttachment, }; -use crate::types::{Attachment, Snowflake}; +use crate::types::{Attachment, MessageFlags, Snowflake}; #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")] @@ -123,7 +123,7 @@ pub struct MessageModifySchema { embed: Option, allowed_mentions: Option, components: Option>, - flags: Option, + flags: Option, files: Option>, payload_json: Option, attachments: Option>, From d3e2ef947ce9f0856c125557d389a619e6bfefc6 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Fri, 7 Jun 2024 13:14:25 -0400 Subject: [PATCH 06/22] Add distinct MessageType enum --- src/types/entities/message.rs | 111 +++++++++++++++++++++++++++++++++- src/types/schema/message.rs | 4 +- 2 files changed, 112 insertions(+), 3 deletions(-) diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index a6d91e7..697be0d 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -5,6 +5,11 @@ use bitflags::bitflags; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use sqlx::database::{HasArguments, HasValueRef}; +use sqlx::encode::IsNull; +use sqlx::error::BoxDynError; +use sqlx::{Encode, MySql}; use crate::types::{ Shared, @@ -51,7 +56,7 @@ pub struct Message { pub pinned: bool, pub webhook_id: Option, #[serde(rename = "type")] - pub message_type: i32, + pub message_type: MessageType, #[cfg(feature = "sqlx")] pub activity: Option>, #[cfg(not(feature = "sqlx"))] @@ -235,10 +240,15 @@ pub struct EmbedField { pub struct Reaction { pub count: u32, pub burst_count: u32, + #[serde(default)] pub me: bool, + #[serde(default)] pub burst_me: bool, pub burst_colors: Vec, pub emoji: Emoji, + #[cfg(feature = "sqlx")] + #[serde(skip)] + pub user_ids: Vec } #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Eq, PartialOrd, Ord)] @@ -262,6 +272,105 @@ pub struct MessageActivity { pub party_id: Option, } +#[derive(Debug, Default, PartialEq, Clone, Copy, Serialize_repr, Deserialize_repr, Eq, PartialOrd, Ord)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[repr(u8)] +#[cfg_attr(feature = "sqlx", derive(sqlx::Type))] +/// # Reference +/// See +pub enum MessageType { + /// A default message + #[default] + Default = 0, + /// A message sent when a user is added to a group DM or thread + RecipientAdd = 1, + /// A message sent when a user is removed from a group DM or thread + RecipientRemove = 2, + /// A message sent when a user creates a call in a private channel + Call = 3, + /// A message sent when a group DM or thread's name is changed + ChannelNameChange = 4, + /// A message sent when a group DM's icon is changed + ChannelIconChange = 5, + /// A message sent when a message is pinned in a channel + ChannelPinnedMessage = 6, + /// A message sent when a user joins a guild + GuildMemberJoin = 7, + /// A message sent when a user subscribes to (boosts) a guild + UserPremiumGuildSubscription = 8, + /// A message sent when a user subscribes to (boosts) a guild to tier 1 + UserPremiumGuildSubscriptionTier1 = 9, + /// A message sent when a user subscribes to (boosts) a guild to tier 2 + UserPremiumGuildSubscriptionTier2 = 10, + /// A message sent when a user subscribes to (boosts) a guild to tier 3 + UserPremiumGuildSubscriptionTier3 = 11, + /// A message sent when a news channel is followed + ChannelFollowAdd = 12, + /// A message sent when a user starts streaming in a guild (deprecated) + #[deprecated] + GuildStream = 13, + /// A message sent when a guild is disqualified from discovery + GuildDiscoveryDisqualified = 14, + /// A message sent when a guild requalifies for discovery + GuildDiscoveryRequalified = 15, + /// A message sent when a guild has failed discovery requirements for a week + GuildDiscoveryGracePeriodInitial = 16, + /// A message sent when a guild has failed discovery requirements for 3 weeks + GuildDiscoveryGracePeriodFinal = 17, + /// A message sent when a thread is created + ThreadCreated = 18, + /// A message sent when a user replies to a message + Reply = 19, + /// A message sent when a user uses a slash command + #[serde(rename = "CHAT_INPUT_COMMAND")] + ApplicationCommand = 20, + /// A message sent when a thread starter message is added to a thread + ThreadStarterMessage = 21, + /// A message sent to remind users to invite friends to a guild + GuildInviteReminder = 22, + /// A message sent when a user uses a context menu command + ContextMenuCommand = 23, + /// A message sent when auto moderation takes an action + AutoModerationAction = 24, + /// A message sent when a user purchases or renews a role subscription + RoleSubscriptionPurchase = 25, + /// A message sent when a user is upsold to a premium interaction + InteractionPremiumUpsell = 26, + /// A message sent when a stage channel starts + StageStart = 27, + /// A message sent when a stage channel ends + StageEnd = 28, + /// A message sent when a user starts speaking in a stage channel + StageSpeaker = 29, + /// A message sent when a user raises their hand in a stage channel + StageRaiseHand = 30, + /// A message sent when a stage channel's topic is changed + StageTopic = 31, + /// A message sent when a user purchases an application premium subscription + GuildApplicationPremiumSubscription = 32, + /// A message sent when a user adds an application to group DM + PrivateChannelIntegrationAdded = 33, + /// A message sent when a user removed an application from a group DM + PrivateChannelIntegrationRemoved = 34, + /// A message sent when a user gifts a premium (Nitro) referral + PremiumReferral = 35, + /// A message sent when a user enabled lockdown for the guild + GuildIncidentAlertModeEnabled = 36, + /// A message sent when a user disables lockdown for the guild + GuildIncidentAlertModeDisabled = 37, + /// A message sent when a user reports a raid for the guild + GuildIncidentReportRaid = 38, + /// A message sent when a user reports a false alarm for the guild + GuildIncidentReportFalseAlarm = 39, + /// A message sent when no one sends a message in the current channel for 1 hour + GuildDeadchatRevivePrompt = 40, + /// A message sent when a user buys another user a gift + CustomGift = 41, + GuildGamingStatsPrompt = 42, + /// A message sent when a user purchases a guild product + PurchaseNotification = 44 +} + bitflags! { #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] diff --git a/src/types/schema/message.rs b/src/types/schema/message.rs index d6b8cc7..a7c1eeb 100644 --- a/src/types/schema/message.rs +++ b/src/types/schema/message.rs @@ -7,13 +7,13 @@ use serde::{Deserialize, Serialize}; use crate::types::entities::{ AllowedMention, Component, Embed, MessageReference, PartialDiscordFileAttachment, }; -use crate::types::{Attachment, MessageFlags, Snowflake}; +use crate::types::{Attachment, MessageFlags, MessageType, Snowflake}; #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")] pub struct MessageSendSchema { #[serde(rename = "type")] - pub message_type: Option, + pub message_type: Option, pub content: Option, pub nonce: Option, pub tts: Option, From 9bd55b9b5f440035ecbdcdc4645ac9374b9ab180 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Fri, 7 Jun 2024 14:08:24 -0400 Subject: [PATCH 07/22] Fix error in macro --- chorus-macros/src/lib.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/chorus-macros/src/lib.rs b/chorus-macros/src/lib.rs index 6034ff8..9f56a53 100644 --- a/chorus-macros/src/lib.rs +++ b/chorus-macros/src/lib.rs @@ -163,23 +163,20 @@ pub fn sqlx_bitflag_derive(input: TokenStream) -> TokenStream { let name = &ast.ident; quote!{ - #[cfg(feature = "sqlx")] impl sqlx::Type for #name { fn type_info() -> sqlx::mysql::MySqlTypeInfo { u64::type_info() } } - #[cfg(feature = "sqlx")] impl<'q> sqlx::Encode<'q, sqlx::MySql> for #name { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { u64::encode_by_ref(&self.bits(), buf) } } - #[cfg(feature = "sqlx")] impl<'q> sqlx::Decode<'q, sqlx::MySql> for #name { - fn decode(value: >::ValueRef) -> Result { + fn decode(value: >::ValueRef) -> Result { u64::decode(value).map(|d| #name::from_bits(d).unwrap()) } } From 590a6d6828b0b31a9a2efc0b6b55b0250280e82c Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Fri, 7 Jun 2024 14:08:47 -0400 Subject: [PATCH 08/22] Make UserFlags deserialize from string --- src/types/entities/user.rs | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index 1577479..ddbdc22 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -4,9 +4,11 @@ use crate::types::utils::Snowflake; use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_aux::prelude::deserialize_option_number_from_string; use std::fmt::Debug; +use std::num::ParseIntError; +use std::str::FromStr; #[cfg(feature = "client")] use crate::gateway::Updateable; @@ -111,7 +113,7 @@ impl From for PublicUser { const CUSTOM_USER_FLAG_OFFSET: u64 = 1 << 32; bitflags::bitflags! { - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct UserFlags: u64 { const DISCORD_EMPLOYEE = 1 << 0; @@ -137,6 +139,28 @@ bitflags::bitflags! { } } +impl FromStr for UserFlags { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result { + s.parse::().map(UserFlags::from_bits).map(|f| f.unwrap_or(UserFlags::empty())) + } +} + +impl Serialize for UserFlags { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.bits().to_string()) + } +} + +impl<'de> Deserialize<'de> for UserFlags { + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { + let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; + + Ok(UserFlags::from_bits(s).unwrap()) + } +} + #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct UserProfileMetadata { pub guild_id: Option, From bec0269e70ed124926596e9564d11159abef6bdc Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Mon, 17 Jun 2024 15:22:58 -0400 Subject: [PATCH 09/22] Add partial emoji and custom reaction types, refine SQLx mapping --- .../config/types/register_configuration.rs | 2 + src/types/entities/channel.rs | 4 ++ src/types/entities/emoji.rs | 17 +++++++- src/types/entities/guild.rs | 4 +- src/types/entities/invite.rs | 10 +++-- src/types/entities/message.rs | 35 +++++++++++++++-- src/types/schema/channel.rs | 8 ++-- src/types/schema/message.rs | 28 ++++++++----- src/types/utils/rights.rs | 39 ++++++++----------- src/types/utils/snowflake.rs | 23 ++++++++++- 10 files changed, 122 insertions(+), 48 deletions(-) diff --git a/src/types/config/types/register_configuration.rs b/src/types/config/types/register_configuration.rs index 19cedfb..4afc40c 100644 --- a/src/types/config/types/register_configuration.rs +++ b/src/types/config/types/register_configuration.rs @@ -3,6 +3,7 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. use serde::{Deserialize, Serialize}; +use serde_aux::prelude::deserialize_number_from_string; use crate::types::{config::types::subconfigs::register::{ DateOfBirthConfiguration, PasswordConfiguration, RegistrationEmailConfiguration, @@ -22,6 +23,7 @@ pub struct RegisterConfiguration { pub allow_multiple_accounts: bool, pub block_proxies: bool, pub incrementing_discriminators: bool, + #[serde(deserialize_with = "deserialize_number_from_string")] pub default_rights: Rights, } diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index 3d6b4ca..782b6a6 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -64,7 +64,9 @@ pub struct Channel { pub managed: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub member: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub member_count: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub message_count: Option, pub name: Option, pub nsfw: Option, @@ -75,6 +77,7 @@ pub struct Channel { #[cfg(not(feature = "sqlx"))] #[cfg_attr(feature = "client", observe_option_vec)] pub permission_overwrites: Option>>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub permissions: Option, pub position: Option, pub rate_limit_per_user: Option, @@ -85,6 +88,7 @@ pub struct Channel { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub thread_metadata: Option, pub topic: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub total_message_sent: Option, pub user_limit: Option, pub video_quality_mode: Option, diff --git a/src/types/entities/emoji.rs b/src/types/entities/emoji.rs index 1b68f0d..23956b5 100644 --- a/src/types/entities/emoji.rs +++ b/src/types/entities/emoji.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; -use crate::types::Shared; +use crate::types::{PartialEmoji, Shared}; use crate::types::entities::User; use crate::types::Snowflake; @@ -66,3 +66,18 @@ impl PartialEq for Emoji { || self.available != other.available) } } + +impl From for Emoji { + fn from(value: PartialEmoji) -> Self { + Self { + id: value.id.unwrap_or_default(), // TODO: this should be handled differently + name: Some(value.name), + roles: None, + user: None, + require_colons: Some(value.animated), + managed: None, + animated: Some(value.animated), + available: None, + } + } +} \ No newline at end of file diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index b8fd0a4..e3c8a7e 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -59,7 +59,7 @@ pub struct Guild { pub emojis: Vec>, pub explicit_content_filter: Option, //#[cfg_attr(feature = "sqlx", sqlx(try_from = "String"))] - pub features: Option, + pub features: GuildFeaturesList, pub icon: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub icon_hash: Option, @@ -111,7 +111,7 @@ pub struct Guild { #[cfg_attr(feature = "client", observe_option_vec)] pub webhooks: Option>>, #[cfg(feature = "sqlx")] - pub welcome_screen: Option>, + pub welcome_screen: sqlx::types::Json>, #[cfg(not(feature = "sqlx"))] pub welcome_screen: Option, pub widget_channel_id: Option, diff --git a/src/types/entities/invite.rs b/src/types/entities/invite.rs index caec03f..bb1d400 100644 --- a/src/types/entities/invite.rs +++ b/src/types/entities/invite.rs @@ -16,7 +16,9 @@ use super::{Application, Channel, GuildMember, NSFWLevel, User}; #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct Invite { + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub approximate_member_count: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub approximate_presence_count: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub channel: Option, @@ -45,7 +47,7 @@ pub struct Invite { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub target_user: Option, pub temporary: Option, - pub uses: Option, + pub uses: Option, } /// The guild an invite is for. @@ -77,16 +79,16 @@ impl From for InviteGuild { icon: value.icon, splash: value.splash, verification_level: value.verification_level.unwrap_or_default(), - features: value.features.unwrap_or_default(), + features: value.features, vanity_url_code: value.vanity_url_code, description: value.description, banner: value.banner, premium_subscription_count: value.premium_subscription_count, nsfw_deprecated: None, nsfw_level: value.nsfw_level.unwrap_or_default(), - welcome_screen: value.welcome_screen.map(|obj| { + welcome_screen: value.welcome_screen.0.map(|obj| { #[cfg(feature = "sqlx")] - let res = obj.0; + let res = obj; #[cfg(not(feature = "sqlx"))] let res = obj; res diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index 697be0d..e31b7be 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -75,13 +75,14 @@ pub struct Message { pub interaction: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub thread: Option, - #[cfg_attr(feature = "sqlx", sqlx(skip))] + #[cfg(feature = "sqlx")] + pub components: Option>>, + #[cfg(not(feature = "sqlx"))] pub components: Option>, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub sticker_items: Option>, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub stickers: Option>, - pub position: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub role_subscription_data: Option, } @@ -116,7 +117,7 @@ impl PartialEq for Message { && self.thread == other.thread && self.components == other.components && self.sticker_items == other.sticker_items - && self.position == other.position + // && self.position == other.position && self.role_subscription_data == other.role_subscription_data } } @@ -125,12 +126,22 @@ impl PartialEq for Message { /// # Reference /// See pub struct MessageReference { + #[serde(rename = "type")] + pub reference_type: MessageReferenceType, pub message_id: Snowflake, pub channel_id: Snowflake, pub guild_id: Option, pub fail_if_not_exists: Option, } +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Eq, Ord, PartialOrd)] +pub enum MessageReferenceType { + /// A standard reference used by replies and system messages + Default = 0, + /// A reference used to point to a message at a point in time + Forward = 1, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct MessageInteraction { pub id: Snowflake, @@ -404,4 +415,22 @@ bitflags! { /// This message has a forwarded message snapshot attached const HAS_SNAPSHOT = 1 << 14; } +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct PartialEmoji { + #[serde(default)] + pub id: Option, + pub name: String, + #[serde(default)] + pub animated: bool +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, PartialOrd)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[cfg_attr(feature = "sqlx", derive(sqlx::Type))] +#[repr(u8)] +pub enum ReactionType { + Normal = 0, + Burst = 1, // The dreaded super reactions } \ No newline at end of file diff --git a/src/types/schema/channel.rs b/src/types/schema/channel.rs index fc2eea6..bcebe5a 100644 --- a/src/types/schema/channel.rs +++ b/src/types/schema/channel.rs @@ -59,7 +59,7 @@ pub struct GetChannelMessagesSchema { /// Between 1 and 100, defaults to 50. pub limit: Option, #[serde(flatten)] - pub anchor: ChannelMessagesAnchor, + pub anchor: Option, } #[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)] @@ -74,21 +74,21 @@ impl GetChannelMessagesSchema { pub fn before(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::Before(anchor), + anchor: Some(ChannelMessagesAnchor::Before(anchor)), } } pub fn around(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::Around(anchor), + anchor: Some(ChannelMessagesAnchor::Around(anchor)), } } pub fn after(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::After(anchor), + anchor: Some(ChannelMessagesAnchor::After(anchor)), } } diff --git a/src/types/schema/message.rs b/src/types/schema/message.rs index a7c1eeb..4a7f2ab 100644 --- a/src/types/schema/message.rs +++ b/src/types/schema/message.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::types::entities::{ AllowedMention, Component, Embed, MessageReference, PartialDiscordFileAttachment, }; -use crate::types::{Attachment, MessageFlags, MessageType, Snowflake}; +use crate::types::{Attachment, MessageFlags, MessageType, ReactionType, Snowflake}; #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")] @@ -118,13 +118,21 @@ pub struct MessageAck { #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)] pub struct MessageModifySchema { - content: Option, - embeds: Option>, - embed: Option, - allowed_mentions: Option, - components: Option>, - flags: Option, - files: Option>, - payload_json: Option, - attachments: Option>, + pub content: Option, + pub embeds: Option>, + pub embed: Option, + pub allowed_mentions: Option, + pub components: Option>, + pub flags: Option, + pub files: Option>, + pub payload_json: Option, + pub attachments: Option>, } + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)] +pub struct ReactionQuerySchema { + pub after: Option, + pub limit: Option, + #[serde(rename = "type")] + pub reaction_type: Option +} \ No newline at end of file diff --git a/src/types/utils/rights.rs b/src/types/utils/rights.rs index 63978da..ec1cd9c 100644 --- a/src/types/utils/rights.rs +++ b/src/types/utils/rights.rs @@ -2,11 +2,11 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use std::num::ParseIntError; +use std::str::FromStr; use bitflags::bitflags; -use serde::{Deserialize, Serialize}; - -#[cfg(feature = "sqlx")] -use sqlx::{{Decode, Encode, MySql}, database::{HasArguments, HasValueRef}, encode::IsNull, error::BoxDynError, mysql::MySqlValueRef}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use crate::types::UserFlags; bitflags! { /// Rights are instance-wide, per-user permissions for everything you may perform on the instance, @@ -18,7 +18,7 @@ bitflags! { /// /// # Reference /// See - #[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] + #[derive(Debug, Clone, Copy, Eq, PartialEq, chorus_macros::SqlxBitFlags)] pub struct Rights: u64 { /// All rights const OPERATOR = 1 << 0; @@ -132,33 +132,28 @@ bitflags! { } } -#[cfg(feature = "sqlx")] -impl sqlx::Type for Rights { - fn type_info() -> ::TypeInfo { - u64::type_info() - } +impl FromStr for Rights { + type Err = ParseIntError; - fn compatible(ty: &::TypeInfo) -> bool { - u64::compatible(ty) + fn from_str(s: &str) -> Result { + s.parse::().map(Rights::from_bits).map(|f| f.unwrap_or(Rights::empty())) } } -#[cfg(feature = "sqlx")] -impl<'q> Encode<'q, MySql> for Rights { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { - >::encode_by_ref(&self.0.0, buf) +impl Serialize for Rights { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.bits().to_string()) } } -#[cfg(feature = "sqlx")] -impl<'r> Decode<'r, MySql> for Rights { - fn decode(value: >::ValueRef) -> Result { - let raw = >::decode(value)?; - Ok(Rights::from_bits(raw).unwrap()) +impl<'de> Deserialize<'de> for Rights { + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { + let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; + + Ok(Rights::from_bits(s).unwrap()) } } - impl Rights { pub fn any(&self, permission: Rights, check_operator: bool) -> bool { (check_operator && self.contains(Rights::OPERATOR)) || self.contains(permission) diff --git a/src/types/utils/snowflake.rs b/src/types/utils/snowflake.rs index 1582085..341cc50 100644 --- a/src/types/utils/snowflake.rs +++ b/src/types/utils/snowflake.rs @@ -8,6 +8,9 @@ use std::{ }; use chrono::{DateTime, TimeZone, Utc}; +use sqlx::{MySql, TypeInfo}; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; #[cfg(feature = "sqlx")] use sqlx::Type; @@ -19,8 +22,6 @@ const EPOCH: i64 = 1420070400000; /// # Reference /// See #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "sqlx", derive(Type))] -#[cfg_attr(feature = "sqlx", sqlx(transparent))] pub struct Snowflake(pub u64); impl Snowflake { @@ -102,6 +103,24 @@ impl<'de> serde::Deserialize<'de> for Snowflake { } } +impl sqlx::Type for Snowflake { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl<'q> sqlx::Encode<'q, sqlx::MySql> for Snowflake { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + >::encode_by_ref(&self.0.to_string(), buf) + } +} + +impl<'d> sqlx::Decode<'d, sqlx::MySql> for Snowflake { + fn decode(value: >::ValueRef) -> Result { + >::decode(value).map(|s| s.parse::().map(Snowflake).unwrap()) + } +} + #[cfg(test)] mod test { use chrono::{DateTime, Utc}; From d8560093f573adb73204e48f87536a1336dc3eee Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 04:48:50 -0400 Subject: [PATCH 10/22] Use chorus_macros from path, since it's there anyway --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a3b0b36..9526281 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ thiserror = "1.0.56" jsonwebtoken = "8.3.0" log = "0.4.20" async-trait = "0.1.77" -chorus-macros = "0.3.0" +chorus-macros = { path = "./chorus-macros" } sqlx = { version = "0.7.3", features = [ "mysql", "sqlite", From 6261cebfdfd57b9d34491e93c679679caa02fb05 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 05:03:55 -0400 Subject: [PATCH 11/22] Fix test --- tests/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/types.rs b/tests/types.rs index 8895bb4..4c727cb 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -920,7 +920,7 @@ mod entities { .clone() ); let flags = ApplicationFlags::from_bits(0).unwrap(); - assert!(application.flags() == flags); + assert!(application == flags); } #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] From dfffcf431344f2db3830e3e71753a762862ed9a7 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 05:08:36 -0400 Subject: [PATCH 12/22] Add forgotten feature locks --- src/types/utils/snowflake.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/types/utils/snowflake.rs b/src/types/utils/snowflake.rs index 341cc50..87a5586 100644 --- a/src/types/utils/snowflake.rs +++ b/src/types/utils/snowflake.rs @@ -103,18 +103,21 @@ impl<'de> serde::Deserialize<'de> for Snowflake { } } +#[cfg(feature = "sqlx")] impl sqlx::Type for Snowflake { fn type_info() -> ::TypeInfo { >::type_info() } } +#[cfg(feature = "sqlx")] impl<'q> sqlx::Encode<'q, sqlx::MySql> for Snowflake { fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { >::encode_by_ref(&self.0.to_string(), buf) } } +#[cfg(feature = "sqlx")] impl<'d> sqlx::Decode<'d, sqlx::MySql> for Snowflake { fn decode(value: >::ValueRef) -> Result { >::decode(value).map(|s| s.parse::().map(Snowflake).unwrap()) From 3810170400877144eadf55a0b92f40e6f88756c7 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 05:14:28 -0400 Subject: [PATCH 13/22] Remove unused imports, feature locks in macro --- src/types/entities/message.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index e31b7be..d776337 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -6,10 +6,6 @@ use bitflags::bitflags; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; -use sqlx::database::{HasArguments, HasValueRef}; -use sqlx::encode::IsNull; -use sqlx::error::BoxDynError; -use sqlx::{Encode, MySql}; use crate::types::{ Shared, From 06b4dc3abdbc24d21b1738202fc716d0acd17216 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 05:19:51 -0400 Subject: [PATCH 14/22] forgot a file :( --- src/types/utils/snowflake.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/types/utils/snowflake.rs b/src/types/utils/snowflake.rs index 87a5586..a021f90 100644 --- a/src/types/utils/snowflake.rs +++ b/src/types/utils/snowflake.rs @@ -8,11 +8,6 @@ use std::{ }; use chrono::{DateTime, TimeZone, Utc}; -use sqlx::{MySql, TypeInfo}; -use sqlx::database::HasArguments; -use sqlx::encode::IsNull; -#[cfg(feature = "sqlx")] -use sqlx::Type; /// 2015-01-01 const EPOCH: i64 = 1420070400000; From 77004abcfa55e604dd2274ef65564370923e5778 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 05:24:42 -0400 Subject: [PATCH 15/22] Feature lock the macro --- src/types/utils/rights.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/types/utils/rights.rs b/src/types/utils/rights.rs index ec1cd9c..e7de4fc 100644 --- a/src/types/utils/rights.rs +++ b/src/types/utils/rights.rs @@ -18,7 +18,8 @@ bitflags! { /// /// # Reference /// See - #[derive(Debug, Clone, Copy, Eq, PartialEq, chorus_macros::SqlxBitFlags)] + #[derive(Debug, Clone, Copy, Eq, PartialEq)] + #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct Rights: u64 { /// All rights const OPERATOR = 1 << 0; From abc4608be5be5fd5c213e924e406eb3774b90b36 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 05:32:27 -0400 Subject: [PATCH 16/22] Dirty hack --- src/types/entities/invite.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/types/entities/invite.rs b/src/types/entities/invite.rs index bb1d400..239bdb2 100644 --- a/src/types/entities/invite.rs +++ b/src/types/entities/invite.rs @@ -86,13 +86,7 @@ impl From for InviteGuild { premium_subscription_count: value.premium_subscription_count, nsfw_deprecated: None, nsfw_level: value.nsfw_level.unwrap_or_default(), - welcome_screen: value.welcome_screen.0.map(|obj| { - #[cfg(feature = "sqlx")] - let res = obj; - #[cfg(not(feature = "sqlx"))] - let res = obj; - res - }), + welcome_screen: (*value.welcome_screen).to_owned(), // Dirty hack to get around feature locks making different types } } } From c27bc8d5758403a55205f82f062b4c8868c06241 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 05:34:03 -0400 Subject: [PATCH 17/22] Fix test I feel silly. --- tests/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/types.rs b/tests/types.rs index 4c727cb..48b132c 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -920,7 +920,7 @@ mod entities { .clone() ); let flags = ApplicationFlags::from_bits(0).unwrap(); - assert!(application == flags); + assert!(application.flags == flags); } #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] From d7de1d2fb7edd699a136ba01a9e494ff149e820b Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 05:46:27 -0400 Subject: [PATCH 18/22] Fix compilation for real, no dirty hack --- src/types/entities/invite.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/types/entities/invite.rs b/src/types/entities/invite.rs index 239bdb2..388cc32 100644 --- a/src/types/entities/invite.rs +++ b/src/types/entities/invite.rs @@ -86,7 +86,10 @@ impl From for InviteGuild { premium_subscription_count: value.premium_subscription_count, nsfw_deprecated: None, nsfw_level: value.nsfw_level.unwrap_or_default(), - welcome_screen: (*value.welcome_screen).to_owned(), // Dirty hack to get around feature locks making different types + #[cfg(feature = "sqlx")] + welcome_screen: value.welcome_screen.0, + #[cfg(not(feature = "sqlx"))] + welcome_screen: value.welcome_screen, } } } From c9a36ce7254390542bbe8a1cc62090e03ad48f30 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 08:46:07 -0400 Subject: [PATCH 19/22] Maybe fix tests, make UserFlags able to be deserialized from String or u64 --- src/types/entities/user.rs | 5 +++-- src/types/utils/serde.rs | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index ddbdc22..d547be3 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -155,8 +155,9 @@ impl Serialize for UserFlags { impl<'de> Deserialize<'de> for UserFlags { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { - let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; - + // let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; + let s = crate::types::serde::string_or_u64(deserializer)?; + Ok(UserFlags::from_bits(s).unwrap()) } } diff --git a/src/types/utils/serde.rs b/src/types/utils/serde.rs index 8584004..544a305 100644 --- a/src/types/utils/serde.rs +++ b/src/types/utils/serde.rs @@ -1,6 +1,7 @@ use core::fmt; use chrono::{LocalResult, NaiveDateTime}; -use serde::de; +use serde::{de, Deserialize, Deserializer}; +use serde::de::Error; #[doc(hidden)] #[derive(Debug)] @@ -259,4 +260,19 @@ pub(crate) fn serde_from(me: LocalResult, _ts: &V) -> Result } LocalResult::Single(val) => Ok(val), } +} + +#[derive(serde::Deserialize, serde::Serialize)] +#[serde(untagged)] +enum StringOrU64 { + String(String), + U64(u64), +} + +pub fn string_or_u64<'de, D>(d: D) -> Result +where D: Deserializer<'de> { + match StringOrU64::deserialize(d)? { + StringOrU64::String(s) => s.parse::().map_err(D::Error::custom), + StringOrU64::U64(u) => Ok(u) + } } \ No newline at end of file From e553d10f25d2e9febe8b8427883167fa3fa1b54f Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 10:04:00 -0400 Subject: [PATCH 20/22] Add new SerdeBitFlags derive macro, to help reduce repetitive code --- chorus-macros/src/lib.rs | 37 ++++++++++++++++++++++ src/types/entities/application.rs | 2 +- src/types/entities/guild.rs | 2 +- src/types/entities/message.rs | 2 +- src/types/entities/role.rs | 2 +- src/types/entities/user.rs | 25 +-------------- src/types/events/voice_gateway/speaking.rs | 2 +- src/types/schema/channel.rs | 29 +---------------- src/types/utils/rights.rs | 24 +------------- 9 files changed, 45 insertions(+), 80 deletions(-) diff --git a/chorus-macros/src/lib.rs b/chorus-macros/src/lib.rs index 9f56a53..4102039 100644 --- a/chorus-macros/src/lib.rs +++ b/chorus-macros/src/lib.rs @@ -156,6 +156,7 @@ pub fn composite_derive(input: TokenStream) -> TokenStream { } } + #[proc_macro_derive(SqlxBitFlags)] pub fn sqlx_bitflag_derive(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); @@ -163,18 +164,21 @@ pub fn sqlx_bitflag_derive(input: TokenStream) -> TokenStream { let name = &ast.ident; quote!{ + #[cfg(feature = "sqlx")] impl sqlx::Type for #name { fn type_info() -> sqlx::mysql::MySqlTypeInfo { u64::type_info() } } + #[cfg(feature = "sqlx")] impl<'q> sqlx::Encode<'q, sqlx::MySql> for #name { fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { u64::encode_by_ref(&self.bits(), buf) } } + #[cfg(feature = "sqlx")] impl<'q> sqlx::Decode<'q, sqlx::MySql> for #name { fn decode(value: >::ValueRef) -> Result { u64::decode(value).map(|d| #name::from_bits(d).unwrap()) @@ -182,4 +186,37 @@ pub fn sqlx_bitflag_derive(input: TokenStream) -> TokenStream { } } .into() +} + +#[proc_macro_derive(SerdeBitFlags)] +pub fn serde_bitflag_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + + quote! { + impl std::str::FromStr for #name { + type Err = std::num::ParseIntError; + + fn from_str(s: &str) -> Result<#name, Self::Err> { + s.parse::().map(#name::from_bits).map(|f| f.unwrap_or(#name::empty())) + } + } + + impl serde::Serialize for #name { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.bits().to_string()) + } + } + + impl<'de> serde::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> Result<#name, D::Error> where D: serde::de::Deserializer<'de> + Sized { + // let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; + let s = crate::types::serde::string_or_u64(deserializer)?; + + Ok(Self::from_bits(s).unwrap()) + } + } + } + .into() } \ No newline at end of file diff --git a/src/types/entities/application.rs b/src/types/entities/application.rs index f57a71f..ac4cb97 100644 --- a/src/types/entities/application.rs +++ b/src/types/entities/application.rs @@ -102,7 +102,7 @@ pub struct InstallParams { } bitflags! { - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, chorus_macros::SerdeBitFlags)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// # Reference /// See diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index e3c8a7e..db207fc 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -422,7 +422,7 @@ pub enum PremiumTier { } bitflags! { - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, chorus_macros::SerdeBitFlags)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// # Reference /// See diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index d776337..34a7b9b 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -379,7 +379,7 @@ pub enum MessageType { } bitflags! { - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, chorus_macros::SerdeBitFlags)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// # Reference /// See diff --git a/src/types/entities/role.rs b/src/types/entities/role.rs index 712c9ab..8a2c513 100644 --- a/src/types/entities/role.rs +++ b/src/types/entities/role.rs @@ -71,7 +71,7 @@ pub struct RoleTags { } bitflags! { - #[derive(Debug, Default, Clone, Hash, Serialize, Deserialize, PartialEq, Eq)] + #[derive(Debug, Default, Clone, Hash, PartialEq, Eq, chorus_macros::SerdeBitFlags)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] /// Permissions limit what users of certain roles can do on a Guild to Guild basis. /// diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index d547be3..70669d0 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -113,7 +113,7 @@ impl From for PublicUser { const CUSTOM_USER_FLAG_OFFSET: u64 = 1 << 32; bitflags::bitflags! { - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, chorus_macros::SerdeBitFlags)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct UserFlags: u64 { const DISCORD_EMPLOYEE = 1 << 0; @@ -139,29 +139,6 @@ bitflags::bitflags! { } } -impl FromStr for UserFlags { - type Err = ParseIntError; - - fn from_str(s: &str) -> Result { - s.parse::().map(UserFlags::from_bits).map(|f| f.unwrap_or(UserFlags::empty())) - } -} - -impl Serialize for UserFlags { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.bits().to_string()) - } -} - -impl<'de> Deserialize<'de> for UserFlags { - fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { - // let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; - let s = crate::types::serde::string_or_u64(deserializer)?; - - Ok(UserFlags::from_bits(s).unwrap()) - } -} - #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct UserProfileMetadata { pub guild_id: Option, diff --git a/src/types/events/voice_gateway/speaking.rs b/src/types/events/voice_gateway/speaking.rs index 64cf2e7..fea29ad 100644 --- a/src/types/events/voice_gateway/speaking.rs +++ b/src/types/events/voice_gateway/speaking.rs @@ -32,7 +32,7 @@ bitflags! { /// Bitflags of speaking types; /// /// See - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Serialize, Deserialize)] + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, chorus_macros::SerdeBitFlags)] pub struct SpeakingBitflags: u8 { /// Whether we'll be transmitting normal voice audio const MICROPHONE = 1 << 0; diff --git a/src/types/schema/channel.rs b/src/types/schema/channel.rs index bcebe5a..9ae64c6 100644 --- a/src/types/schema/channel.rs +++ b/src/types/schema/channel.rs @@ -131,7 +131,7 @@ impl Default for CreateChannelInviteSchema { } bitflags! { - #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, chorus_macros::SerdeBitFlags)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct InviteFlags: u64 { const GUEST = 1 << 0; @@ -139,33 +139,6 @@ bitflags! { } } -impl Serialize for InviteFlags { - fn serialize(&self, serializer: S) -> Result { - self.bits().to_string().serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for InviteFlags { - fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { - struct FlagsVisitor; - - impl<'de> Visitor<'de> for FlagsVisitor - { - type Value = InviteFlags; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("a raw u64 value of flags") - } - - fn visit_u64(self, v: u64) -> Result { - InviteFlags::from_bits(v).ok_or(serde::de::Error::custom(Error::InvalidFlags(v))) - } - } - - deserializer.deserialize_u64(FlagsVisitor) - } -} - #[derive(Debug, Deserialize, Serialize, Clone, Copy, Default, PartialOrd, Ord, PartialEq, Eq)] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] diff --git a/src/types/utils/rights.rs b/src/types/utils/rights.rs index e7de4fc..598f782 100644 --- a/src/types/utils/rights.rs +++ b/src/types/utils/rights.rs @@ -18,7 +18,7 @@ bitflags! { /// /// # Reference /// See - #[derive(Debug, Clone, Copy, Eq, PartialEq)] + #[derive(Debug, Clone, Copy, Eq, PartialEq, chorus_macros::SerdeBitFlags)] #[cfg_attr(feature = "sqlx", derive(chorus_macros::SqlxBitFlags))] pub struct Rights: u64 { /// All rights @@ -133,28 +133,6 @@ bitflags! { } } -impl FromStr for Rights { - type Err = ParseIntError; - - fn from_str(s: &str) -> Result { - s.parse::().map(Rights::from_bits).map(|f| f.unwrap_or(Rights::empty())) - } -} - -impl Serialize for Rights { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.bits().to_string()) - } -} - -impl<'de> Deserialize<'de> for Rights { - fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { - let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; - - Ok(Rights::from_bits(s).unwrap()) - } -} - impl Rights { pub fn any(&self, permission: Rights, check_operator: bool) -> bool { (check_operator && self.contains(Rights::OPERATOR)) || self.contains(permission) From a8b7d9dfb33d1955481497d70ace39ace68c3a76 Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 18 Jun 2024 10:07:56 -0400 Subject: [PATCH 21/22] u8 -> u64 --- src/types/events/voice_gateway/speaking.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types/events/voice_gateway/speaking.rs b/src/types/events/voice_gateway/speaking.rs index fea29ad..7a505fd 100644 --- a/src/types/events/voice_gateway/speaking.rs +++ b/src/types/events/voice_gateway/speaking.rs @@ -33,7 +33,7 @@ bitflags! { /// /// See #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, chorus_macros::SerdeBitFlags)] - pub struct SpeakingBitflags: u8 { + pub struct SpeakingBitflags: u64 { /// Whether we'll be transmitting normal voice audio const MICROPHONE = 1 << 0; /// Whether we'll be transmitting context audio for video, no speaking indicator From 4430a7c4e6dfd4d7154d318c10d061ef97d6a116 Mon Sep 17 00:00:00 2001 From: kozabrada123 <59031733+kozabrada123@users.noreply.github.com> Date: Tue, 18 Jun 2024 17:18:20 +0200 Subject: [PATCH 22/22] Fix deserialization error w/ guild features list --- src/types/entities/guild.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index db207fc..866b172 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -59,6 +59,7 @@ pub struct Guild { pub emojis: Vec>, pub explicit_content_filter: Option, //#[cfg_attr(feature = "sqlx", sqlx(try_from = "String"))] + #[serde(default)] pub features: GuildFeaturesList, pub icon: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))]