Skip to content
7 changes: 4 additions & 3 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::expr::{
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::data_types;
use crate::type_coercion::functions::data_types_with_scalar_udf;
use crate::{utils, LogicalPlan, Projection, Subquery};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
Expand Down Expand Up @@ -139,9 +139,10 @@ impl ExprSchemable for Expr {
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types(&arg_data_types, func.signature()).map_err(|_| {
data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| {
plan_datafusion_err!(
"{}",
"{} and {}",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep error for debugging

err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
Expand Down
23 changes: 10 additions & 13 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,12 @@ pub enum TypeSignature {
/// # Examples
/// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])`
Variadic(Vec<DataType>),
/// One or more arguments of an arbitrary but equal type.
/// DataFusion attempts to coerce all argument types to match the first argument's type
/// The acceptable signature and coercions rules to coerce arguments to this
/// signature are special for this function. If this signature is specified,
/// Datafusion will call [`ScalarUDFImpl::coerce_types`] to prepare argument types.
///
/// # Examples
/// Given types in signature should be coercible to the same final type.
/// A function such as `make_array` is `VariadicEqual`.
///
/// `make_array(i32, i64) -> make_array(i64, i64)`
VariadicEqual,
/// [`ScalarUDFImpl::coerce_types`]: crate::udf::ScalarUDFImpl::coerce_types
UserDefined,
/// One or more arguments with arbitrary types
VariadicAny,
/// Fixed number of arguments of an arbitrary but equal type out of a list of valid types.
Expand Down Expand Up @@ -190,8 +187,8 @@ impl TypeSignature {
.collect::<Vec<&str>>()
.join(", ")]
}
TypeSignature::VariadicEqual => {
vec!["CoercibleT, .., CoercibleT".to_string()]
TypeSignature::UserDefined => {
vec!["UserDefined".to_string()]
}
TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()],
TypeSignature::OneOf(sigs) => {
Expand Down Expand Up @@ -255,10 +252,10 @@ impl Signature {
volatility,
}
}
/// An arbitrary number of arguments of the same type.
pub fn variadic_equal(volatility: Volatility) -> Self {
/// User-defined coercion rules for the function.
pub fn user_defined(volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::VariadicEqual,
type_signature: TypeSignature::UserDefined,
volatility,
}
}
Expand Down
179 changes: 153 additions & 26 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,124 @@ use std::sync::Arc;
use crate::signature::{
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
};
use crate::{Signature, TypeSignature};
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature};
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims};
use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err, Result,
};

use super::binary::{comparison_binary_numeric_coercion, comparison_coercion};

/// Performs type coercion for scalar function arguments.
///
/// Returns the data types to which each argument must be coerced to
/// match `signature`.
///
/// For more details on coercion in general, please see the
/// [`type_coercion`](crate::type_coercion) module.
pub fn data_types_with_scalar_udf(
current_types: &[DataType],
func: &ScalarUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!(
"[data_types_with_scalar_udf] signature {:?} does not support zero arguments.",
&signature.type_signature
);
}
}

let valid_types =
get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?;

if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

// Try and coerce the argument types to match the signature, returning the
// coerced types from the first matching signature.
for valid_types in valid_types {
if let Some(types) = maybe_data_types(&valid_types, current_types) {
return Ok(types);
}
}

// none possible -> Error
plan_err!(
"[data_types_with_scalar_udf] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
)
}

pub fn data_types_with_aggregate_udf(
current_types: &[DataType],
func: &AggregateUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!(
"[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
);
}
}

let valid_types = get_valid_types_with_aggregate_udf(
&signature.type_signature,
current_types,
func,
)?;
if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

// Try and coerce the argument types to match the signature, returning the
// coerced types from the first matching signature.
for valid_types in valid_types {
if let Some(types) = maybe_data_types(&valid_types, current_types) {
return Ok(types);
}
}

// none possible -> Error
plan_err!(
"[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
)
}

