Skip to content

Commit 2d30334

Browse files
lgingerichalamb
andauthored
Use take_function_args in more places (#14525)
* refactor: apply take_function_args() in functions crate * fix: handle plural vs. singular grammar for "argument(s)" * fix: run cargo clippy and fix errors * style: apply cargo fmt * refactor: move func to datafusion_common and update imports * refactor: apply take_function_args * fix: update test output language * fix: simplify doc test for take_function_args --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 82461b7 commit 2d30334

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+298
-470
lines changed

datafusion/common/src/utils/mod.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub mod memory;
2222
pub mod proxy;
2323
pub mod string_utils;
2424

25-
use crate::error::{_internal_datafusion_err, _internal_err};
25+
use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err};
2626
use crate::{DataFusionError, Result, ScalarValue};
2727
use arrow::array::{
2828
cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray,
@@ -905,6 +905,45 @@ pub fn get_available_parallelism() -> usize {
905905
.get()
906906
}
907907

908+
/// Converts a collection of function arguments into an fixed-size array of length N
909+
/// producing a reasonable error message in case of unexpected number of arguments.
910+
///
911+
/// # Example
912+
/// ```
913+
/// # use datafusion_common::Result;
914+
/// # use datafusion_common::utils::take_function_args;
915+
/// # use datafusion_common::ScalarValue;
916+
/// fn my_function(args: &[ScalarValue]) -> Result<()> {
917+
/// // function expects 2 args, so create a 2-element array
918+
/// let [arg1, arg2] = take_function_args("my_function", args)?;
919+
/// // ... do stuff..
920+
/// Ok(())
921+
/// }
922+
///
923+
/// // Calling the function with 1 argument produces an error:
924+
/// let args = vec![ScalarValue::Int32(Some(10))];
925+
/// let err = my_function(&args).unwrap_err();
926+
/// assert_eq!(err.to_string(), "Execution error: my_function function requires 2 arguments, got 1");
927+
/// // Calling the function with 2 arguments works great
928+
/// let args = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(20))];
929+
/// my_function(&args).unwrap();
930+
/// ```
931+
pub fn take_function_args<const N: usize, T>(
932+
function_name: &str,
933+
args: impl IntoIterator<Item = T>,
934+
) -> Result<[T; N]> {
935+
let args = args.into_iter().collect::<Vec<_>>();
936+
args.try_into().map_err(|v: Vec<T>| {
937+
_exec_datafusion_err!(
938+
"{} function requires {} {}, got {}",
939+
function_name,
940+
N,
941+
if N == 1 { "argument" } else { "arguments" },
942+
v.len()
943+
)
944+
})
945+
}
946+
908947
#[cfg(test)]
909948
mod tests {
910949
use super::*;

datafusion/expr/src/test/function_stub.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use arrow::datatypes::{
2525
DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
2626
};
2727

28-
use datafusion_common::{exec_err, not_impl_err, Result};
28+
use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
2929

3030
use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
3131
use crate::Volatility::Immutable;
@@ -125,9 +125,7 @@ impl AggregateUDFImpl for Sum {
125125
}
126126

127127
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
128-
if arg_types.len() != 1 {
129-
return exec_err!("SUM expects exactly one argument");
130-
}
128+
let [array] = take_function_args(self.name(), arg_types)?;
131129

132130
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
133131
// smallint, int, bigint, real, double precision, decimal, or interval.
@@ -147,7 +145,7 @@ impl AggregateUDFImpl for Sum {
147145
}
148146
}
149147

150-
Ok(vec![coerced_type(&arg_types[0])?])
148+
Ok(vec![coerced_type(array)?])
151149
}
152150

153151
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {

datafusion/functions-aggregate/src/average.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ use arrow::datatypes::{
2727
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
2828
Float64Type, UInt64Type,
2929
};
30-
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
30+
use datafusion_common::{
31+
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
32+
};
3133
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3234
use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
3335
use datafusion_expr::utils::format_state_name;
@@ -247,10 +249,8 @@ impl AggregateUDFImpl for Avg {
247249
}
248250

249251
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
250-
if arg_types.len() != 1 {
251-
return exec_err!("{} expects exactly one argument.", self.name());
252-
}
253-
coerce_avg_type(self.name(), arg_types)
252+
let [args] = take_function_args(self.name(), arg_types)?;
253+
coerce_avg_type(self.name(), std::slice::from_ref(args))
254254
}
255255

