-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Introduce user-defined signature #10439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
ff441ea
8b6016b
36abe3e
0988efe
a64b813
5bbd2a0
6f0a90b
5cc047b
a515aad
17e6ec1
5eaacc8
a606581
40d5444
ca7f942
5cb0d0f
68fdf52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,15 +91,10 @@ pub enum TypeSignature { | |
| /// # Examples | ||
| /// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` | ||
| Variadic(Vec<DataType>), | ||
| /// One or more arguments of an arbitrary but equal type. | ||
| /// DataFusion attempts to coerce all argument types to match the first argument's type | ||
| /// | ||
| /// # Examples | ||
| /// Given types in signature should be coercible to the same final type. | ||
| /// A function such as `make_array` is `VariadicEqual`. | ||
| /// | ||
| /// `make_array(i32, i64) -> make_array(i64, i64)` | ||
| VariadicEqual, | ||
| /// One or more arguments of an arbitrary type and coerced with user-defined coercion rules. | ||
|
||
| VariadicCoercion, | ||
| /// Fixed number of arguments of an arbitrary type and coerced with user-defined coercion rules. | ||
| UniformCoercion(usize), | ||
| /// One or more arguments with arbitrary types | ||
| VariadicAny, | ||
| /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. | ||
|
|
@@ -190,9 +185,15 @@ impl TypeSignature { | |
| .collect::<Vec<&str>>() | ||
| .join(", ")] | ||
| } | ||
| TypeSignature::VariadicEqual => { | ||
| TypeSignature::VariadicCoercion => { | ||
| vec!["CoercibleT, .., CoercibleT".to_string()] | ||
| } | ||
| TypeSignature::UniformCoercion(arg_count) => { | ||
| vec![std::iter::repeat("CoercibleT") | ||
| .take(*arg_count) | ||
| .collect::<Vec<&str>>() | ||
| .join(", ")] | ||
| } | ||
| TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], | ||
| TypeSignature::OneOf(sigs) => { | ||
| sigs.iter().flat_map(|s| s.to_string_repr()).collect() | ||
|
|
@@ -255,10 +256,17 @@ impl Signature { | |
| volatility, | ||
| } | ||
| } | ||
| /// An arbitrary number of arguments of the same type. | ||
| pub fn variadic_equal(volatility: Volatility) -> Self { | ||
| /// An arbitrary number of arguments with user-defined coercion rules. | ||
| pub fn variadic_coercion(volatility: Volatility) -> Self { | ||
| Self { | ||
| type_signature: TypeSignature::VariadicCoercion, | ||
| volatility, | ||
| } | ||
| } | ||
| /// Fixed number of arguments with user-defined coercion rules. | ||
| pub fn uniform_coercion(num: usize, volatility: Volatility) -> Self { | ||
| Self { | ||
| type_signature: TypeSignature::VariadicEqual, | ||
| type_signature: TypeSignature::UniformCoercion(num), | ||
| volatility, | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -20,7 +20,7 @@ use std::sync::Arc; | |||||
| use crate::signature::{ | ||||||
| ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, | ||||||
| }; | ||||||
| use crate::{Signature, TypeSignature}; | ||||||
| use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature}; | ||||||
| use arrow::{ | ||||||
| compute::can_cast_types, | ||||||
| datatypes::{DataType, TimeUnit}, | ||||||
|
|
@@ -30,13 +30,112 @@ use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result} | |||||
|
|
||||||
| use super::binary::{comparison_binary_numeric_coercion, comparison_coercion}; | ||||||
|
|
||||||
| /// Performs type coercion for scalar function arguments. | ||||||
| /// | ||||||
| /// Returns the data types to which each argument must be coerced to | ||||||
| /// match `signature`. | ||||||
| /// | ||||||
| /// For more details on coercion in general, please see the | ||||||
| /// [`type_coercion`](crate::type_coercion) module. | ||||||
| pub fn data_types_with_scalar_udf( | ||||||
| current_types: &[DataType], | ||||||
| func: &ScalarUDF, | ||||||
| ) -> Result<Vec<DataType>> { | ||||||
| let signature = func.signature(); | ||||||
|
|
||||||
| if current_types.is_empty() { | ||||||
| if signature.type_signature.supports_zero_argument() { | ||||||
| return Ok(vec![]); | ||||||
| } else { | ||||||
| return plan_err!( | ||||||
| "[data_types_with_scalar_udf] signature {:?} does not support zero arguments.", | ||||||
| &signature.type_signature | ||||||
| ); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| let valid_types = | ||||||
| get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?; | ||||||
|
|
||||||
| if valid_types | ||||||
| .iter() | ||||||
| .any(|data_type| data_type == current_types) | ||||||
| { | ||||||
| return Ok(current_types.to_vec()); | ||||||
| } | ||||||
|
|
||||||
| // Try and coerce the argument types to match the signature, returning the | ||||||
| // coerced types from the first matching signature. | ||||||
| for valid_types in valid_types { | ||||||
| if let Some(types) = maybe_data_types(&valid_types, current_types) { | ||||||
| return Ok(types); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // none possible -> Error | ||||||
| plan_err!( | ||||||
| "[data_types_with_scalar_udf] Coercion from {:?} to the signature {:?} failed.", | ||||||
| current_types, | ||||||
| &signature.type_signature | ||||||
| ) | ||||||
| } | ||||||
|
|
||||||
| pub fn data_types_with_aggregate_udf( | ||||||
| current_types: &[DataType], | ||||||
| func: &AggregateUDF, | ||||||
| ) -> Result<Vec<DataType>> { | ||||||
| let signature = func.signature(); | ||||||
|
|
||||||
| if current_types.is_empty() { | ||||||
| if signature.type_signature.supports_zero_argument() { | ||||||
| return Ok(vec![]); | ||||||
| } else { | ||||||
| return plan_err!( | ||||||
| "[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", | ||||||
| current_types, | ||||||
| &signature.type_signature | ||||||
| ); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| let valid_types = get_valid_types_with_aggregate_udf( | ||||||
| &signature.type_signature, | ||||||
| current_types, | ||||||
| func, | ||||||
| )?; | ||||||
| if valid_types | ||||||
| .iter() | ||||||
| .any(|data_type| data_type == current_types) | ||||||
| { | ||||||
| return Ok(current_types.to_vec()); | ||||||
| } | ||||||
|
|
||||||
| // Try and coerce the argument types to match the signature, returning the | ||||||
| // coerced types from the first matching signature. | ||||||
| for valid_types in valid_types { | ||||||
| if let Some(types) = maybe_data_types(&valid_types, current_types) { | ||||||
| return Ok(types); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // none possible -> Error | ||||||
| plan_err!( | ||||||
| "[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", | ||||||
| current_types, | ||||||
| &signature.type_signature | ||||||
| ) | ||||||
| } | ||||||
|
|
||||||
| /// Performs type coercion for function arguments. | ||||||
| /// | ||||||
| /// Returns the data types to which each argument must be coerced to | ||||||
| /// match `signature`. | ||||||
| /// | ||||||
| /// For more details on coercion in general, please see the | ||||||
| /// [`type_coercion`](crate::type_coercion) module. | ||||||
| /// | ||||||
| /// This function will be replaced with [data_types_with_scalar_udf], | ||||||
|
||||||
| /// [data_types_with_aggregate_udf], and data_types_with_window_udf gradually. | ||||||
| pub fn data_types( | ||||||
| current_types: &[DataType], | ||||||
| signature: &Signature, | ||||||
|
|
@@ -46,7 +145,7 @@ pub fn data_types( | |||||
| return Ok(vec![]); | ||||||
| } else { | ||||||
| return plan_err!( | ||||||
| "Coercion from {:?} to the signature {:?} failed.", | ||||||
| "[data_types] Coercion from {:?} to the signature {:?} failed.", | ||||||
| current_types, | ||||||
| &signature.type_signature | ||||||
| ); | ||||||
|
|
@@ -72,12 +171,74 @@ pub fn data_types( | |||||
|
|
||||||
| // none possible -> Error | ||||||
| plan_err!( | ||||||
| "Coercion from {:?} to the signature {:?} failed.", | ||||||
| "[data_types] Coercion from {:?} to the signature {:?} failed.", | ||||||
| current_types, | ||||||
| &signature.type_signature | ||||||
| ) | ||||||
| } | ||||||
|
|
||||||
| fn get_valid_types_with_scalar_udf( | ||||||
| signature: &TypeSignature, | ||||||
| current_types: &[DataType], | ||||||
| func: &ScalarUDF, | ||||||
| ) -> Result<Vec<Vec<DataType>>> { | ||||||
| let valid_types = match signature { | ||||||
| TypeSignature::VariadicCoercion => { | ||||||
| vec![func.coerce_types(current_types)?] | ||||||
| } | ||||||
| TypeSignature::UniformCoercion(num) => { | ||||||
| if *num != current_types.len() { | ||||||
| return plan_err!( | ||||||
| "The function expected {} arguments but received {}", | ||||||
| num, | ||||||
| current_types.len() | ||||||
| ); | ||||||
| } | ||||||
| vec![func.coerce_types(current_types)?] | ||||||
| } | ||||||
| TypeSignature::OneOf(signatures) => signatures | ||||||
| .iter() | ||||||
| .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok()) | ||||||
| .flatten() | ||||||
| .collect::<Vec<_>>(), | ||||||
| _ => get_valid_types(signature, current_types)?, | ||||||
| }; | ||||||
|
|
||||||
| Ok(valid_types) | ||||||
| } | ||||||
|
|
||||||
| fn get_valid_types_with_aggregate_udf( | ||||||
| signature: &TypeSignature, | ||||||
| current_types: &[DataType], | ||||||
| func: &AggregateUDF, | ||||||
| ) -> Result<Vec<Vec<DataType>>> { | ||||||
| let valid_types = match signature { | ||||||
| TypeSignature::VariadicCoercion => { | ||||||
| vec![func.coerce_types(current_types)?] | ||||||
| } | ||||||
| TypeSignature::UniformCoercion(num) => { | ||||||
| if *num != current_types.len() { | ||||||
| return plan_err!( | ||||||
| "The function expected {} arguments but received {}", | ||||||
| num, | ||||||
| current_types.len() | ||||||
| ); | ||||||
| } | ||||||
| vec![func.coerce_types(current_types)?] | ||||||
| } | ||||||
| TypeSignature::OneOf(signatures) => signatures | ||||||
| .iter() | ||||||
| .filter_map(|t| { | ||||||
| get_valid_types_with_aggregate_udf(t, current_types, func).ok() | ||||||
| }) | ||||||
| .flatten() | ||||||
| .collect::<Vec<_>>(), | ||||||
| _ => get_valid_types(signature, current_types)?, | ||||||
| }; | ||||||
|
|
||||||
| Ok(valid_types) | ||||||
| } | ||||||
|
|
||||||
| /// Returns a Vec of all possible valid argument types for the given signature. | ||||||
| fn get_valid_types( | ||||||
| signature: &TypeSignature, | ||||||
|
|
@@ -184,32 +345,14 @@ fn get_valid_types( | |||||
| .iter() | ||||||
| .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) | ||||||
| .collect(), | ||||||
| TypeSignature::VariadicEqual => { | ||||||
| let new_type = current_types.iter().skip(1).try_fold( | ||||||
| current_types.first().unwrap().clone(), | ||||||
| |acc, x| { | ||||||
| // The coerced types found by `comparison_coercion` are not guaranteed to be | ||||||
| // coercible for the arguments. `comparison_coercion` returns more loose | ||||||
| // types that can be coerced to both `acc` and `x` for comparison purpose. | ||||||
| // See `maybe_data_types` for the actual coercion. | ||||||
| let coerced_type = comparison_coercion(&acc, x); | ||||||
| if let Some(coerced_type) = coerced_type { | ||||||
| Ok(coerced_type) | ||||||
| } else { | ||||||
| internal_err!("Coercion from {acc:?} to {x:?} failed.") | ||||||
| } | ||||||
| }, | ||||||
| ); | ||||||
|
|
||||||
| match new_type { | ||||||
| Ok(new_type) => vec![vec![new_type; current_types.len()]], | ||||||
| Err(e) => return Err(e), | ||||||
| } | ||||||
| TypeSignature::VariadicCoercion | TypeSignature::UniformCoercion(_) => { | ||||||
| return internal_err!( | ||||||
| "Coercion signature is handled in function-specific get_valid_types." | ||||||
|
||||||
| "Coercion signature is handled in function-specific get_valid_types." | |
| "Coercion signature should be handled by in function-specific coerce_types." |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -213,6 +213,11 @@ impl ScalarUDF { | |||||||||||||||||||||||||||||||||||||||||
| pub fn short_circuits(&self) -> bool { | ||||||||||||||||||||||||||||||||||||||||||
| self.inner.short_circuits() | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /// See [`ScalarUDFImpl::coerce_types`] for more details. | ||||||||||||||||||||||||||||||||||||||||||
| pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { | ||||||||||||||||||||||||||||||||||||||||||
| self.inner.coerce_types(arg_types) | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| impl<F> From<F> for ScalarUDF | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -420,6 +425,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |||||||||||||||||||||||||||||||||||||||||
| fn short_circuits(&self) -> bool { | ||||||||||||||||||||||||||||||||||||||||||
| false | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /// Coerce the types of the arguments to the types that the function expects | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
| /// Coerce the types of the arguments to the types that the function expects | |
| /// Coerce arguments of a function call to types that the function can evaluate. | |
| /// | |
| /// This function is only called if [`Self::signature`] returns [`TypeSignature::UserDefined`]. Most | |
| /// UDFs should return one of the other variants of `TypeSignature` which handle common | |
| /// cases | |
| /// | |
| /// See the [type coercion module](datafusion_expr::type_coercion) | |
| /// documentation for more details on type coercion | |
| /// | |
| /// For example, if your function requires a floating point arguments, but the user calls | |
| /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` | |
| /// to ensure the argument was cast to `1::double` | |
| /// | |
| /// # Parameters | |
| /// * `arg_types`: The argument types of the arguments this function with | |
| /// | |
| /// # Return value | |
| /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call | |
| /// arguments to these specific types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep error for debugging