Skip to content

Clean up mixed semi join by removing unused parameters and redundant table-swapping logic #19655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 25 additions & 36 deletions cpp/src/join/mixed_join_common_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,18 +50,15 @@ struct expression_equality {
__device__ expression_equality(
cudf::ast::detail::expression_evaluator<has_nulls> const& evaluator,
cudf::ast::detail::IntermediateDataType<has_nulls>* thread_intermediate_storage,
bool const swap_tables,
row_equality const& equality_probe)
: evaluator{evaluator},
thread_intermediate_storage{thread_intermediate_storage},
swap_tables{swap_tables},
equality_probe{equality_probe}
{
}

cudf::ast::detail::IntermediateDataType<has_nulls>* thread_intermediate_storage;
cudf::ast::detail::expression_evaluator<has_nulls> const& evaluator;
bool const swap_tables;
row_equality const& equality_probe;
};

Expand All @@ -71,23 +68,14 @@ struct expression_equality {
* This equality comparator is designed for use with cuco::static_map's APIs. A
* probe hit indicates that the hashes of the keys are equal, at which point
* this comparator checks whether the keys themselves are equal (using the
* provided equality_probe) and then evaluates the conditional expression
* provided equality_probe) and then evaluates the conditional expression.
*/
template <bool has_nulls>
struct single_expression_equality : expression_equality<has_nulls> {
using expression_equality<has_nulls>::expression_equality;

// The parameters are build/probe rather than left/right because the operator
// is called by cuco's kernels with parameters in this order (note that this
// is an implementation detail that we should eventually stop relying on by
// defining operators with suitable heterogeneous typing). Rather than
// converting to left/right semantics, we can operate directly on build/probe
// until we get to the expression evaluator, which needs to convert back to
// left/right semantics because the conditional expression need not be
// commutative.
// TODO: The input types should really be size_type.
__device__ __forceinline__ bool operator()(hash_value_type const build_row_index,
hash_value_type const probe_row_index) const noexcept
__device__ __forceinline__ bool operator()(size_type const left_index,
size_type const right_index) const noexcept
{
using cudf::experimental::row::lhs_index_type;
using cudf::experimental::row::rhs_index_type;
Expand All @@ -97,12 +85,11 @@ struct single_expression_equality : expression_equality<has_nulls> {
// 1. The contents of the columns involved in the equality condition are equal.
// 2. The predicate evaluated on the relevant columns (already encoded in the evaluator)
// evaluates to true.
if (this->equality_probe(lhs_index_type{probe_row_index}, rhs_index_type{build_row_index})) {
auto const lrow_idx = this->swap_tables ? build_row_index : probe_row_index;
auto const rrow_idx = this->swap_tables ? probe_row_index : build_row_index;
if (this->equality_probe(lhs_index_type{left_index}, rhs_index_type{right_index})) {
// For single expressions, left/right mapping is direct (no swap needed)
this->evaluator.evaluate(output_dest,
static_cast<size_type>(lrow_idx),
static_cast<size_type>(rrow_idx),
static_cast<size_type>(left_index),
static_cast<size_type>(right_index),
0,
this->thread_intermediate_storage);
return (output_dest.is_valid() && output_dest.value());
Expand All @@ -127,18 +114,20 @@ struct single_expression_equality : expression_equality<has_nulls> {
*/
template <bool has_nulls>
struct pair_expression_equality : public expression_equality<has_nulls> {
using expression_equality<has_nulls>::expression_equality;
__device__ pair_expression_equality(
cudf::ast::detail::expression_evaluator<has_nulls> const& evaluator,
cudf::ast::detail::IntermediateDataType<has_nulls>* thread_intermediate_storage,
bool const swap_tables,
row_equality const& equality_probe)
: expression_equality<has_nulls>{evaluator, thread_intermediate_storage, equality_probe},
swap_tables{swap_tables}
{
}

bool const swap_tables;

// The parameters are build/probe rather than left/right because the operator
// is called by cuco's kernels with parameters in this order (note that this
// is an implementation detail that we should eventually stop relying on by
// defining operators with suitable heterogeneous typing). Rather than
// converting to left/right semantics, we can operate directly on build/probe
// until we get to the expression evaluator, which needs to convert back to
// left/right semantics because the conditional expression need not be
// commutative.
__device__ __forceinline__ bool operator()(pair_type const& build_row,
pair_type const& probe_row) const noexcept
__device__ __forceinline__ bool operator()(pair_type const& left_row,
pair_type const& right_row) const noexcept
{
using cudf::experimental::row::lhs_index_type;
using cudf::experimental::row::rhs_index_type;
Expand All @@ -149,10 +138,10 @@ struct pair_expression_equality : public expression_equality<has_nulls> {
// 2. The contents of the columns involved in the equality condition are equal.
// 3. The predicate evaluated on the relevant columns (already encoded in the evaluator)
// evaluates to true.
if ((probe_row.first == build_row.first) &&
this->equality_probe(lhs_index_type{probe_row.second}, rhs_index_type{build_row.second})) {
auto const lrow_idx = this->swap_tables ? build_row.second : probe_row.second;
auto const rrow_idx = this->swap_tables ? probe_row.second : build_row.second;
if ((left_row.first == right_row.first) &&
this->equality_probe(lhs_index_type{left_row.second}, rhs_index_type{right_row.second})) {
auto const lrow_idx = this->swap_tables ? right_row.second : left_row.second;
auto const rrow_idx = this->swap_tables ? left_row.second : right_row.second;
this->evaluator.evaluate(
output_dest, lrow_idx, rrow_idx, 0, this->thread_intermediate_storage);
return (output_dest.is_valid() && output_dest.value());
Expand Down
16 changes: 2 additions & 14 deletions cpp/src/join/mixed_join_kernels_semi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ template <cudf::size_type block_size, bool has_nulls>
CUDF_KERNEL void __launch_bounds__(block_size)
mixed_join_semi(table_device_view left_table,
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_equality const equality_probe,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
Expand All @@ -50,12 +48,8 @@ CUDF_KERNEL void __launch_bounds__(block_size)
// Equality evaluator to use
auto const evaluator = cudf::ast::detail::expression_evaluator<has_nulls>(
left_table, right_table, device_expression_data);

// Make sure to swap_tables here as hash_set will use probe table as the left one
auto constexpr swap_tables = true;
auto const equality = single_expression_equality<has_nulls>{
evaluator, thread_intermediate_storage, swap_tables, equality_probe};

auto const equality =
single_expression_equality<has_nulls>{evaluator, thread_intermediate_storage, equality_probe};
// Create set ref with the new equality comparator
auto const set_ref_equality = set_ref.rebind_key_eq(equality);

Expand All @@ -74,8 +68,6 @@ CUDF_KERNEL void __launch_bounds__(block_size)
void launch_mixed_join_semi(bool has_nulls,
table_device_view left_table,
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_equality const equality_probe,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
Expand All @@ -89,8 +81,6 @@ void launch_mixed_join_semi(bool has_nulls,
<<<config.num_blocks, config.num_threads_per_block, shmem_size_per_block, stream.value()>>>(
left_table,
right_table,
probe,
build,
equality_probe,
set_ref,
left_table_keep_mask,
Expand All @@ -100,8 +90,6 @@ void launch_mixed_join_semi(bool has_nulls,
<<<config.num_blocks, config.num_threads_per_block, shmem_size_per_block, stream.value()>>>(
left_table,
right_table,
probe,
build,
equality_probe,
set_ref,
left_table_keep_mask,
Expand Down
4 changes: 0 additions & 4 deletions cpp/src/join/mixed_join_kernels_semi.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ namespace detail {
* @param[in] has_nulls If the input has nulls
* @param[in] left_table The left table
* @param[in] right_table The right table
* @param[in] probe The table with which to probe the hash table for matches.
* @param[in] build The table with which the hash table was built.
* @param[in] equality_probe The equality comparator used when probing the hash table.
* @param[in] set_ref The hash table device view built from `build`.
* @param[out] left_table_keep_mask The result of the join operation with "true" element indicating
Expand All @@ -56,8 +54,6 @@ namespace detail {
void launch_mixed_join_semi(bool has_nulls,
table_device_view left_table,
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_equality const equality_probe,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
Expand Down
4 changes: 1 addition & 3 deletions cpp/src/join/mixed_join_semi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
auto const preprocessed_probe =
cudf::experimental::row::equality::preprocessed_table::create(probe, stream);
auto const row_comparator =
cudf::experimental::row::equality::two_table_comparator{preprocessed_build, preprocessed_probe};
cudf::experimental::row::equality::two_table_comparator{preprocessed_probe, preprocessed_build};
auto const equality_probe = row_comparator.equal_to<false>(has_nulls, compare_nulls);

// Create hash table containing all keys found in right table
Expand Down Expand Up @@ -196,8 +196,6 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
launch_mixed_join_semi(has_nulls,
*left_conditional_view,
*right_conditional_view,
*probe_view,
*build_view,
equality_probe,
row_set_ref,
cudf::device_span<bool>(left_table_keep_mask),
Expand Down
Loading