From bec0269e70ed124926596e9564d11159abef6bdc Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Mon, 17 Jun 2024 15:22:58 -0400 Subject: [PATCH] Add partial emoji and custom reaction types, refine SQLx mapping --- .../config/types/register_configuration.rs | 2 + src/types/entities/channel.rs | 4 ++ src/types/entities/emoji.rs | 17 +++++++- src/types/entities/guild.rs | 4 +- src/types/entities/invite.rs | 10 +++-- src/types/entities/message.rs | 35 +++++++++++++++-- src/types/schema/channel.rs | 8 ++-- src/types/schema/message.rs | 28 ++++++++----- src/types/utils/rights.rs | 39 ++++++++----------- src/types/utils/snowflake.rs | 23 ++++++++++- 10 files changed, 122 insertions(+), 48 deletions(-) diff --git a/src/types/config/types/register_configuration.rs b/src/types/config/types/register_configuration.rs index 19cedfb..4afc40c 100644 --- a/src/types/config/types/register_configuration.rs +++ b/src/types/config/types/register_configuration.rs @@ -3,6 +3,7 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. use serde::{Deserialize, Serialize}; +use serde_aux::prelude::deserialize_number_from_string; use crate::types::{config::types::subconfigs::register::{ DateOfBirthConfiguration, PasswordConfiguration, RegistrationEmailConfiguration, @@ -22,6 +23,7 @@ pub struct RegisterConfiguration { pub allow_multiple_accounts: bool, pub block_proxies: bool, pub incrementing_discriminators: bool, + #[serde(deserialize_with = "deserialize_number_from_string")] pub default_rights: Rights, } diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index 3d6b4ca..782b6a6 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -64,7 +64,9 @@ pub struct Channel { pub managed: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub member: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub member_count: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub message_count: Option, pub name: Option, pub nsfw: Option, @@ -75,6 +77,7 @@ pub struct Channel { #[cfg(not(feature = "sqlx"))] #[cfg_attr(feature = "client", observe_option_vec)] pub permission_overwrites: Option>>, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub permissions: Option, pub position: Option, pub rate_limit_per_user: Option, @@ -85,6 +88,7 @@ pub struct Channel { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub thread_metadata: Option, pub topic: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub total_message_sent: Option, pub user_limit: Option, pub video_quality_mode: Option, diff --git a/src/types/entities/emoji.rs b/src/types/entities/emoji.rs index 1b68f0d..23956b5 100644 --- a/src/types/entities/emoji.rs +++ b/src/types/entities/emoji.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; -use crate::types::Shared; +use crate::types::{PartialEmoji, Shared}; use crate::types::entities::User; use crate::types::Snowflake; @@ -66,3 +66,18 @@ impl PartialEq for Emoji { || self.available != other.available) } } + +impl From for Emoji { + fn from(value: PartialEmoji) -> Self { + Self { + id: value.id.unwrap_or_default(), // TODO: this should be handled differently + name: Some(value.name), + roles: None, + user: None, + require_colons: Some(value.animated), + managed: None, + animated: Some(value.animated), + available: None, + } + } +} \ No newline at end of file diff --git a/src/types/entities/guild.rs b/src/types/entities/guild.rs index b8fd0a4..e3c8a7e 100644 --- a/src/types/entities/guild.rs +++ b/src/types/entities/guild.rs @@ -59,7 +59,7 @@ pub struct Guild { pub emojis: Vec>, pub explicit_content_filter: Option, //#[cfg_attr(feature = "sqlx", sqlx(try_from = "String"))] - pub features: Option, + pub features: GuildFeaturesList, pub icon: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub icon_hash: Option, @@ -111,7 +111,7 @@ pub struct Guild { #[cfg_attr(feature = "client", observe_option_vec)] pub webhooks: Option>>, #[cfg(feature = "sqlx")] - pub welcome_screen: Option>, + pub welcome_screen: sqlx::types::Json>, #[cfg(not(feature = "sqlx"))] pub welcome_screen: Option, pub widget_channel_id: Option, diff --git a/src/types/entities/invite.rs b/src/types/entities/invite.rs index caec03f..bb1d400 100644 --- a/src/types/entities/invite.rs +++ b/src/types/entities/invite.rs @@ -16,7 +16,9 @@ use super::{Application, Channel, GuildMember, NSFWLevel, User}; #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct Invite { + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub approximate_member_count: Option, + #[cfg_attr(feature = "sqlx", sqlx(skip))] pub approximate_presence_count: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub channel: Option, @@ -45,7 +47,7 @@ pub struct Invite { #[cfg_attr(feature = "sqlx", sqlx(skip))] pub target_user: Option, pub temporary: Option, - pub uses: Option, + pub uses: Option, } /// The guild an invite is for. @@ -77,16 +79,16 @@ impl From for InviteGuild { icon: value.icon, splash: value.splash, verification_level: value.verification_level.unwrap_or_default(), - features: value.features.unwrap_or_default(), + features: value.features, vanity_url_code: value.vanity_url_code, description: value.description, banner: value.banner, premium_subscription_count: value.premium_subscription_count, nsfw_deprecated: None, nsfw_level: value.nsfw_level.unwrap_or_default(), - welcome_screen: value.welcome_screen.map(|obj| { + welcome_screen: value.welcome_screen.0.map(|obj| { #[cfg(feature = "sqlx")] - let res = obj.0; + let res = obj; #[cfg(not(feature = "sqlx"))] let res = obj; res diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index 697be0d..e31b7be 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -75,13 +75,14 @@ pub struct Message { pub interaction: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub thread: Option, - #[cfg_attr(feature = "sqlx", sqlx(skip))] + #[cfg(feature = "sqlx")] + pub components: Option>>, + #[cfg(not(feature = "sqlx"))] pub components: Option>, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub sticker_items: Option>, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub stickers: Option>, - pub position: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] pub role_subscription_data: Option, } @@ -116,7 +117,7 @@ impl PartialEq for Message { && self.thread == other.thread && self.components == other.components && self.sticker_items == other.sticker_items - && self.position == other.position + // && self.position == other.position && self.role_subscription_data == other.role_subscription_data } } @@ -125,12 +126,22 @@ impl PartialEq for Message { /// # Reference /// See pub struct MessageReference { + #[serde(rename = "type")] + pub reference_type: MessageReferenceType, pub message_id: Snowflake, pub channel_id: Snowflake, pub guild_id: Option, pub fail_if_not_exists: Option, } +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Eq, Ord, PartialOrd)] +pub enum MessageReferenceType { + /// A standard reference used by replies and system messages + Default = 0, + /// A reference used to point to a message at a point in time + Forward = 1, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct MessageInteraction { pub id: Snowflake, @@ -404,4 +415,22 @@ bitflags! { /// This message has a forwarded message snapshot attached const HAS_SNAPSHOT = 1 << 14; } +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct PartialEmoji { + #[serde(default)] + pub id: Option, + pub name: String, + #[serde(default)] + pub animated: bool +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, PartialOrd)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[cfg_attr(feature = "sqlx", derive(sqlx::Type))] +#[repr(u8)] +pub enum ReactionType { + Normal = 0, + Burst = 1, // The dreaded super reactions } \ No newline at end of file diff --git a/src/types/schema/channel.rs b/src/types/schema/channel.rs index fc2eea6..bcebe5a 100644 --- a/src/types/schema/channel.rs +++ b/src/types/schema/channel.rs @@ -59,7 +59,7 @@ pub struct GetChannelMessagesSchema { /// Between 1 and 100, defaults to 50. pub limit: Option, #[serde(flatten)] - pub anchor: ChannelMessagesAnchor, + pub anchor: Option, } #[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)] @@ -74,21 +74,21 @@ impl GetChannelMessagesSchema { pub fn before(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::Before(anchor), + anchor: Some(ChannelMessagesAnchor::Before(anchor)), } } pub fn around(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::Around(anchor), + anchor: Some(ChannelMessagesAnchor::Around(anchor)), } } pub fn after(anchor: Snowflake) -> Self { Self { limit: None, - anchor: ChannelMessagesAnchor::After(anchor), + anchor: Some(ChannelMessagesAnchor::After(anchor)), } } diff --git a/src/types/schema/message.rs b/src/types/schema/message.rs index a7c1eeb..4a7f2ab 100644 --- a/src/types/schema/message.rs +++ b/src/types/schema/message.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::types::entities::{ AllowedMention, Component, Embed, MessageReference, PartialDiscordFileAttachment, }; -use crate::types::{Attachment, MessageFlags, MessageType, Snowflake}; +use crate::types::{Attachment, MessageFlags, MessageType, ReactionType, Snowflake}; #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")] @@ -118,13 +118,21 @@ pub struct MessageAck { #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)] pub struct MessageModifySchema { - content: Option, - embeds: Option>, - embed: Option, - allowed_mentions: Option, - components: Option>, - flags: Option, - files: Option>, - payload_json: Option, - attachments: Option>, + pub content: Option, + pub embeds: Option>, + pub embed: Option, + pub allowed_mentions: Option, + pub components: Option>, + pub flags: Option, + pub files: Option>, + pub payload_json: Option, + pub attachments: Option>, } + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)] +pub struct ReactionQuerySchema { + pub after: Option, + pub limit: Option, + #[serde(rename = "type")] + pub reaction_type: Option +} \ No newline at end of file diff --git a/src/types/utils/rights.rs b/src/types/utils/rights.rs index 63978da..ec1cd9c 100644 --- a/src/types/utils/rights.rs +++ b/src/types/utils/rights.rs @@ -2,11 +2,11 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use std::num::ParseIntError; +use std::str::FromStr; use bitflags::bitflags; -use serde::{Deserialize, Serialize}; - -#[cfg(feature = "sqlx")] -use sqlx::{{Decode, Encode, MySql}, database::{HasArguments, HasValueRef}, encode::IsNull, error::BoxDynError, mysql::MySqlValueRef}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use crate::types::UserFlags; bitflags! { /// Rights are instance-wide, per-user permissions for everything you may perform on the instance, @@ -18,7 +18,7 @@ bitflags! { /// /// # Reference /// See - #[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] + #[derive(Debug, Clone, Copy, Eq, PartialEq, chorus_macros::SqlxBitFlags)] pub struct Rights: u64 { /// All rights const OPERATOR = 1 << 0; @@ -132,33 +132,28 @@ bitflags! { } } -#[cfg(feature = "sqlx")] -impl sqlx::Type for Rights { - fn type_info() -> ::TypeInfo { - u64::type_info() - } +impl FromStr for Rights { + type Err = ParseIntError; - fn compatible(ty: &::TypeInfo) -> bool { - u64::compatible(ty) + fn from_str(s: &str) -> Result { + s.parse::().map(Rights::from_bits).map(|f| f.unwrap_or(Rights::empty())) } } -#[cfg(feature = "sqlx")] -impl<'q> Encode<'q, MySql> for Rights { - fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { - >::encode_by_ref(&self.0.0, buf) +impl Serialize for Rights { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.bits().to_string()) } } -#[cfg(feature = "sqlx")] -impl<'r> Decode<'r, MySql> for Rights { - fn decode(value: >::ValueRef) -> Result { - let raw = >::decode(value)?; - Ok(Rights::from_bits(raw).unwrap()) +impl<'de> Deserialize<'de> for Rights { + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { + let s = String::deserialize(deserializer)?.parse::().map_err(serde::de::Error::custom)?; + + Ok(Rights::from_bits(s).unwrap()) } } - impl Rights { pub fn any(&self, permission: Rights, check_operator: bool) -> bool { (check_operator && self.contains(Rights::OPERATOR)) || self.contains(permission) diff --git a/src/types/utils/snowflake.rs b/src/types/utils/snowflake.rs index 1582085..341cc50 100644 --- a/src/types/utils/snowflake.rs +++ b/src/types/utils/snowflake.rs @@ -8,6 +8,9 @@ use std::{ }; use chrono::{DateTime, TimeZone, Utc}; +use sqlx::{MySql, TypeInfo}; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; #[cfg(feature = "sqlx")] use sqlx::Type; @@ -19,8 +22,6 @@ const EPOCH: i64 = 1420070400000; /// # Reference /// See #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "sqlx", derive(Type))] -#[cfg_attr(feature = "sqlx", sqlx(transparent))] pub struct Snowflake(pub u64); impl Snowflake { @@ -102,6 +103,24 @@ impl<'de> serde::Deserialize<'de> for Snowflake { } } +impl sqlx::Type for Snowflake { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl<'q> sqlx::Encode<'q, sqlx::MySql> for Snowflake { + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> sqlx::encode::IsNull { + >::encode_by_ref(&self.0.to_string(), buf) + } +} + +impl<'d> sqlx::Decode<'d, sqlx::MySql> for Snowflake { + fn decode(value: >::ValueRef) -> Result { + >::decode(value).map(|s| s.parse::().map(Snowflake).unwrap()) + } +} + #[cfg(test)] mod test { use chrono::{DateTime, Utc};