From cab4cb1ce63dc79b127ac60540246c50e791da5f Mon Sep 17 00:00:00 2001 From: Quat3rnion Date: Tue, 4 Jun 2024 23:07:04 -0400 Subject: [PATCH] Write custom serialize/deserialize impl's for InviteFlags --- src/types/errors.rs | 3 +++ src/types/schema/channel.rs | 37 ++++++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/types/errors.rs b/src/types/errors.rs index f0a488c..f417aef 100644 --- a/src/types/errors.rs +++ b/src/types/errors.rs @@ -21,6 +21,9 @@ pub enum Error { #[error(transparent)] Guild(#[from] GuildError), + + #[error("Invalid flags value: {0}")] + InvalidFlags(u64) } #[derive(Debug, PartialEq, Eq, thiserror::Error)] diff --git a/src/types/schema/channel.rs b/src/types/schema/channel.rs index 68abc79..da49aec 100644 --- a/src/types/schema/channel.rs +++ b/src/types/schema/channel.rs @@ -2,10 +2,14 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use std::error::Error as StdError; +use std::num::ParseIntError; use bitflags::bitflags; -use serde::{Deserialize, Serialize}; +use bitflags::parser::ParseHex; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::de::{DeserializeOwned, Visitor}; -use crate::types::{ChannelType, DefaultReaction}; +use crate::types::{ChannelType, DefaultReaction, Error}; use crate::types::{entities::PermissionOverwrite, Snowflake}; #[derive(Debug, Deserialize, Serialize, Default, PartialEq, PartialOrd)] @@ -131,12 +135,39 @@ impl Default for CreateChannelInviteSchema { } bitflags! { - #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct InviteFlags: u64 { const GUEST = 1 << 0; } } +impl Serialize for InviteFlags { + fn serialize(&self, serializer: S) -> Result { + self.bits().to_string().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for InviteFlags { + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { + struct FlagsVisitor; + + impl<'de> Visitor<'de> for FlagsVisitor + { + type Value = InviteFlags; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("a raw u64 value of flags") + } + + fn visit_u64(self, v: u64) -> Result where E: serde::de::Error { + InviteFlags::from_bits(v).ok_or(serde::de::Error::custom(Error::InvalidFlags(v))) + } + } + + deserializer.deserialize_u64(FlagsVisitor) + } +} + #[cfg(feature = "sqlx")] impl sqlx::Type for InviteFlags { fn type_info() -> sqlx::mysql::MySqlTypeInfo {