diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 161a442df7d7..817eafee4087 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -270,7 +270,7 @@ async fn scalar_udf_zero_params() -> Result<()> { let get_100_udf = Simple0ArgsScalarUDF { name: "get_100".to_string(), - signature: Signature::exact(vec![], Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), return_type: DataType::Int32, }; @@ -1121,11 +1121,7 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { #[tokio::test] async fn test_valid_zero_argument_signatures() { - let signatures = vec![ - Signature::exact(vec![], Volatility::Immutable), - Signature::any(0, Volatility::Immutable), - Signature::nullary(Volatility::Immutable), - ]; + let signatures = vec![Signature::nullary(Volatility::Immutable)]; for signature in signatures { let ctx = SessionContext::new(); let udf = ScalarFunctionWrapper { @@ -1161,6 +1157,8 @@ async fn test_invalid_zero_argument_signatures() { Signature::uniform(0, vec![], Volatility::Immutable), Signature::coercible(vec![], Volatility::Immutable), Signature::comparable(0, Volatility::Immutable), + Signature::any(0, Volatility::Immutable), + Signature::exact(vec![], Volatility::Immutable), ]; for signature in signatures { let ctx = SessionContext::new(); diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 20fb1e43c6dd..4a1f5cc0ad14 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -342,8 +342,6 @@ impl TypeSignature { /// Check whether 0 input argument is valid for given `TypeSignature` pub fn supports_zero_argument(&self) -> bool { match &self { - TypeSignature::Exact(vec) => vec.is_empty(), - TypeSignature::Any(0) => true, TypeSignature::Nullary => true, TypeSignature::OneOf(types) => types .iter() diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 7d2906e1731b..6b43632c4045 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -55,7 +55,7 @@ pub fn data_types_with_scalar_udf( if signature.type_signature.supports_zero_argument() { return Ok(vec![]); } else { - return plan_err!("{} does not support zero arguments.", func.name()); + return plan_err!("{} does not support zero arguments. Please add TypeSignature::Nullary to your function's signature", func.name()); } } @@ -88,21 +88,19 @@ pub fn data_types_with_aggregate_udf( current_types: &[DataType], func: &AggregateUDF, ) -> Result> { - let signature = func.signature(); + let type_signature = &func.signature().type_signature; if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); } else { return plan_err!("{} does not support zero arguments.", func.name()); } } - let valid_types = get_valid_types_with_aggregate_udf( - &signature.type_signature, - current_types, - func, - )?; + let valid_types = + get_valid_types_with_aggregate_udf(type_signature, current_types, func)?; + if valid_types .iter() .any(|data_type| data_type == current_types) @@ -110,12 +108,7 @@ pub fn data_types_with_aggregate_udf( return Ok(current_types.to_vec()); } - try_coerce_types( - func.name(), - valid_types, - current_types, - &signature.type_signature, - ) + try_coerce_types(func.name(), valid_types, current_types, type_signature) } /// Performs type coercion for window function arguments. @@ -129,10 +122,10 @@ pub fn data_types_with_window_udf( current_types: &[DataType], func: &WindowUDF, ) -> Result> { - let signature = func.signature(); + let type_signature = &func.signature().type_signature; if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); } else { return plan_err!("{} does not support zero arguments.", func.name()); @@ -140,7 +133,8 @@ pub fn data_types_with_window_udf( } let valid_types = - get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?; + get_valid_types_with_window_udf(type_signature, current_types, func)?; + if valid_types .iter() .any(|data_type| data_type == current_types) @@ -148,12 +142,7 @@ pub fn data_types_with_window_udf( return Ok(current_types.to_vec()); } - try_coerce_types( - func.name(), - valid_types, - current_types, - &signature.type_signature, - ) + try_coerce_types(func.name(), valid_types, current_types, type_signature) } /// Performs type coercion for function arguments. @@ -168,18 +157,20 @@ pub fn data_types( current_types: &[DataType], signature: &Signature, ) -> Result> { + let type_signature = &signature.type_signature; + if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); } else { return plan_err!( - "signature {:?} does not support zero arguments.", - &signature.type_signature + "{} does not support zero arguments.", + function_name.as_ref() ); } } - let valid_types = get_valid_types(&signature.type_signature, current_types)?; + let valid_types = get_valid_types(type_signature, current_types)?; if valid_types .iter() .any(|data_type| data_type == current_types) @@ -187,12 +178,7 @@ pub fn data_types( return Ok(current_types.to_vec()); } - try_coerce_types( - function_name, - valid_types, - current_types, - &signature.type_signature, - ) + try_coerce_types(function_name, valid_types, current_types, type_signature) } fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { @@ -335,6 +321,7 @@ fn get_valid_types_with_window_udf( } /// Returns a Vec of all possible valid argument types for the given signature. +/// Empty argument is checked by the caller so no need to re-check here. fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], @@ -441,12 +428,6 @@ fn get_valid_types( } fn function_length_check(length: usize, expected_length: usize) -> Result<()> { - if length < 1 { - return plan_err!( - "The signature expected at least one argument but received {expected_length}" - ); - } - if length != expected_length { return plan_err!( "The signature expected {length} arguments but received {expected_length}" @@ -645,27 +626,16 @@ fn get_valid_types( vec![new_types] } - TypeSignature::Uniform(number, valid_types) => { - if *number == 0 { - return plan_err!("The function expected at least one argument"); - } - - valid_types - .iter() - .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) - .collect() - } + TypeSignature::Uniform(number, valid_types) => valid_types + .iter() + .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .collect(), TypeSignature::UserDefined => { return internal_err!( "User-defined signature should be handled by function-specific coerce_types." ) } TypeSignature::VariadicAny => { - if current_types.is_empty() { - return plan_err!( - "The function expected at least one argument but received 0" - ); - } vec![current_types.to_vec()] } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], @@ -716,28 +686,13 @@ fn get_valid_types( } }, TypeSignature::Nullary => { - if !current_types.is_empty() { - return plan_err!( - "The function expected zero argument but received {}", - current_types.len() - ); - } - vec![vec![]] + return plan_err!( + "Nullary expects zero arguments, but received {}", + current_types.len() + ); } TypeSignature::Any(number) => { - if current_types.is_empty() { - return plan_err!( - "The function expected at least one argument but received 0" - ); - } - - if current_types.len() != *number { - return plan_err!( - "The function expected {} arguments but received {}", - number, - current_types.len() - ); - } + function_length_check(current_types.len(), *number)?; vec![(0..*number).map(|i| current_types[i].clone()).collect()] } TypeSignature::OneOf(types) => types diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 6048a70bd8c5..5f0b24232215 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -42,7 +42,7 @@ impl Default for UuidFunc { impl UuidFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } }