-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Refactor distinct aggregate implementations to use common buffer #18348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,12 +15,20 @@ | |
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use arrow::array::{ArrayRef, ArrowNativeTypeOp}; | ||
| use ahash::RandomState; | ||
| use arrow::array::{ | ||
| Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, | ||
| }; | ||
| use arrow::compute::SortOptions; | ||
| use arrow::datatypes::{ | ||
| ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice, | ||
| }; | ||
| use datafusion_common::{exec_err, internal_datafusion_err, Result}; | ||
| use datafusion_common::cast::{as_list_array, as_primitive_array}; | ||
| use datafusion_common::utils::memory::estimate_memory_size; | ||
| use datafusion_common::utils::SingleRowListArrayBuilder; | ||
| use datafusion_common::{ | ||
| exec_err, internal_datafusion_err, HashSet, Result, ScalarValue, | ||
| }; | ||
| use datafusion_expr_common::accumulator::Accumulator; | ||
| use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; | ||
| use std::sync::Arc; | ||
|
|
@@ -167,3 +175,90 @@ impl<T: DecimalType> DecimalAverager<T> { | |
| } | ||
| } | ||
| } | ||
|
|
||
| /// Generic way to collect distinct values for accumulators. | ||
| /// | ||
| /// The intermediate state is represented as a List of scalar values updated by | ||
| /// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values | ||
| /// in the final evaluation step so that we avoid expensive conversions and | ||
| /// allocations during `update_batch`. | ||
| pub struct GenericDistinctBuffer<T: ArrowPrimitiveType> { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Main implementation here; I toyed with the idea of making this implement |
||
| pub values: HashSet<Hashable<T::Native>, RandomState>, | ||
| data_type: DataType, | ||
| } | ||
|
|
||
| impl<T: ArrowPrimitiveType> std::fmt::Debug for GenericDistinctBuffer<T> { | ||
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| write!( | ||
| f, | ||
| "GenericDistinctBuffer({}, values={})", | ||
| self.data_type, | ||
| self.values.len() | ||
| ) | ||
| } | ||
| } | ||
|
|
||
| impl<T: ArrowPrimitiveType> GenericDistinctBuffer<T> { | ||
| pub fn new(data_type: DataType) -> Self { | ||
| Self { | ||
| values: HashSet::default(), | ||
| data_type, | ||
| } | ||
| } | ||
|
|
||
| /// Mirrors [`Accumulator::state`]. | ||
| pub fn state(&self) -> Result<Vec<ScalarValue>> { | ||
| let arr = Arc::new( | ||
| PrimitiveArray::<T>::from_iter_values(self.values.iter().map(|v| v.0)) | ||
| // Ideally we'd just use T::DATA_TYPE but this misses things like | ||
| // decimal scale/precision and timestamp timezones, which need to | ||
| // match up with Accumulator::state_fields | ||
| .with_data_type(self.data_type.clone()), | ||
| ); | ||
| Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()]) | ||
| } | ||
|
|
||
| /// Mirrors [`Accumulator::update_batch`]. | ||
| pub fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||
| if values.is_empty() { | ||
| return Ok(()); | ||
| } | ||
|
|
||
| debug_assert_eq!( | ||
| values.len(), | ||
| 1, | ||
| "DistinctValuesBuffer::update_batch expects only a single input array" | ||
| ); | ||
|
|
||
| let arr = as_primitive_array::<T>(&values[0])?; | ||
| if arr.null_count() > 0 { | ||
| self.values.extend(arr.iter().flatten().map(Hashable)); | ||
| } else { | ||
| self.values | ||
| .extend(arr.values().iter().cloned().map(Hashable)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice -- this is an elegant way to special case nulls/non nulls |
||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| /// Mirrors [`Accumulator::merge_batch`]. | ||
| pub fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { | ||
| if states.is_empty() { | ||
| return Ok(()); | ||
| } | ||
|
|
||
| let array = as_list_array(&states[0])?; | ||
| for list in array.iter().flatten() { | ||
| self.update_batch(&[list])?; | ||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| /// Mirrors [`Accumulator::size`]. | ||
| pub fn size(&self) -> usize { | ||
| let num_elements = self.values.len(); | ||
| let fixed_size = size_of_val(self) + size_of_val(&self.values); | ||
| estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap() | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if I can pull in
PrimitiveDistinctCountAccumulatorto the deduplication as well, however it is specialized for types which don't need to hash throughHashable(aka non-float types) and I think there might be a performance hit if I try force them to useHashable🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we definitely don't want to be hashing if we can avoid taht