/// Performs type coercion for function arguments.
///
/// Returns the data types to which each argument must be coerced to
/// match `signature`.
///
/// For more details on coercion in general, please see the
/// [`type_coercion`](crate::type_coercion) module.
///
/// This function will be replaced with [data_types_with_scalar_udf],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we want to replace this function over time -- I think having the basic simple Signatures that handle most common coercions makes sense to have in DataFusion core (even if it could be done purely in a udf) as it will make creating UDFs easier for users

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with you

/// [data_types_with_aggregate_udf], and data_types_with_window_udf gradually.
pub fn data_types(
current_types: &[DataType],
signature: &Signature,
Expand All @@ -46,7 +147,7 @@ pub fn data_types(
return Ok(vec![]);
} else {
return plan_err!(
"Coercion from {:?} to the signature {:?} failed.",
"[data_types] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
);
Expand All @@ -72,12 +173,56 @@ pub fn data_types(

// none possible -> Error
plan_err!(
"Coercion from {:?} to the signature {:?} failed.",
"[data_types] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
)
}

fn get_valid_types_with_scalar_udf(
signature: &TypeSignature,
current_types: &[DataType],
func: &ScalarUDF,
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
TypeSignature::UserDefined => match func.coerce_types(current_types) {
Ok(coerced_types) => vec![coerced_types],
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
},
TypeSignature::OneOf(signatures) => signatures
.iter()
.filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok())
.flatten()
.collect::<Vec<_>>(),
_ => get_valid_types(signature, current_types)?,
};

Ok(valid_types)
}

fn get_valid_types_with_aggregate_udf(
signature: &TypeSignature,
current_types: &[DataType],
func: &AggregateUDF,
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
TypeSignature::UserDefined => match func.coerce_types(current_types) {
Ok(coerced_types) => vec![coerced_types],
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
},
TypeSignature::OneOf(signatures) => signatures
.iter()
.filter_map(|t| {
get_valid_types_with_aggregate_udf(t, current_types, func).ok()
})
.flatten()
.collect::<Vec<_>>(),
_ => get_valid_types(signature, current_types)?,
};

Ok(valid_types)
}

/// Returns a Vec of all possible valid argument types for the given signature.
fn get_valid_types(
signature: &TypeSignature,
Expand Down Expand Up @@ -184,32 +329,14 @@ fn get_valid_types(
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual => {
let new_type = current_types.iter().skip(1).try_fold(
current_types.first().unwrap().clone(),
|acc, x| {
// The coerced types found by `comparison_coercion` are not guaranteed to be
// coercible for the arguments. `comparison_coercion` returns more loose
// types that can be coerced to both `acc` and `x` for comparison purpose.
// See `maybe_data_types` for the actual coercion.
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
}
},
);

match new_type {
Ok(new_type) => vec![vec![new_type; current_types.len()]],
Err(e) => return Err(e),
}
TypeSignature::UserDefined => {
return internal_err!(
"User-defined signature should be handled by function-specific coerce_types."
)
}
TypeSignature::VariadicAny => {
vec![current_types.to_vec()]
}

TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArraySignature(ref function_signature) => match function_signature
{
Expand Down
4 changes: 4 additions & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ impl AggregateUDF {
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator()
}

pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
}
}

impl<F> From<F> for AggregateUDF
Expand Down
29 changes: 29 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ impl ScalarUDF {
pub fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}

/// See [`ScalarUDFImpl::coerce_types`] for more details.
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
}

impl<F> From<F> for ScalarUDF
Expand Down Expand Up @@ -420,6 +425,29 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
fn short_circuits(&self) -> bool {
false
}

/// Coerce arguments of a function call to types that the function can evaluate.
///
/// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
/// UDFs should return one of the other variants of `TypeSignature` which handle common
/// cases
///
/// See the [type coercion module](crate::type_coercion)
/// documentation for more details on type coercion
///
/// For example, if your function requires a floating point arguments, but the user calls
/// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
/// to ensure the argument was cast to `1::double`
///
/// # Parameters
/// * `arg_types`: The argument types of the arguments this function with
///
/// # Return value
/// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
/// arguments to these specific types.
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}
}

/// ScalarUDF that adds an alias to the underlying function. It is better to
Expand All @@ -446,6 +474,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.inner.name()
}
Expand Down
Loading