diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 9675d03a0161..688563baecfa 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -146,3 +146,8 @@ required-features = ["string_expressions"] harness = false name = "upper" required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "pad" +required-features = ["unicode_expressions"] diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs new file mode 100644 index 000000000000..5ff1e2fb860d --- /dev/null +++ b/datafusion/functions/benches/pad.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; +use arrow::datatypes::Int64Type; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::unicode::{lpad, rpad}; +use rand::distributions::{Distribution, Uniform}; +use rand::Rng; +use std::sync::Arc; + +struct Filter { + dist: Dist, +} + +impl Distribution for Filter +where + Dist: Distribution, +{ + fn sample(&self, rng: &mut R) -> T { + self.dist.sample(rng) + } +} + +pub fn create_primitive_array( + size: usize, + null_density: f32, + len: usize, +) -> PrimitiveArray +where + T: ArrowPrimitiveType, +{ + let dist = Filter { + dist: Uniform::new_inclusive::(0, len as i64), + }; + + let mut rng = rand::thread_rng(); + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + Some(rng.sample(&dist)) + } + }) + .collect() +} + +fn create_args( + size: usize, + str_len: usize, + use_string_view: bool, +) -> Vec { + let length_array = Arc::new(create_primitive_array::(size, 0.0, str_len)); + + if !use_string_view { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + let fill_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&length_array) as ArrayRef), + ColumnarValue::Array(fill_array), + ] + } else { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + let fill_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&length_array) as ArrayRef), + ColumnarValue::Array(fill_array), + ] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 2048] { + let mut group = c.benchmark_group("lpad function"); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { + b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + }); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { + b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + }); + + let args = create_args::(size, 32, true); + group.bench_function(BenchmarkId::new("stringview type", size), |b| { + b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + }); + + group.finish(); + + let mut group = c.benchmark_group("rpad function"); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { + b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + }); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { + b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + }); + // + // let args = create_args::(size, 32, true); + // group.bench_function(BenchmarkId::new("stringview type", size), |b| { + // b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + // }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 5caa6acd6745..521cdc5d0ff0 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -16,11 +16,12 @@ // under the License. use std::any::Any; +use std::fmt::Write; use std::sync::Arc; use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, - OffsetSizeTrait, StringViewArray, + Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, + GenericStringBuilder, Int64Array, OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use unicode_segmentation::UnicodeSegmentation; @@ -87,14 +88,18 @@ impl ScalarUDFImpl for LPadFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - make_scalar_function(lpad, vec![])(args) + match args[0].data_type() { + Utf8 | Utf8View => make_scalar_function(lpad::, vec![])(args), + LargeUtf8 => make_scalar_function(lpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function lpad"), + } } } /// Extends the string to length 'length' by prepending the characters fill (a space by default). /// If the string is already longer than length then it is truncated (on the right). /// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { +pub fn lpad(args: &[ArrayRef]) -> Result { if args.len() <= 1 || args.len() > 3 { return exec_err!( "lpad was called with {} arguments. It requires at least 2 and at most 3.", @@ -104,49 +109,28 @@ pub fn lpad(args: &[ArrayRef]) -> Result { let length_array = as_int64_array(&args[1])?; - match args[0].data_type() { - Utf8 => match args.len() { - 2 => lpad_impl::<&GenericStringArray, &GenericStringArray, i32>( - args[0].as_string::(), - length_array, - None, - ), - 3 => lpad_with_replace::<&GenericStringArray, i32>( - args[0].as_string::(), - length_array, - &args[2], - ), - _ => unreachable!(), - }, - LargeUtf8 => match args.len() { - 2 => lpad_impl::<&GenericStringArray, &GenericStringArray, i64>( - args[0].as_string::(), - length_array, - None, - ), - 3 => lpad_with_replace::<&GenericStringArray, i64>( - args[0].as_string::(), - length_array, - &args[2], - ), - _ => unreachable!(), - }, - Utf8View => match args.len() { - 2 => lpad_impl::<&StringViewArray, &GenericStringArray, i32>( - args[0].as_string_view(), - length_array, - None, - ), - 3 => lpad_with_replace::<&StringViewArray, i32>( - args[0].as_string_view(), - length_array, - &args[2], - ), - _ => unreachable!(), - }, - other => { - exec_err!("Unsupported data type {other:?} for function lpad") - } + match (args.len(), args[0].data_type()) { + (2, Utf8View) => lpad_impl::<&StringViewArray, &GenericStringArray, T>( + args[0].as_string_view(), + length_array, + None, + ), + (2, Utf8 | LargeUtf8) => lpad_impl::< + &GenericStringArray, + &GenericStringArray, + T, + >(args[0].as_string::(), length_array, None), + (3, Utf8View) => lpad_with_replace::<&StringViewArray, T>( + args[0].as_string_view(), + length_array, + &args[2], + ), + (3, Utf8 | LargeUtf8) => lpad_with_replace::<&GenericStringArray, T>( + args[0].as_string::(), + length_array, + &args[2], + ), + (_, _) => unreachable!(), } } @@ -159,20 +143,20 @@ where V: StringArrayType<'a>, { match fill_array.data_type() { - Utf8 => lpad_impl::, T>( + Utf8View => lpad_impl::( string_array, length_array, - Some(fill_array.as_string::()), + Some(fill_array.as_string_view()), ), LargeUtf8 => lpad_impl::, T>( string_array, length_array, Some(fill_array.as_string::()), ), - Utf8View => lpad_impl::( + Utf8 => lpad_impl::, T>( string_array, length_array, - Some(fill_array.as_string_view()), + Some(fill_array.as_string::()), ), other => { exec_err!("Unsupported data type {other:?} for function lpad") @@ -190,87 +174,86 @@ where V2: StringArrayType<'a>, T: OffsetSizeTrait, { - if fill_array.is_none() { - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!("lpad requested length {length} too large"); - } + let array = if fill_array.is_none() { + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } - } + for (string, length) in string_array.iter().zip(length_array.iter()) { + if let (Some(string), Some(length)) = (string, length) { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); + continue; + } + + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else { + builder.write_str(" ".repeat(length - graphemes.len()).as_str())?; + builder.write_str(string)?; + builder.append_value(""); + } + } else { + builder.append_null(); + } + } + + builder.finish() } else { - let result = string_array + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + + for ((string, length), fill) in string_array .iter() .zip(length_array.iter()) .zip(fill_array.unwrap().iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!("lpad requested length {length} too large"); - } + { + if let (Some(string), Some(length), Some(fill)) = (string, length, fill) { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); + } - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) - } + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); + continue; + } + + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else if fill_chars.is_empty() { + builder.append_value(string); + } else { + for l in 0..length - graphemes.len() { + let c = *fill_chars.get(l % fill_chars.len()).unwrap(); + builder.write_char(c)?; } + builder.write_str(string)?; + builder.append_value(""); } - _ => Ok(None), - }) - .collect::>>()?; + } else { + builder.append_null(); + } + } - Ok(Arc::new(result) as ArrayRef) - } + builder.finish() + }; + + Ok(Arc::new(array) as ArrayRef) } trait StringArrayType<'a>: ArrayAccessor + Sized { fn iter(&self) -> ArrayIter; } -impl<'a, O: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { +impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { fn iter(&self) -> ArrayIter { - GenericStringArray::::iter(self) + GenericStringArray::::iter(self) } } impl<'a> StringArrayType<'a> for &'a StringViewArray {