Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-string/src/binary_like.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Provide SQL's CONTAINS, STARTS_WITH, ENDS_WITH operators for Arrow's binary arrays
19
20
use crate::binary_predicate::BinaryPredicate;
21
22
use arrow_array::cast::AsArray;
23
use arrow_array::*;
24
use arrow_schema::*;
25
use arrow_select::take::take;
26
27
#[derive(Debug)]
28
pub(crate) enum Op {
29
    Contains,
30
    StartsWith,
31
    EndsWith,
32
}
33
34
impl TryFrom<crate::like::Op> for Op {
35
    type Error = ArrowError;
36
37
0
    fn try_from(value: crate::like::Op) -> Result<Self, Self::Error> {
38
0
        match value {
39
0
            crate::like::Op::Contains => Ok(Op::Contains),
40
0
            crate::like::Op::StartsWith => Ok(Op::StartsWith),
41
0
            crate::like::Op::EndsWith => Ok(Op::EndsWith),
42
0
            _ => Err(ArrowError::InvalidArgumentError(format!(
43
0
                "Invalid binary operation: {value}"
44
0
            ))),
45
        }
46
0
    }
47
}
48
49
impl std::fmt::Display for Op {
50
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51
0
        match self {
52
0
            Op::Contains => write!(f, "CONTAINS"),
53
0
            Op::StartsWith => write!(f, "STARTS_WITH"),
54
0
            Op::EndsWith => write!(f, "ENDS_WITH"),
55
        }
56
0
    }
57
}
58
59
0
pub(crate) fn binary_apply<'a, 'i, T: BinaryArrayType<'a> + 'a>(
60
0
    op: Op,
61
0
    l: T,
62
0
    l_s: bool,
63
0
    l_v: Option<&'a dyn AnyDictionaryArray>,
64
0
    r: T,
65
0
    r_s: bool,
66
0
    r_v: Option<&'a dyn AnyDictionaryArray>,
67
0
) -> Result<BooleanArray, ArrowError> {
68
0
    let l_len = l_v.map(|l| l.len()).unwrap_or(l.len());
69
0
    if r_s {
70
0
        let idx = match r_v {
71
0
            Some(dict) if dict.null_count() != 0 => return Ok(BooleanArray::new_null(l_len)),
72
0
            Some(dict) => dict.normalized_keys()[0],
73
0
            None => 0,
74
        };
75
0
        if r.is_null(idx) {
76
0
            return Ok(BooleanArray::new_null(l_len));
77
0
        }
78
0
        op_scalar::<T>(op, l, l_v, r.value(idx))
79
    } else {
80
0
        match (l_s, l_v, r_v) {
81
            (true, None, None) => {
82
0
                let v = l.is_valid(0).then(|| l.value(0));
83
0
                op_binary(op, std::iter::repeat(v), r.iter())
84
            }
85
0
            (true, Some(l_v), None) => {
86
0
                let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]);
87
0
                let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx)));
88
0
                op_binary(op, std::iter::repeat(v), r.iter())
89
            }
90
0
            (true, None, Some(r_v)) => {
91
0
                let v = l.is_valid(0).then(|| l.value(0));
92
0
                op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v))
93
            }
94
0
            (true, Some(l_v), Some(r_v)) => {
95
0
                let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]);
96
0
                let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx)));
97
0
                op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v))
98
            }
99
0
            (false, None, None) => op_binary(op, l.iter(), r.iter()),
100
0
            (false, Some(l_v), None) => op_binary(op, vectored_iter(l, l_v), r.iter()),
101
0
            (false, None, Some(r_v)) => op_binary(op, l.iter(), vectored_iter(r, r_v)),
102
0
            (false, Some(l_v), Some(r_v)) => {
103
0
                op_binary(op, vectored_iter(l, l_v), vectored_iter(r, r_v))
104
            }
105
        }
106
    }
107
0
}
108
109
#[inline(never)]
110
0
fn op_scalar<'a, T: BinaryArrayType<'a>>(
111
0
    op: Op,
112
0
    l: T,
113
0
    l_v: Option<&dyn AnyDictionaryArray>,
114
0
    r: &[u8],
115
0
) -> Result<BooleanArray, ArrowError> {
116
0
    let r = match op {
117
0
        Op::Contains => BinaryPredicate::contains(r).evaluate_array(l, false),
118
0
        Op::StartsWith => BinaryPredicate::StartsWith(r).evaluate_array(l, false),
119
0
        Op::EndsWith => BinaryPredicate::EndsWith(r).evaluate_array(l, false),
120
    };
121
122
0
    Ok(match l_v {
123
0
        Some(v) => take(&r, v.keys(), None)?.as_boolean().clone(),
124
0
        None => r,
125
    })
126
0
}
127
128
0
fn vectored_iter<'a, T: BinaryArrayType<'a> + 'a>(
129
0
    a: T,
130
0
    a_v: &'a dyn AnyDictionaryArray,
131
0
) -> impl Iterator<Item = Option<&'a [u8]>> + 'a {
132
0
    let nulls = a_v.nulls();
133
0
    let keys = a_v.normalized_keys();
134
0
    keys.into_iter().enumerate().map(move |(idx, key)| {
135
0
        if nulls.map(|n| n.is_null(idx)).unwrap_or_default() || a.is_null(key) {
136
0
            return None;
137
0
        }
138
0
        Some(a.value(key))
139
0
    })
140
0
}
141
142
#[inline(never)]
143
0
fn op_binary<'a>(
144
0
    op: Op,
145
0
    l: impl Iterator<Item = Option<&'a [u8]>>,
146
0
    r: impl Iterator<Item = Option<&'a [u8]>>,
147
0
) -> Result<BooleanArray, ArrowError> {
148
0
    match op {
149
0
        Op::Contains => Ok(l
150
0
            .zip(r)
151
0
            .map(|(l, r)| Some(bytes_contains(l?, r?)))
152
0
            .collect()),
153
0
        Op::StartsWith => Ok(l
154
0
            .zip(r)
155
0
            .map(|(l, r)| Some(BinaryPredicate::StartsWith(r?).evaluate(l?)))
156
0
            .collect()),
157
0
        Op::EndsWith => Ok(l
158
0
            .zip(r)
159
0
            .map(|(l, r)| Some(BinaryPredicate::EndsWith(r?).evaluate(l?)))
160
0
            .collect()),
161
    }
162
0
}
163
164
0
fn bytes_contains(haystack: &[u8], needle: &[u8]) -> bool {
165
0
    memchr::memmem::find(haystack, needle).is_some()
166
0
}