Skip to content

Commit

Permalink
Fix argument parsing broken by serenity 0.12 port (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev authored Nov 30, 2023
1 parent 1ccd19a commit bcd8958
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 120 deletions.
14 changes: 1 addition & 13 deletions macros/src/command/slash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,7 @@ pub fn generate_slash_action(inv: &Invocation) -> Result<proc_macro2::TokenStrea
let ( #( #param_identifiers, )* ) = ::poise::parse_slash_args!(
ctx.serenity_context, ctx.interaction, ctx.args =>
#( (#param_names: #param_types), )*
).await.map_err(|error| match error {
poise::SlashArgError::CommandStructureMismatch { description, .. } => {
poise::FrameworkError::new_command_structure_mismatch(ctx, description)
},
poise::SlashArgError::Parse { error, input, .. } => {
poise::FrameworkError::new_argument_parse(
ctx.into(),
Some(input),
error,
)
},
poise::SlashArgError::__NonExhaustive => unreachable!(),
})?;
).await.map_err(|error| error.to_framework_error(ctx))?;

if !ctx.framework.options.manual_cooldowns {
ctx.command.cooldowns.lock().unwrap().start_cooldown(ctx.cooldown_context());
Expand Down
54 changes: 50 additions & 4 deletions src/slash_argument/slash_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,73 @@ pub enum SlashArgError {
/// Original input string
input: String,
},
/// The argument passed by the user is invalid in this context. E.g. a Member parameter in DMs
#[non_exhaustive]
Invalid(
/// Human readable description of the error
&'static str,
),
/// HTTP error occured while retrieving the model type from Discord
Http(serenity::Error),
#[doc(hidden)]
__NonExhaustive,
}

/// Support functions for macro which can't create #[non_exhaustive] enum variants
#[doc(hidden)]
impl SlashArgError {
pub fn new_command_structure_mismatch(description: &'static str) -> Self {
Self::CommandStructureMismatch { description }
}

pub fn to_framework_error<U, E>(
self,
ctx: crate::ApplicationContext<'_, U, E>,
) -> crate::FrameworkError<'_, U, E> {
match self {
Self::CommandStructureMismatch { description } => {
crate::FrameworkError::CommandStructureMismatch { ctx, description }
}
Self::Parse { error, input } => crate::FrameworkError::ArgumentParse {
ctx: ctx.into(),
error,
input: Some(input),
},
Self::Invalid(description) => crate::FrameworkError::ArgumentParse {
ctx: ctx.into(),
error: description.into(),
input: None,
},
Self::Http(error) => crate::FrameworkError::ArgumentParse {
ctx: ctx.into(),
error: error.into(),
input: None,
},
Self::__NonExhaustive => unreachable!(),
}
}
}