256256
fn documentation(&self) -> Option<&Documentation> {

datafusion/functions-aggregate/src/sum.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ use arrow::datatypes::{
3333
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
3434
};
3535
use arrow::{array::ArrayRef, datatypes::Field};
36-
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
36+
use datafusion_common::{
37+
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
38+
};
3739
use datafusion_expr::function::AccumulatorArgs;
3840
use datafusion_expr::function::StateFieldsArgs;
3941
use datafusion_expr::utils::format_state_name;
@@ -125,9 +127,7 @@ impl AggregateUDFImpl for Sum {
125127
}
126128

127129
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
128-
if arg_types.len() != 1 {
129-
return exec_err!("SUM expects exactly one argument");
130-
}
130+
let [args] = take_function_args(self.name(), arg_types)?;
131131

132132
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
133133
// smallint, int, bigint, real, double precision, decimal, or interval.
@@ -147,7 +147,7 @@ impl AggregateUDFImpl for Sum {
147147
}
148148
}
149149

150-
Ok(vec![coerced_type(&arg_types[0])?])
150+
Ok(vec![coerced_type(args)?])
151151
}
152152

153153
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {

datafusion/functions-nested/src/cardinality.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use arrow::datatypes::{
2626
DataType::{FixedSizeList, LargeList, List, Map, UInt64},
2727
};
2828
use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array};
29+
use datafusion_common::utils::take_function_args;
2930
use datafusion_common::Result;
3031
use datafusion_common::{exec_err, plan_err};
3132
use datafusion_expr::{
@@ -127,21 +128,18 @@ impl ScalarUDFImpl for Cardinality {
127128

128129
/// Cardinality SQL function
129130
pub fn cardinality_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
130-
if args.len() != 1 {
131-
return exec_err!("cardinality expects one argument");
132-
}
133-
134-
match &args[0].data_type() {
131+
let [array] = take_function_args("cardinality", args)?;
132+
match &array.data_type() {
135133
List(_) => {
136-
let list_array = as_list_array(&args[0])?;
134+
let list_array = as_list_array(&array)?;
137135
generic_list_cardinality::<i32>(list_array)
138136
}
139137
LargeList(_) => {
140-
let list_array = as_large_list_array(&args[0])?;
138+
let list_array = as_large_list_array(&array)?;
141139
generic_list_cardinality::<i64>(list_array)
142140
}
143141
Map(_, _) => {
144-
let map_array = as_map_array(&args[0])?;
142+
let map_array = as_map_array(&array)?;
145143
generic_map_cardinality(map_array)
146144
}
147145
other => {

datafusion/functions-nested/src/concat.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ use arrow::buffer::OffsetBuffer;
2828
use arrow::datatypes::{DataType, Field};
2929
use datafusion_common::Result;
3030
use datafusion_common::{
31-
cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims,
31+
cast::as_generic_list_array,
32+
exec_err, not_impl_err, plan_err,
33+
utils::{list_ndims, take_function_args},
3234
};
3335
use datafusion_expr::{
3436
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
@@ -415,23 +417,19 @@ fn concat_internal<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
415417

416418
/// Array_append SQL function
417419
pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
418-
if args.len() != 2 {
419-
return exec_err!("array_append expects two arguments");
420-
}
420+
let [array, _] = take_function_args("array_append", args)?;
421421

422-
match args[0].data_type() {
422+
match array.data_type() {
423423
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, true),
424424
_ => general_append_and_prepend::<i32>(args, true),
425425
}
426426
}
427427

428428
/// Array_prepend SQL function
429429
pub(crate) fn array_prepend_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
430-
if args.len() != 2 {
431-
return exec_err!("array_prepend expects two arguments");
432-
}
430+
let [_, array] = take_function_args("array_prepend", args)?;
433431

434-
match args[1].data_type() {
432+
match array.data_type() {
435433
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, false),
436434
_ => general_append_and_prepend::<i32>(args, false),
437435
}

datafusion/functions-nested/src/dimension.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use arrow::datatypes::{
2828
use std::any::Any;
2929

3030
use datafusion_common::cast::{as_large_list_array, as_list_array};
31-
use datafusion_common::{exec_err, plan_err, Result};
31+
use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result};
3232

3333
use crate::utils::{compute_array_dims, make_scalar_function};
3434
use datafusion_expr::{
@@ -203,20 +203,18 @@ impl ScalarUDFImpl for ArrayNdims {
203203

204204
/// Array_dims SQL function
205205
pub fn array_dims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
206-
if args.len() != 1 {
207-
return exec_err!("array_dims needs one argument");
208-
}
206+
let [array] = take_function_args("array_dims", args)?;
209207

210-
let data = match args[0].data_type() {
208+
let data = match array.data_type() {
211209
List(_) => {
212-
let array = as_list_array(&args[0])?;
210+
let array = as_list_array(&array)?;
213211
array
214212
.iter()
215213
.map(compute_array_dims)
216214
.collect::<Result<Vec<_>>>()?
217215
}
218216
LargeList(_) => {
219-
let array = as_large_list_array(&args[0])?;
217+
let array = as_large_list_array(&array)?;
220218
array
221219
.iter()
222220
.map(compute_array_dims)
@@ -234,9 +232,7 @@ pub fn array_dims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
234232

235233
/// Array_ndims SQL function
236234
pub fn array_ndims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
237-
if args.len() != 1 {
238-
return exec_err!("array_ndims needs one argument");
239-
}
235+
let [array_dim] = take_function_args("array_ndims", args)?;
240236

241237
fn general_list_ndims<O: OffsetSizeTrait>(
242238
array: &GenericListArray<O>,
@@ -254,13 +250,13 @@ pub fn array_ndims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
254250

255251
Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
256252
}
257-
match args[0].data_type() {
253+
match array_dim.data_type() {
258254
List(_) => {
259-
let array = as_list_array(&args[0])?;
255+
let array = as_list_array(&array_dim)?;
260256
general_list_ndims::<i32>(array)
261257
}
262258
LargeList(_) => {
263-
let array = as_large_list_array(&args[0])?;
259+
let array = as_large_list_array(&array_dim)?;
264260
general_list_ndims::<i64>(array)
265261
}
266262
array_type => exec_err!("array_ndims does not support type {array_type:?}"),

datafusion/functions-nested/src/distance.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ use datafusion_common::cast::{
3030
as_int64_array,
3131
};
3232
use datafusion_common::utils::coerced_fixed_size_list_to_list;
33-
use datafusion_common::{exec_err, internal_datafusion_err, Result};
33+
use datafusion_common::{
34+
exec_err, internal_datafusion_err, utils::take_function_args, Result,
35+
};
3436
use datafusion_expr::{
3537
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3638
};
@@ -110,9 +112,7 @@ impl ScalarUDFImpl for ArrayDistance {
110112
}
111113

112114
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
113-
if arg_types.len() != 2 {
114-
return exec_err!("array_distance expects exactly two arguments");
115-
}
115+
let [_, _] = take_function_args(self.name(), arg_types)?;
116116
let mut result = Vec::new();
117117
for arg_type in arg_types {
118118
match arg_type {
@@ -142,11 +142,9 @@ impl ScalarUDFImpl for ArrayDistance {
142142
}
143143

144144
pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
145-
if args.len() != 2 {
146-
return exec_err!("array_distance expects exactly two arguments");
147-
}
145+
let [array1, array2] = take_function_args("array_distance", args)?;
148146

149-
match (&args[0].data_type(), &args[1].data_type()) {
147+
match (&array1.data_type(), &array2.data_type()) {
150148
(List(_), List(_)) => general_array_distance::<i32>(args),
151149
(LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
152150
(array_type1, array_type2) => {

datafusion/functions-nested/src/empty.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::datatypes::{
2424
DataType::{Boolean, FixedSizeList, LargeList, List},
2525
};
2626
use datafusion_common::cast::as_generic_list_array;
27-
use datafusion_common::{exec_err, plan_err, Result};
27+
use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result};
2828
use datafusion_expr::{
2929
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3030
};
@@ -117,14 +117,12 @@ impl ScalarUDFImpl for ArrayEmpty {
117117

118118
/// Array_empty SQL function
119119
pub fn array_empty_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
120-
if args.len() != 1 {
121-
return exec_err!("array_empty expects one argument");
122-
}
120+
let [array] = take_function_args("array_empty", args)?;
123121

124-
let array_type = args[0].data_type();
122+
let array_type = array.data_type();
125123
match array_type {
126-
List(_) => general_array_empty::<i32>(&args[0]),
127-
LargeList(_) => general_array_empty::<i64>(&args[0]),
124+
List(_) => general_array_empty::<i32>(array),
125+
LargeList(_) => general_array_empty::<i64>(array),
128126
_ => exec_err!("array_empty does not support type '{array_type:?}'."),
129127
}
130128
}

datafusion/functions-nested/src/except.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ use arrow::array::{cast::AsArray, Array, ArrayRef, GenericListArray, OffsetSizeT
2222
use arrow::buffer::OffsetBuffer;
2323
use arrow::datatypes::{DataType, FieldRef};
2424
use arrow::row::{RowConverter, SortField};
25-
use datafusion_common::{exec_err, internal_err, HashSet, Result};
25+
use datafusion_common::utils::take_function_args;
26+
use datafusion_common::{internal_err, HashSet, Result};
2627
use datafusion_expr::{
2728
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
2829
};
@@ -124,12 +125,7 @@ impl ScalarUDFImpl for ArrayExcept {
124125

125126
/// Array_except SQL function
126127
pub fn array_except_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
127-
if args.len() != 2 {
128-
return exec_err!("array_except needs two arguments");
129-
}
130-
131-
let array1 = &args[0];
132-
let array2 = &args[1];
128+
let [array1, array2] = take_function_args("array_except", args)?;
133129

134130
match (array1.data_type(), array2.data_type()) {
135131
(DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()),

0 commit comments

Comments
 (0)