Skip to content

Commit f7cae50

Browse files
stuartcarniealamb
authored andcommitted
feat: Support binary data types for SortMergeJoin on clause (apache#17431)
* feat: Support binary data types for `SortMergeJoin` `on` clause * Add sql level tests for merge join on binary keys --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent c57e72c commit f7cae50

File tree

2 files changed

+209
-4
lines changed

2 files changed

+209
-4
lines changed

datafusion/physical-plan/src/joins/sort_merge_join.rs

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2503,6 +2503,10 @@ fn compare_join_arrays(
25032503
DataType::Utf8 => compare_value!(StringArray),
25042504
DataType::Utf8View => compare_value!(StringViewArray),
25052505
DataType::LargeUtf8 => compare_value!(LargeStringArray),
2506+
DataType::Binary => compare_value!(BinaryArray),
2507+
DataType::BinaryView => compare_value!(BinaryViewArray),
2508+
DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray),
2509+
DataType::LargeBinary => compare_value!(LargeBinaryArray),
25062510
DataType::Decimal128(..) => compare_value!(Decimal128Array),
25072511
DataType::Timestamp(time_unit, None) => match time_unit {
25082512
TimeUnit::Second => compare_value!(TimestampSecondArray),
@@ -2571,6 +2575,10 @@ fn is_join_arrays_equal(
25712575
DataType::Utf8 => compare_value!(StringArray),
25722576
DataType::Utf8View => compare_value!(StringViewArray),
25732577
DataType::LargeUtf8 => compare_value!(LargeStringArray),
2578+
DataType::Binary => compare_value!(BinaryArray),
2579+
DataType::BinaryView => compare_value!(BinaryViewArray),
2580+
DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray),
2581+
DataType::LargeBinary => compare_value!(LargeBinaryArray),
25742582
DataType::Decimal128(..) => compare_value!(Decimal128Array),
25752583
DataType::Timestamp(time_unit, None) => match time_unit {
25762584
TimeUnit::Second => compare_value!(TimestampSecondArray),
@@ -2600,7 +2608,8 @@ mod tests {
26002608

26012609
use arrow::array::{
26022610
builder::{BooleanBuilder, UInt64Builder},
2603-
BooleanArray, Date32Array, Date64Array, Int32Array, RecordBatch, UInt64Array,
2611+
BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray,
2612+
Int32Array, RecordBatch, UInt64Array,
26042613
};
26052614
use arrow::compute::{concat_batches, filter_record_batch, SortOptions};
26062615
use arrow::datatypes::{DataType, Field, Schema};
@@ -2694,6 +2703,56 @@ mod tests {
26942703
TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
26952704
}
26962705

2706+
fn build_binary_table(
2707+
a: (&str, &Vec<&[u8]>),
2708+
b: (&str, &Vec<i32>),
2709+
c: (&str, &Vec<i32>),
2710+
) -> Arc<dyn ExecutionPlan> {
2711+
let schema = Schema::new(vec![
2712+
Field::new(a.0, DataType::Binary, false),
2713+
Field::new(b.0, DataType::Int32, false),
2714+
Field::new(c.0, DataType::Int32, false),
2715+
]);
2716+
2717+
let batch = RecordBatch::try_new(
2718+
Arc::new(schema),
2719+
vec![
2720+
Arc::new(BinaryArray::from(a.1.clone())),
2721+
Arc::new(Int32Array::from(b.1.clone())),
2722+
Arc::new(Int32Array::from(c.1.clone())),
2723+
],
2724+
)
2725+
.unwrap();
2726+
2727+
let schema = batch.schema();
2728+
TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2729+
}
2730+
2731+
fn build_fixed_size_binary_table(
2732+
a: (&str, &Vec<&[u8]>),
2733+
b: (&str, &Vec<i32>),
2734+
c: (&str, &Vec<i32>),
2735+
) -> Arc<dyn ExecutionPlan> {
2736+
let schema = Schema::new(vec![
2737+
Field::new(a.0, DataType::FixedSizeBinary(3), false),
2738+
Field::new(b.0, DataType::Int32, false),
2739+
Field::new(c.0, DataType::Int32, false),
2740+
]);
2741+
2742+
let batch = RecordBatch::try_new(
2743+
Arc::new(schema),
2744+
vec![
2745+
Arc::new(FixedSizeBinaryArray::from(a.1.clone())),
2746+
Arc::new(Int32Array::from(b.1.clone())),
2747+
Arc::new(Int32Array::from(c.1.clone())),
2748+
],
2749+
)
2750+
.unwrap();
2751+
2752+
let schema = batch.schema();
2753+
TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2754+
}
2755+
26972756
/// returns a table with 3 columns of i32 in memory
26982757
pub fn build_table_i32_nullable(
26992758
a: (&str, &Vec<Option<i32>>),
@@ -3932,6 +3991,100 @@ mod tests {
39323991
Ok(())
39333992
}
39343993

3994+
#[tokio::test]
3995+
async fn join_binary() -> Result<()> {
3996+
let left = build_binary_table(
3997+
(
3998+
"a1",
3999+
&vec![
4000+
&[0xc0, 0xff, 0xee],
4001+
&[0xde, 0xca, 0xde],
4002+
&[0xfa, 0xca, 0xde],
4003+
],
4004+
),
4005+
("b1", &vec![5, 10, 15]), // this has a repetition
4006+
("c1", &vec![7, 8, 9]),
4007+
);
4008+
let right = build_binary_table(
4009+
(
4010+
"a1",
4011+
&vec![
4012+
&[0xc0, 0xff, 0xee],
4013+
&[0xde, 0xca, 0xde],
4014+
&[0xfa, 0xca, 0xde],
4015+
],
4016+
),
4017+
("b2", &vec![105, 110, 115]),
4018+
("c2", &vec![70, 80, 90]),
4019+
);
4020+
4021+
let on = vec![(
4022+
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
4023+
Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
4024+
)];
4025+
4026+
let (_, batches) = join_collect(left, right, on, Inner).await?;
4027+
4028+
// The output order is important as SMJ preserves sortedness
4029+
assert_snapshot!(batches_to_string(&batches), @r#"
4030+
+--------+----+----+--------+-----+----+
4031+
| a1 | b1 | c1 | a1 | b2 | c2 |
4032+
+--------+----+----+--------+-----+----+
4033+
| c0ffee | 5 | 7 | c0ffee | 105 | 70 |
4034+
| decade | 10 | 8 | decade | 110 | 80 |
4035+
| facade | 15 | 9 | facade | 115 | 90 |
4036+
+--------+----+----+--------+-----+----+
4037+
"#);
4038+
Ok(())
4039+
}
4040+
4041+
#[tokio::test]
4042+
async fn join_fixed_size_binary() -> Result<()> {
4043+
let left = build_fixed_size_binary_table(
4044+
(
4045+
"a1",
4046+
&vec![
4047+
&[0xc0, 0xff, 0xee],
4048+
&[0xde, 0xca, 0xde],
4049+
&[0xfa, 0xca, 0xde],
4050+
],
4051+
),
4052+
("b1", &vec![5, 10, 15]), // this has a repetition
4053+
("c1", &vec![7, 8, 9]),
4054+
);
4055+
let right = build_fixed_size_binary_table(
4056+
(
4057+
"a1",
4058+
&vec![
4059+
&[0xc0, 0xff, 0xee],
4060+
&[0xde, 0xca, 0xde],
4061+
&[0xfa, 0xca, 0xde],
4062+
],
4063+
),
4064+
("b2", &vec![105, 110, 115]),
4065+
("c2", &vec![70, 80, 90]),
4066+
);
4067+
4068+
let on = vec![(
4069+
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
4070+
Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
4071+
)];
4072+
4073+
let (_, batches) = join_collect(left, right, on, Inner).await?;
4074+
4075+
// The output order is important as SMJ preserves sortedness
4076+
assert_snapshot!(batches_to_string(&batches), @r#"
4077+
+--------+----+----+--------+-----+----+
4078+
| a1 | b1 | c1 | a1 | b2 | c2 |
4079+
+--------+----+----+--------+-----+----+
4080+
| c0ffee | 5 | 7 | c0ffee | 105 | 70 |
4081+
| decade | 10 | 8 | decade | 110 | 80 |
4082+
| facade | 15 | 9 | facade | 115 | 90 |
4083+
+--------+----+----+--------+-----+----+
4084+
"#);
4085+
Ok(())
4086+
}
4087+
39354088
#[tokio::test]
39364089
async fn join_left_sort_order() -> Result<()> {
39374090
let left = build_table(

datafusion/sqllogictest/test_files/sort_merge_join.slt

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,9 +833,61 @@ t2 as (
833833
11 14
834834
12 15
835835

836-
# return sql params back to default values
837836
statement ok
838-
set datafusion.optimizer.prefer_hash_join = true;
837+
set datafusion.execution.batch_size = 8192;
839838

839+
840+
######
841+
## Tests for Binary, LargeBinary, BinaryView, FixedSizeBinary join keys
842+
######
840843
statement ok
841-
set datafusion.execution.batch_size = 8192;
844+
create table t1(x varchar, id1 int) as values ('aa', 1), ('bb', 2), ('aa', 3), (null, 4), ('ee', 5);
845+
846+
statement ok
847+
create table t2(y varchar, id2 int) as values ('ee', 10), ('bb', 20), ('cc', 30), ('cc', 40), (null, 50);
848+
849+
# Binary join keys
850+
query ?I?I
851+
with t1 as (select arrow_cast(x, 'Binary') as x, id1 from t1),
852+
t2 as (select arrow_cast(y, 'Binary') as y, id2 from t2)
853+
select * from t1 join t2 on t1.x = t2.y order by id1, id2
854+
----
855+
6262 2 6262 20
856+
6565 5 6565 10
857+
858+
# LargeBinary join keys
859+
query ?I?I
860+
with t1 as (select arrow_cast(x, 'LargeBinary') as x, id1 from t1),
861+
t2 as (select arrow_cast(y, 'LargeBinary') as y, id2 from t2)
862+
select * from t1 join t2 on t1.x = t2.y order by id1, id2
863+
----
864+
6262 2 6262 20
865+
6565 5 6565 10
866+
867+
# BinaryView join keys
868+
query ?I?I
869+
with t1 as (select arrow_cast(x, 'BinaryView') as x, id1 from t1),
870+
t2 as (select arrow_cast(y, 'BinaryView') as y, id2 from t2)
871+
select * from t1 join t2 on t1.x = t2.y order by id1, id2
872+
----
873+
6262 2 6262 20
874+
6565 5 6565 10
875+
876+
# FixedSizeBinary join keys
877+
query ?I?I
878+
with t1 as (select arrow_cast(arrow_cast(x, 'Binary'), 'FixedSizeBinary(2)') as x, id1 from t1),
879+
t2 as (select arrow_cast(arrow_cast(y, 'Binary'), 'FixedSizeBinary(2)') as y, id2 from t2)
880+
select * from t1 join t2 on t1.x = t2.y order by id1, id2
881+
----
882+
6262 2 6262 20
883+
6565 5 6565 10
884+
885+
statement ok
886+
drop table t1;
887+
888+
statement ok
889+
drop table t2;
890+
891+
# return sql params back to default values
892+
statement ok
893+
set datafusion.optimizer.prefer_hash_join = true;

0 commit comments

Comments
 (0)