Skip to content

Commit

Permalink
Refactor scalar UDF signatures to use Signature::nullary for zero-a…
Browse files Browse the repository at this point in the history
…rgument functions

- Updated the `Simple0ArgsScalarUDF` to utilize `Signature::nullary` instead of `Signature::exact`.
- Modified tests to reflect the change in signature handling for zero-argument functions.
- Enhanced error messages in type coercion functions to clarify the requirement for nullary signatures.
- Cleaned up redundant checks and improved code readability in the type coercion logic.

This change improves consistency in how zero-argument functions are defined and validated across the codebase.
  • Loading branch information
jayzhan211 committed Dec 22, 2024
1 parent 8b5e1e9 commit a9d2f81
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ async fn scalar_udf_zero_params() -> Result<()> {

let get_100_udf = Simple0ArgsScalarUDF {
name: "get_100".to_string(),
signature: Signature::exact(vec![], Volatility::Immutable),
signature: Signature::nullary(Volatility::Immutable),
return_type: DataType::Int32,
};

Expand Down Expand Up @@ -1121,11 +1121,7 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {

#[tokio::test]
async fn test_valid_zero_argument_signatures() {
let signatures = vec![
Signature::exact(vec![], Volatility::Immutable),
Signature::any(0, Volatility::Immutable),
Signature::nullary(Volatility::Immutable),
];
let signatures = vec![Signature::nullary(Volatility::Immutable)];
for signature in signatures {
let ctx = SessionContext::new();
let udf = ScalarFunctionWrapper {
Expand Down Expand Up @@ -1161,6 +1157,8 @@ async fn test_invalid_zero_argument_signatures() {
Signature::uniform(0, vec![], Volatility::Immutable),
Signature::coercible(vec![], Volatility::Immutable),
Signature::comparable(0, Volatility::Immutable),
Signature::any(0, Volatility::Immutable),
Signature::exact(vec![], Volatility::Immutable),
];
for signature in signatures {
let ctx = SessionContext::new();
Expand Down
2 changes: 0 additions & 2 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,6 @@ impl TypeSignature {
/// Check whether 0 input argument is valid for given `TypeSignature`
pub fn supports_zero_argument(&self) -> bool {
match &self {
TypeSignature::Exact(vec) => vec.is_empty(),
TypeSignature::Any(0) => true,
TypeSignature::Nullary => true,
TypeSignature::OneOf(types) => types
.iter()
Expand Down
103 changes: 29 additions & 74 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub fn data_types_with_scalar_udf(
if signature.type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!("{} does not support zero arguments.", func.name());
return plan_err!("{} does not support zero arguments. Please add TypeSignature::Nullary to your function's signature", func.name());
}
}

Expand Down Expand Up @@ -88,34 +88,27 @@ pub fn data_types_with_aggregate_udf(
current_types: &[DataType],
func: &AggregateUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();
let type_signature = &func.signature().type_signature;

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!("{} does not support zero arguments.", func.name());
}
}

let valid_types = get_valid_types_with_aggregate_udf(
&signature.type_signature,
current_types,
func,
)?;
let valid_types =
get_valid_types_with_aggregate_udf(type_signature, current_types, func)?;

if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

try_coerce_types(
func.name(),
valid_types,
current_types,
&signature.type_signature,
)
try_coerce_types(func.name(), valid_types, current_types, type_signature)
}

/// Performs type coercion for window function arguments.
Expand All @@ -129,31 +122,27 @@ pub fn data_types_with_window_udf(
current_types: &[DataType],
func: &WindowUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();
let type_signature = &func.signature().type_signature;

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!("{} does not support zero arguments.", func.name());
}
}

let valid_types =
get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?;
get_valid_types_with_window_udf(type_signature, current_types, func)?;

if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

try_coerce_types(
func.name(),
valid_types,
current_types,
&signature.type_signature,
)
try_coerce_types(func.name(), valid_types, current_types, type_signature)
}

/// Performs type coercion for function arguments.
Expand All @@ -168,31 +157,28 @@ pub fn data_types(
current_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
let type_signature = &signature.type_signature;

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!(
"signature {:?} does not support zero arguments.",
&signature.type_signature
"{} does not support zero arguments.",
function_name.as_ref()
);
}
}

let valid_types = get_valid_types(&signature.type_signature, current_types)?;
let valid_types = get_valid_types(type_signature, current_types)?;
if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

try_coerce_types(
function_name,
valid_types,
current_types,
&signature.type_signature,
)
try_coerce_types(function_name, valid_types, current_types, type_signature)
}

fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
Expand Down Expand Up @@ -335,6 +321,7 @@ fn get_valid_types_with_window_udf(
}

/// Returns a Vec of all possible valid argument types for the given signature.
/// Empty argument is checked by the caller so no need to re-check here.
fn get_valid_types(
signature: &TypeSignature,
current_types: &[DataType],
Expand Down Expand Up @@ -441,12 +428,6 @@ fn get_valid_types(
}

fn function_length_check(length: usize, expected_length: usize) -> Result<()> {
if length < 1 {
return plan_err!(
"The signature expected at least one argument but received {expected_length}"
);
}

if length != expected_length {
return plan_err!(
"The signature expected {length} arguments but received {expected_length}"
Expand Down Expand Up @@ -645,27 +626,16 @@ fn get_valid_types(

vec![new_types]
}
TypeSignature::Uniform(number, valid_types) => {
if *number == 0 {
return plan_err!("The function expected at least one argument");
}

valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect()
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::UserDefined => {
return internal_err!(
"User-defined signature should be handled by function-specific coerce_types."
)
}
TypeSignature::VariadicAny => {
if current_types.is_empty() {
return plan_err!(
"The function expected at least one argument but received 0"
);
}
vec![current_types.to_vec()]
}
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
Expand Down Expand Up @@ -716,28 +686,13 @@ fn get_valid_types(
}
},
TypeSignature::Nullary => {
if !current_types.is_empty() {
return plan_err!(
"The function expected zero argument but received {}",
current_types.len()
);
}
vec![vec![]]
return plan_err!(
"Nullary expects zero arguments, but received {}",
current_types.len()
);
}
TypeSignature::Any(number) => {
if current_types.is_empty() {
return plan_err!(
"The function expected at least one argument but received 0"
);
}

if current_types.len() != *number {
return plan_err!(
"The function expected {} arguments but received {}",
number,
current_types.len()
);
}
function_length_check(current_types.len(), *number)?;
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
TypeSignature::OneOf(types) => types
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/uuid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Default for UuidFunc {
impl UuidFunc {
pub fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Volatile),
signature: Signature::nullary(Volatility::Volatile),
}
}
}
Expand Down

0 comments on commit a9d2f81

Please sign in to comment.