-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Optimize COUNT( DISTINCT ...) for strings (up to 9x faster)
#8849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
9c44d04
6cb8bbe
9d662a7
1744cb3
e3b0568
12cf50c
4f9a3f0
626b1cb
2e80cb7
d2d1d6d
ebb8726
98a9cd1
07831fa
62c8084
e3b65c8
3f0e9a9
4bc483a
0475687
a764e99
bde49c6
a101b62
0f2fa02
489e130
c39988a
0e33b12
b3bcc68
d7efcf6
a80b39c
3e9289a
7b9d067
d405744
3a6a066
f177aed
8640907
214ba5b
1e10b9c
f5e268d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,3 @@ | ||
| SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; | ||
| SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; | ||
| SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; | ||
| SELECT "BrowserCountry", COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true } | |
| blake3 = { version = "1.0", optional = true } | ||
| chrono = { workspace = true } | ||
| datafusion-common = { workspace = true } | ||
| datafusion-execution = { workspace = true } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needed to use RawTableAlloc trait |
||
| datafusion-expr = { workspace = true } | ||
| half = { version = "2.1", default-features = false } | ||
| hashbrown = { version = "0.14", features = ["raw"] } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,12 +23,14 @@ use arrow_array::types::{ | |
| TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, | ||
| TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, | ||
| }; | ||
| use arrow_array::PrimitiveArray; | ||
| use arrow_array::{PrimitiveArray, StringArray}; | ||
| use arrow_buffer::{BufferBuilder, MutableBuffer, OffsetBuffer}; | ||
|
|
||
| use std::any::Any; | ||
| use std::cmp::Eq; | ||
| use std::fmt::Debug; | ||
| use std::hash::Hash; | ||
| use std::mem; | ||
| use std::sync::Arc; | ||
|
|
||
| use ahash::RandomState; | ||
|
|
@@ -38,9 +40,10 @@ use std::collections::HashSet; | |
| use crate::aggregate::utils::{down_cast_any_ref, Hashable}; | ||
| use crate::expressions::format_state_name; | ||
| use crate::{AggregateExpr, PhysicalExpr}; | ||
| use datafusion_common::cast::{as_list_array, as_primitive_array}; | ||
| use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; | ||
| use datafusion_common::utils::array_into_list_array; | ||
| use datafusion_common::{Result, ScalarValue}; | ||
| use datafusion_execution::memory_pool::proxy::RawTableAllocExt; | ||
| use datafusion_expr::Accumulator; | ||
|
|
||
| type DistinctScalarValues = ScalarValue; | ||
|
|
@@ -152,6 +155,8 @@ impl AggregateExpr for DistinctCount { | |
| Float32 => float_distinct_count_accumulator!(Float32Type), | ||
| Float64 => float_distinct_count_accumulator!(Float64Type), | ||
|
|
||
| Utf8 => Ok(Box::new(StringDistinctCountAccumulator::new())), | ||
alamb marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| _ => Ok(Box::new(DistinctCountAccumulator { | ||
| values: HashSet::default(), | ||
| state_data_type: self.state_data_type.clone(), | ||
|
|
@@ -244,7 +249,7 @@ impl Accumulator for DistinctCountAccumulator { | |
| assert_eq!(states.len(), 1, "array_agg states must be singleton!"); | ||
| let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; | ||
| for scalars in scalar_vec.into_iter() { | ||
| self.values.extend(scalars) | ||
| self.values.extend(scalars); | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
@@ -438,6 +443,212 @@ where | |
| } | ||
| } | ||
|
|
||
| #[derive(Debug)] | ||
| struct StringDistinctCountAccumulator(SSOStringHashSet); | ||
| impl StringDistinctCountAccumulator { | ||
| fn new() -> Self { | ||
| Self(SSOStringHashSet::new()) | ||
| } | ||
| } | ||
|
|
||
| impl Accumulator for StringDistinctCountAccumulator { | ||
| fn state(&self) -> Result<Vec<ScalarValue>> { | ||
| let arr = self.0.state(); | ||
| let list = Arc::new(array_into_list_array(Arc::new(arr))); | ||
| Ok(vec![ScalarValue::List(list)]) | ||
| } | ||
|
|
||
| fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||
| if values.is_empty() { | ||
| return Ok(()); | ||
| } | ||
|
|
||
| let arr = as_string_array(&values[0])?; | ||
| arr.iter().for_each(|value| { | ||
| if let Some(value) = value { | ||
| self.0.insert(value); | ||
| } | ||
| }); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { | ||
| if states.is_empty() { | ||
| return Ok(()); | ||
| } | ||
| assert_eq!( | ||
| states.len(), | ||
| 1, | ||
| "count_distinct states must be single array" | ||
| ); | ||
|
|
||
| let arr = as_list_array(&states[0])?; | ||
| arr.iter().try_for_each(|maybe_list| { | ||
| if let Some(list) = maybe_list { | ||
| let list = as_string_array(&list)?; | ||
|
|
||
| list.iter().for_each(|value| { | ||
| if let Some(value) = value { | ||
| self.0.insert(value); | ||
| } | ||
| }) | ||
| }; | ||
| Ok(()) | ||
| }) | ||
| } | ||
|
|
||
| fn evaluate(&self) -> Result<ScalarValue> { | ||
| Ok(ScalarValue::Int64(Some(self.0.len() as i64))) | ||
| } | ||
|
|
||
| fn size(&self) -> usize { | ||
| // Size of accumulator | ||
| // + SSOStringHashSet size | ||
| std::mem::size_of_val(self) + self.0.size() | ||
| } | ||
| } | ||
|
|
||
| const SHORT_STRING_LEN: usize = mem::size_of::<usize>(); | ||
|
|
||
| #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] | ||
|
||
| struct SSOStringHeader { | ||
| /// hash of the string value (used when resizing table) | ||
| hash: u64, | ||
| /// length of the string | ||
| len: usize, | ||
| /// short strings are stored inline, long strings are stored in the buffer | ||
| offset_or_inline: usize, | ||
| } | ||
|
|
||
| impl SSOStringHeader { | ||
| fn evaluate(&self, buffer: &[u8]) -> String { | ||
| if self.len <= SHORT_STRING_LEN { | ||
| self.offset_or_inline.to_string() | ||
| } else { | ||
| let offset = self.offset_or_inline; | ||
| // SAFETY: buffer is only appended to, and we correctly inserted values | ||
| unsafe { | ||
| std::str::from_utf8_unchecked( | ||
| buffer.get_unchecked(offset..offset + self.len), | ||
| ) | ||
| } | ||
| .to_string() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Short String Optimizated HashSet for String | ||
| // Equivalent to HashSet<String> but with better memory usage | ||
| #[derive(Default)] | ||
| struct SSOStringHashSet { | ||
| /// Core of the HashSet, it stores both the short and long string headers | ||
| header_set: HashSet<SSOStringHeader>, | ||
| /// Used to check if the long string already exists | ||
| long_string_map: hashbrown::raw::RawTable<SSOStringHeader>, | ||
| /// Total size of the map in bytes | ||
| map_size: usize, | ||
| /// Buffer containing all long strings | ||
| buffer: BufferBuilder<u8>, | ||
| /// The random state used to generate hashes | ||
| state: RandomState, | ||
| /// Used for capacity calculation, equivalent to the sum of all string lengths | ||
| size_hint: usize, | ||
| } | ||
|
|
||
| impl SSOStringHashSet { | ||
| fn new() -> Self { | ||
| Self::default() | ||
| } | ||
|
|
||
| fn insert(&mut self, value: &str) { | ||
| let value_len = value.len(); | ||
| self.size_hint += value_len; | ||
| let value_bytes = value.as_bytes(); | ||
|
|
||
| if value_len <= SHORT_STRING_LEN { | ||
|
||
| let inline = value_bytes | ||
| .iter() | ||
| .fold(0usize, |acc, &x| acc << 8 | x as usize); | ||
| let short_string_header = SSOStringHeader { | ||
| hash: 0, // no need for short string cases | ||
| len: value_len, | ||
| offset_or_inline: inline, | ||
| }; | ||
| self.header_set.insert(short_string_header); | ||
| } else { | ||
| let hash = self.state.hash_one(value_bytes); | ||
|
|
||
| let entry = self.long_string_map.get_mut(hash, |header| { | ||
| // if hash matches, check if the bytes match | ||
| let offset = header.offset_or_inline; | ||
| let len = header.len; | ||
|
|
||
| // SAFETY: buffer is only appended to, and we correctly inserted values | ||
| let existing_value = | ||
| unsafe { self.buffer.as_slice().get_unchecked(offset..offset + len) }; | ||
|
|
||
| value_bytes == existing_value | ||
| }); | ||
|
|
||
| if entry.is_none() { | ||
| let offset = self.buffer.len(); | ||
| self.buffer.append_slice(value_bytes); | ||
| let header = SSOStringHeader { | ||
| hash, | ||
| len: value_len, | ||
| offset_or_inline: offset, | ||
| }; | ||
| self.long_string_map.insert_accounted( | ||
alamb marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| header, | ||
| |header| header.hash, | ||
| &mut self.map_size, | ||
| ); | ||
| self.header_set.insert(header); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Returns a StringArray with the current state of the set | ||
| fn state(&self) -> StringArray { | ||
| let mut offsets = Vec::with_capacity(self.size_hint + 1); | ||
| offsets.push(0); | ||
|
|
||
| let mut values = MutableBuffer::new(0); | ||
| let buffer = self.buffer.as_slice(); | ||
|
|
||
| for header in self.header_set.iter() { | ||
| let s = header.evaluate(buffer); | ||
| values.extend_from_slice(s.as_bytes()); | ||
| offsets.push(values.len() as i32); | ||
| } | ||
|
|
||
| let value_offsets = OffsetBuffer::<i32>::new(offsets.into()); | ||
| StringArray::new(value_offsets, values.into(), None) | ||
| } | ||
|
|
||
| fn len(&self) -> usize { | ||
| self.header_set.len() | ||
| } | ||
|
|
||
| fn size(&self) -> usize { | ||
| self.header_set.len() * mem::size_of::<SSOStringHeader>() | ||
| + self.map_size | ||
| + self.buffer.len() | ||
| } | ||
| } | ||
|
|
||
| impl Debug for SSOStringHashSet { | ||
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| f.debug_struct("SSOStringHashSet") | ||
| .field("header_set", &self.header_set) | ||
| // TODO: Print long_string_map | ||
| .field("map_size", &self.map_size) | ||
| .field("buffer", &self.buffer) | ||
| .field("state", &self.state) | ||
| .finish() | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use crate::expressions::NoOp; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.