impl std::fmt::Display for SlashArgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CommandStructureMismatch { description } => {
write!(
f,
"Bot author did not register their commands correctly ({})",
description
"Bot author did not register their commands correctly ({description})",
)
}
Self::Parse { error, input } => {
write!(f, "Failed to parse `{}` as argument: {}", input, error)
write!(f, "Failed to parse `{input}` as argument: {error}")
}
Self::Invalid(description) => {
write!(f, "You can't use this parameter here: {description}",)
}
Self::Http(error) => {
write!(
f,
"Error occured while retrieving data from Discord: {error}",
)
}
Self::__NonExhaustive => unreachable!(),
}
Expand All @@ -56,8 +101,9 @@ impl std::fmt::Display for SlashArgError {
impl std::error::Error for SlashArgError {
fn cause(&self) -> Option<&dyn std::error::Error> {
match self {
Self::Http(error) => Some(error),
Self::Parse { error, input: _ } => Some(&**error),
Self::CommandStructureMismatch { description: _ } => None,
Self::Invalid { .. } | Self::CommandStructureMismatch { .. } => None,
Self::__NonExhaustive => unreachable!(),
}
}
Expand Down
162 changes: 59 additions & 103 deletions src/slash_argument/slash_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,16 @@ macro_rules! impl_for_integer {
_: &serenity::CommandInteraction,
value: &serenity::ResolvedValue<'_>,
) -> Result<$t, SlashArgError> {
let value = match value {
serenity::ResolvedValue::Integer(int) => *int,
_ => return Err(SlashArgError::CommandStructureMismatch { description: "expected integer" })
};

value
.try_into()
.map_err(|_| SlashArgError::CommandStructureMismatch { description: "received out of bounds integer" })
match *value {
serenity::ResolvedValue::Integer(x) => x
.try_into()
.map_err(|_| SlashArgError::CommandStructureMismatch {
description: "received out of bounds integer",
}),
_ => Err(SlashArgError::CommandStructureMismatch {
description: "expected integer",
}),
}
}

fn create(builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption {
Expand All @@ -160,90 +162,6 @@ macro_rules! impl_for_integer {
}
impl_for_integer!(i8 i16 i32 i64 isize u8 u16 u32 u64 usize);

/// Implements slash argument trait for float types
macro_rules! impl_for_float {
($($t:ty)*) => { $(
#[async_trait::async_trait]
impl SlashArgumentHack<$t> for &PhantomData<$t> {
async fn extract(
self,
_: &serenity::Context,
_: &serenity::CommandInteraction,
value: &serenity::ResolvedValue<'_>,
) -> Result<$t, SlashArgError> {
match value {
serenity::ResolvedValue::Number(float) => Ok(*float as $t),
_ => Err(SlashArgError::CommandStructureMismatch { description: "expected float" })
}
}

fn create(self, builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption {
builder.kind(serenity::CommandOptionType::Number)
}
}
)* };
}
impl_for_float!(f32 f64);

#[async_trait::async_trait]
impl SlashArgumentHack<bool> for &PhantomData<bool> {
async fn extract(
self,
_: &serenity::Context,
_: &serenity::CommandInteraction,
value: &serenity::ResolvedValue<'_>,
) -> Result<bool, SlashArgError> {
match value {
serenity::ResolvedValue::Boolean(val) => Ok(*val),
_ => Err(SlashArgError::CommandStructureMismatch {
description: "expected bool",
}),
}
}

fn create(self, builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption {
builder.kind(serenity::CommandOptionType::Boolean)
}
}

#[async_trait::async_trait]
impl SlashArgumentHack<serenity::Attachment> for &PhantomData<serenity::Attachment> {
async fn extract(
self,
_: &serenity::Context,
interaction: &serenity::CommandInteraction,
value: &serenity::ResolvedValue<'_>,
) -> Result<serenity::Attachment, SlashArgError> {
let attachment_id = match value {
serenity::ResolvedValue::String(val) => {
val.parse()
.map_err(|_| SlashArgError::CommandStructureMismatch {
description: "improper attachment id passed",
})?
}
_ => {
return Err(SlashArgError::CommandStructureMismatch {
description: "expected attachment id",
})
}
};

interaction
.data
.resolved
.attachments
.get(&attachment_id)
.cloned()
.ok_or(SlashArgError::CommandStructureMismatch {
description: "attachment id with no attachment",
})
}

fn create(self, builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption {
builder.kind(serenity::CommandOptionType::Attachment)
}
}

#[async_trait::async_trait]
impl<T: SlashArgument + Sync> SlashArgumentHack<T> for &PhantomData<T> {
async fn extract(
Expand All @@ -264,18 +182,22 @@ impl<T: SlashArgument + Sync> SlashArgumentHack<T> for &PhantomData<T> {
}
}

/// Implements `SlashArgumentHack` for a model type that is represented in interactions via an ID
/// Versatile macro to implement `SlashArgumentHack` for simple types
macro_rules! impl_slash_argument {
($type:ty, $slash_param_type:ident) => {
($type:ty, |$ctx:pat, $interaction:pat, $slash_param_type:ident ( $($arg:pat),* )| $extractor:expr) => {
#[async_trait::async_trait]
impl SlashArgument for $type {
async fn extract(
ctx: &serenity::Context,
interaction: &serenity::CommandInteraction,
$ctx: &serenity::Context,
$interaction: &serenity::CommandInteraction,
value: &serenity::ResolvedValue<'_>,
) -> Result<$type, SlashArgError> {
// We can parse IDs by falling back to the generic serenity::ArgumentConvert impl
PhantomData::<$type>.extract(ctx, interaction, value).await
match *value {
serenity::ResolvedValue::$slash_param_type( $($arg),* ) => Ok( $extractor ),
_ => Err(SlashArgError::CommandStructureMismatch {
description: concat!("expected ", stringify!($slash_param_type))
}),
}
}

fn create(builder: serenity::CreateCommandOption) -> serenity::CreateCommandOption {
Expand All @@ -284,8 +206,42 @@ macro_rules! impl_slash_argument {
}
};
}
impl_slash_argument!(serenity::Member, User);
impl_slash_argument!(serenity::User, User);
impl_slash_argument!(serenity::Channel, Channel);
impl_slash_argument!(serenity::GuildChannel, Channel);
impl_slash_argument!(serenity::Role, Role);

impl_slash_argument!(f32, |_, _, Number(x)| x as f32);
impl_slash_argument!(f64, |_, _, Number(x)| x);
impl_slash_argument!(bool, |_, _, Boolean(x)| x);
impl_slash_argument!(serenity::Attachment, |_, _, Attachment(att)| att.clone());
impl_slash_argument!(serenity::Member, |ctx, interaction, User(user, _)| {
interaction
.guild_id
.ok_or(SlashArgError::Invalid("cannot use member parameter in DMs"))?
.member(ctx, user.id)
.await
.map_err(SlashArgError::Http)?
});
impl_slash_argument!(serenity::PartialMember, |_, _, User(_, member)| {
member
.ok_or(SlashArgError::Invalid("cannot use member parameter in DMs"))?
.clone()
});
impl_slash_argument!(serenity::User, |_, _, User(user, _)| user.clone());
impl_slash_argument!(serenity::UserId, |_, _, User(user, _)| user.id);
impl_slash_argument!(serenity::Channel, |ctx, _, Channel(channel)| {
channel
.id
.to_channel(ctx)
.await
.map_err(SlashArgError::Http)?
});
impl_slash_argument!(serenity::ChannelId, |_, _, Channel(channel)| channel.id);
impl_slash_argument!(serenity::PartialChannel, |_, _, Channel(channel)| channel
.clone());
impl_slash_argument!(serenity::GuildChannel, |ctx, _, Channel(channel)| {
let channel_res = channel.id.to_channel(ctx).await;
let channel = channel_res.map_err(SlashArgError::Http)?.guild();
channel.ok_or(SlashArgError::Http(serenity::Error::Model(
serenity::ModelError::InvalidChannelType,
)))?
});
impl_slash_argument!(serenity::Role, |_, _, Role(role)| role.clone());
impl_slash_argument!(serenity::RoleId, |_, _, Role(role)| role.id);

0 comments on commit bcd8958

Please sign in to comment.