Skip to content

Commit 4961ca6

Browse files
committed
Fix join on arrays of unhashable types
Update can_hash to match currently supported hashes.
1 parent 1e69946 commit 4961ca6

File tree

1 file changed

+42
-17
lines changed

1 file changed

+42
-17
lines changed

datafusion/expr/src/utils.rs

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use crate::{
2929
};
3030
use datafusion_expr_common::signature::{Signature, TypeSignature};
3131

32-
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
32+
use arrow::datatypes::{DataType, Field, Schema};
3333
use datafusion_common::tree_node::{
3434
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
3535
};
@@ -958,7 +958,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr(
958958

959959
/// Can this data type be used in hash join equal conditions??
960960
/// Data types here come from function 'equal_rows', if more data types are supported
961-
/// in equal_rows(hash join), add those data types here to generate join logical plan.
961+
/// in create_hashes, add those data types here to generate join logical plan.
962962
pub fn can_hash(data_type: &DataType) -> bool {
963963
match data_type {
964964
DataType::Null => true,
@@ -971,31 +971,38 @@ pub fn can_hash(data_type: &DataType) -> bool {
971971
DataType::UInt16 => true,
972972
DataType::UInt32 => true,
973973
DataType::UInt64 => true,
974+
DataType::Float16 => true,
974975
DataType::Float32 => true,
975976
DataType::Float64 => true,
976-
DataType::Timestamp(time_unit, _) => match time_unit {
977-
TimeUnit::Second => true,
978-
TimeUnit::Millisecond => true,
979-
TimeUnit::Microsecond => true,
980-
TimeUnit::Nanosecond => true,
981-
},
977+
DataType::Decimal128(_, _) => true,
978+
DataType::Decimal256(_, _) => true,
979+
DataType::Timestamp(_, _) => true,
982980
DataType::Utf8 => true,
983981
DataType::LargeUtf8 => true,
984982
DataType::Utf8View => true,
985-
DataType::Decimal128(_, _) => true,
983+
DataType::Binary => true,
984+
DataType::LargeBinary => true,
985+
DataType::BinaryView => true,
986986
DataType::Date32 => true,
987987
DataType::Date64 => true,
988+
DataType::Time32(_) => true,
989+
DataType::Time64(_) => true,
990+
DataType::Duration(_) => true,
991+
DataType::Interval(_) => true,
988992
DataType::FixedSizeBinary(_) => true,
989-
DataType::Dictionary(key_type, value_type)
990-
if *value_type.as_ref() == DataType::Utf8 =>
991-
{
992-
DataType::is_dictionary_key_type(key_type)
993+
DataType::Dictionary(key_type, value_type) => {
994+
DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
993995
}
994-
DataType::List(_) => true,
995-
DataType::LargeList(_) => true,
996-
DataType::FixedSizeList(_, _) => true,
996+
DataType::List(value_type) => can_hash(value_type.data_type()),
997+
DataType::LargeList(value_type) => can_hash(value_type.data_type()),
998+
DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
999+
DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
9971000
DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
998-
_ => false,
1001+
1002+
DataType::ListView(_)
1003+
| DataType::LargeListView(_)
1004+
| DataType::Union(_, _)
1005+
| DataType::RunEndEncoded(_, _) => false,
9991006
}
10001007
}
10011008

@@ -1403,6 +1410,7 @@ mod tests {
14031410
test::function_stub::max_udaf, test::function_stub::min_udaf,
14041411
test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition,
14051412
};
1413+
use arrow::datatypes::{UnionFields, UnionMode};
14061414

14071415
#[test]
14081416
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
@@ -1805,4 +1813,21 @@ mod tests {
18051813
assert!(accum.contains(&Column::from_name("a")));
18061814
Ok(())
18071815
}
1816+
1817+
#[test]
1818+
fn test_can_hash() {
1819+
let union_fields: UnionFields = [
1820+
(0, Arc::new(Field::new("A", DataType::Int32, true))),
1821+
(1, Arc::new(Field::new("B", DataType::Float64, true))),
1822+
]
1823+
.into_iter()
1824+
.collect();
1825+
1826+
let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1827+
assert!(!can_hash(&union_type));
1828+
1829+
let list_union_type =
1830+
DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1831+
assert!(!can_hash(&list_union_type));
1832+
}
18081833
}

0 commit comments

Comments
 (0)