Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-arith/src/arity.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Kernels for operating on [`PrimitiveArray`]s
19
20
use arrow_array::builder::BufferBuilder;
21
use arrow_array::*;
22
use arrow_buffer::buffer::NullBuffer;
23
use arrow_buffer::ArrowNativeType;
24
use arrow_buffer::MutableBuffer;
25
use arrow_data::ArrayData;
26
use arrow_schema::ArrowError;
27
28
/// See [`PrimitiveArray::unary`]
29
pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
30
where
31
    I: ArrowPrimitiveType,
32
    O: ArrowPrimitiveType,
33
    F: Fn(I::Native) -> O::Native,
34
{
35
    array.unary(op)
36
}
37
38
/// See [`PrimitiveArray::unary_mut`]
39
pub fn unary_mut<I, F>(
40
    array: PrimitiveArray<I>,
41
    op: F,
42
) -> Result<PrimitiveArray<I>, PrimitiveArray<I>>
43
where
44
    I: ArrowPrimitiveType,
45
    F: Fn(I::Native) -> I::Native,
46
{
47
    array.unary_mut(op)
48
}
49
50
/// See [`PrimitiveArray::try_unary`]
51
pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>, ArrowError>
52
where
53
    I: ArrowPrimitiveType,
54
    O: ArrowPrimitiveType,
55
    F: Fn(I::Native) -> Result<O::Native, ArrowError>,
56
{
57
    array.try_unary(op)
58
}
59
60
/// See [`PrimitiveArray::try_unary_mut`]
61
pub fn try_unary_mut<I, F>(
62
    array: PrimitiveArray<I>,
63
    op: F,
64
) -> Result<Result<PrimitiveArray<I>, ArrowError>, PrimitiveArray<I>>
65
where
66
    I: ArrowPrimitiveType,
67
    F: Fn(I::Native) -> Result<I::Native, ArrowError>,
68
{
69
    array.try_unary_mut(op)
70
}
71
72
/// Allies a binary infallable function to two [`PrimitiveArray`]s,
73
/// producing a new [`PrimitiveArray`]
74
///
75
/// # Details
76
///
77
/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting
78
/// the results in a [`PrimitiveArray`].
79
///
80
/// If any index is null in either `a` or `b`, the
81
/// corresponding index in the result will also be null
82
///
83
/// Like [`unary`], the `op` is evaluated for every element in the two arrays,
84
/// including those elements which are NULL. This is beneficial as the cost of
85
/// the operation is low compared to the cost of branching, and especially when
86
/// the operation can be vectorised, however, requires `op` to be infallible for
87
/// all possible values of its inputs
88
///
89
/// # Errors
90
///
91
/// * if the arrays have different lengths.
92
///
93
/// # Example
94
/// ```
95
/// # use arrow_arith::arity::binary;
96
/// # use arrow_array::{Float32Array, Int32Array};
97
/// # use arrow_array::types::Int32Type;
98
/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8), Some(7.2)]);
99
/// let b = Int32Array::from(vec![1, 2, 4, 9]);
100
/// // compute int(a) + b for each element
101
/// let c = binary(&a, &b, |a, b| a as i32 + b).unwrap();
102
/// assert_eq!(c, Int32Array::from(vec![Some(6), None, Some(10), Some(16)]));
103
/// ```
104
0
pub fn binary<A, B, F, O>(
105
0
    a: &PrimitiveArray<A>,
106
0
    b: &PrimitiveArray<B>,
107
0
    op: F,
108
0
) -> Result<PrimitiveArray<O>, ArrowError>
109
0
where
110
0
    A: ArrowPrimitiveType,
111
0
    B: ArrowPrimitiveType,
112
0
    O: ArrowPrimitiveType,
113
0
    F: Fn(A::Native, B::Native) -> O::Native,
114
{
115
0
    if a.len() != b.len() {
116
0
        return Err(ArrowError::ComputeError(
117
0
            "Cannot perform binary operation on arrays of different length".to_string(),
118
0
        ));
119
0
    }
120
121
0
    if a.is_empty() {
122
0
        return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
123
0
    }
124
125
0
    let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref());
126
127
0
    let values = a
128
0
        .values()
129
0
        .into_iter()
130
0
        .zip(b.values())
131
0
        .map(|(l, r)| op(*l, *r));
132
133
0
    let buffer: Vec<_> = values.collect();
134
0
    Ok(PrimitiveArray::new(buffer.into(), nulls))
135
0
}
136
137
/// Applies a binary and infallible function to values in two arrays, replacing
138
/// the values in the first array in place.
139
///
140
/// # Details
141
///
142
/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in
143
/// `0..len`, modifying the [`PrimitiveArray`] `a` in place, if possible.
144
///
145
/// If any index is null in either `a` or `b`, the corresponding index in the
146
/// result will also be null.
147
///
148
/// # Buffer Reuse
149
///
150
/// If the underlying buffers in `a` are not shared with other arrays,  mutates
151
/// the underlying buffer in place, without allocating.
152
///
153
/// If the underlying buffer in `a` are shared, returns Err(self)
154
///
155
/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This
156
/// is beneficial when the cost of the operation is low compared to the cost of branching, and
157
/// especially when the operation can be vectorised, however, requires `op` to be infallible
158
/// for all possible values of its inputs
159
///
160
/// # Errors
161
///
162
/// * If the arrays have different lengths
163
/// * If the array is not mutable (see "Buffer Reuse")
164
///
165
/// # See Also
166
///
167
/// * Documentation on [`PrimitiveArray::unary_mut`] for operating on [`ArrayRef`].
168
///
169
/// # Example
170
/// ```
171
/// # use arrow_arith::arity::binary_mut;
172
/// # use arrow_array::{Float32Array, Int32Array};
173
/// # use arrow_array::types::Int32Type;
174
/// // compute a + b for each element
175
/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8)]);
176
/// let b = Int32Array::from(vec![Some(1), None, Some(2)]);
177
/// // compute a + b, updating the value in a in place if possible
178
/// let a = binary_mut(a, &b, |a, b| a + b as f32).unwrap().unwrap();
179
/// // a is updated in place
180
/// assert_eq!(a, Float32Array::from(vec![Some(6.1), None, Some(8.8)]));
181
/// ```
182
///
183
/// # Example with shared buffers
184
/// ```
185
/// # use arrow_arith::arity::binary_mut;
186
/// # use arrow_array::Float32Array;
187
/// # use arrow_array::types::Int32Type;
188
/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8)]);
189
/// let b = Float32Array::from(vec![Some(1.0f32), None, Some(2.0)]);
190
/// // a_clone shares the buffer with a
191
/// let a_cloned = a.clone();
192
/// // try to update a in place, but it is shared. Returns Err(a)
193
/// let a = binary_mut(a, &b, |a, b| a + b).unwrap_err();
194
/// assert_eq!(a_cloned, a);
195
/// // drop shared reference
196
/// drop(a_cloned);
197
/// // now a is not shared, so we can update it in place
198
/// let a = binary_mut(a, &b, |a, b| a + b).unwrap().unwrap();
199
/// assert_eq!(a, Float32Array::from(vec![Some(6.1), None, Some(8.8)]));
200
/// ```
201
pub fn binary_mut<T, U, F>(
202
    a: PrimitiveArray<T>,
203
    b: &PrimitiveArray<U>,
204
    op: F,
205
) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
206
where
207
    T: ArrowPrimitiveType,
208
    U: ArrowPrimitiveType,
209
    F: Fn(T::Native, U::Native) -> T::Native,
210
{
211
    if a.len() != b.len() {
212
        return Ok(Err(ArrowError::ComputeError(
213
            "Cannot perform binary operation on arrays of different length".to_string(),
214
        )));
215
    }
216
217
    if a.is_empty() {
218
        return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
219
            &T::DATA_TYPE,
220
        ))));
221
    }
222
223
    let mut builder = a.into_builder()?;
224
225
    builder
226
        .values_slice_mut()
227
        .iter_mut()
228
        .zip(b.values())
229
        .for_each(|(l, r)| *l = op(*l, *r));
230
231
    let array = builder.finish();
232
233
    // The builder has the null buffer from `a`, it is not changed.
234
    let nulls = NullBuffer::union(array.logical_nulls().as_ref(), b.logical_nulls().as_ref());
235
236
    let array_builder = array.into_data().into_builder().nulls(nulls);
237
238
    let array_data = unsafe { array_builder.build_unchecked() };
239
    Ok(Ok(PrimitiveArray::<T>::from(array_data)))
240
}
241
242
/// Applies the provided fallible binary operation across `a` and `b`.
243
///
244
/// This will return any error encountered, or collect the results into
245
/// a [`PrimitiveArray`]. If any index is null in either `a`
246
/// or `b`, the corresponding index in the result will also be null
247
///
248
/// Like [`try_unary`] the function is only evaluated for non-null indices
249
///
250
/// # Error
251
///
252
/// Return an error if the arrays have different lengths or
253
/// the operation is under erroneous
254
0
pub fn try_binary<A: ArrayAccessor, B: ArrayAccessor, F, O>(
255
0
    a: A,
256
0
    b: B,
257
0
    op: F,
258
0
) -> Result<PrimitiveArray<O>, ArrowError>
259
0
where
260
0
    O: ArrowPrimitiveType,
261
0
    F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
262
{
263
0
    if a.len() != b.len() {
264
0
        return Err(ArrowError::ComputeError(
265
0
            "Cannot perform a binary operation on arrays of different length".to_string(),
266
0
        ));
267
0
    }
268
0
    if a.is_empty() {
269
0
        return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
270
0
    }
271
0
    let len = a.len();
272
273
0
    if a.null_count() == 0 && b.null_count() == 0 {
274
0
        try_binary_no_nulls(len, a, b, op)
275
    } else {
276
0
        let nulls =
277
0
            NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap();
278
279
0
        let mut buffer = BufferBuilder::<O::Native>::new(len);
280
0
        buffer.append_n_zeroed(len);
281
0
        let slice = buffer.as_slice_mut();
282
283
0
        nulls.try_for_each_valid_idx(|idx| {
284
            unsafe {
285
0
                *slice.get_unchecked_mut(idx) = op(a.value_unchecked(idx), b.value_unchecked(idx))?
286
            };
287
0
            Ok::<_, ArrowError>(())
288
0
        })?;
289
290
0
        let values = buffer.finish().into();
291
0
        Ok(PrimitiveArray::new(values, Some(nulls)))
292
    }
293
0
}
294
295
/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable
296
/// [`PrimitiveArray`] `a` with the results.
297
///
298
/// Returns any error encountered, or collects the results into a [`PrimitiveArray`] as return
299
/// value. If any index is null in either `a` or `b`, the corresponding index in the result will
300
/// also be null.
301
///
302
/// Like [`try_unary`] the function is only evaluated for non-null indices.
303
///
304
/// See [`binary_mut`] for errors and buffer reuse information.
305
pub fn try_binary_mut<T, F>(
306
    a: PrimitiveArray<T>,
307
    b: &PrimitiveArray<T>,
308
    op: F,
309
) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
310
where
311
    T: ArrowPrimitiveType,
312
    F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
313
{
314
    if a.len() != b.len() {
315
        return Ok(Err(ArrowError::ComputeError(
316
            "Cannot perform binary operation on arrays of different length".to_string(),
317
        )));
318
    }
319
    let len = a.len();
320
321
    if a.is_empty() {
322
        return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
323
            &T::DATA_TYPE,
324
        ))));
325
    }
326
327
    if a.null_count() == 0 && b.null_count() == 0 {
328
        try_binary_no_nulls_mut(len, a, b, op)
329
    } else {
330
        let nulls =
331
            create_union_null_buffer(a.logical_nulls().as_ref(), b.logical_nulls().as_ref())
332
                .unwrap();
333
334
        let mut builder = a.into_builder()?;
335
336
        let slice = builder.values_slice_mut();
337
338
        let r = nulls.try_for_each_valid_idx(|idx| {
339
            unsafe {
340
                *slice.get_unchecked_mut(idx) =
341
                    op(*slice.get_unchecked(idx), b.value_unchecked(idx))?
342
            };
343
            Ok::<_, ArrowError>(())
344
        });
345
        if let Err(err) = r {
346
            return Ok(Err(err));
347
        }
348
        let array_builder = builder.finish().into_data().into_builder();
349
        let array_data = unsafe { array_builder.nulls(Some(nulls)).build_unchecked() };
350
        Ok(Ok(PrimitiveArray::<T>::from(array_data)))
351
    }
352
}
353
354
/// Computes the union of the nulls in two optional [`NullBuffer`] which
355
/// is not shared with the input buffers.
356
///
357
/// The union of the nulls is the same as `NullBuffer::union(lhs, rhs)` but
358
/// it does not increase the reference count of the null buffer.
359
fn create_union_null_buffer(
360
    lhs: Option<&NullBuffer>,
361
    rhs: Option<&NullBuffer>,
362
) -> Option<NullBuffer> {
363
    match (lhs, rhs) {
364
        (Some(lhs), Some(rhs)) => Some(NullBuffer::new(lhs.inner() & rhs.inner())),
365
        (Some(n), None) | (None, Some(n)) => Some(NullBuffer::new(n.inner() & n.inner())),
366
        (None, None) => None,
367
    }
368
}
369
370
/// This intentional inline(never) attribute helps LLVM optimize the loop.
371
#[inline(never)]
372
0
fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
373
0
    len: usize,
374
0
    a: A,
375
0
    b: B,
376
0
    op: F,
377
0
) -> Result<PrimitiveArray<O>, ArrowError>
378
0
where
379
0
    O: ArrowPrimitiveType,
380
0
    F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
381
{
382
0
    let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width());
383
0
    for idx in 0..len {
384
        unsafe {
385
0
            buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);
386
        };
387
    }
388
0
    Ok(PrimitiveArray::new(buffer.into(), None))
389
0
}
390
391
/// This intentional inline(never) attribute helps LLVM optimize the loop.
392
#[inline(never)]
393
fn try_binary_no_nulls_mut<T, F>(
394
    len: usize,
395
    a: PrimitiveArray<T>,
396
    b: &PrimitiveArray<T>,
397
    op: F,
398
) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
399
where
400
    T: ArrowPrimitiveType,
401
    F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
402
{
403
    let mut builder = a.into_builder()?;
404
    let slice = builder.values_slice_mut();
405
406
    for idx in 0..len {
407
        unsafe {
408
            match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) {
409
                Ok(value) => *slice.get_unchecked_mut(idx) = value,
410
                Err(err) => return Ok(Err(err)),
411
            };
412
        };
413
    }
414
    Ok(Ok(builder.finish()))
415
}
416
417
#[cfg(test)]
418
mod tests {
419
    use super::*;
420
    use arrow_array::types::*;
421
    use std::sync::Arc;
422
423
    #[test]
424
    #[allow(deprecated)]
425
    fn test_unary_f64_slice() {
426
        let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
427
        let input_slice = input.slice(1, 4);
428
        let result = unary(&input_slice, |n| n.round());
429
        assert_eq!(
430
            result,
431
            Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
432
        );
433
    }
434
435
    #[test]
436
    fn test_binary_mut() {
437
        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
438
        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
439
        let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap();
440
441
        let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
442
        assert_eq!(c, expected);
443
    }
444
445
    #[test]
446
    fn test_binary_mut_null_buffer() {
447
        let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
448
449
        let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);
450
451
        let r1 = binary_mut(a, &b, |a, b| a + b).unwrap();
452
453
        let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
454
        let b = Int32Array::new(
455
            vec![10, 11, 12, 13, 14].into(),
456
            Some(vec![true, true, true, true, true].into()),
457
        );
458
459
        // unwrap here means that no copying occured
460
        let r2 = binary_mut(a, &b, |a, b| a + b).unwrap();
461
        assert_eq!(r1.unwrap(), r2.unwrap());
462
    }
463
464
    #[test]
465
    fn test_try_binary_mut() {
466
        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
467
        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
468
        let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
469
470
        let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
471
        assert_eq!(c, expected);
472
473
        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
474
        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
475
        let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
476
        let expected = Int32Array::from(vec![16, 16, 12, 12, 6]);
477
        assert_eq!(c, expected);
478
479
        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
480
        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
481
        let _ = try_binary_mut(a, &b, |l, r| {
482
            if l == 1 {
483
                Err(ArrowError::InvalidArgumentError(
484
                    "got error".parse().unwrap(),
485
                ))
486
            } else {
487
                Ok(l + r)
488
            }
489
        })
490
        .unwrap()
491
        .expect_err("should got error");
492
    }
493
494
    #[test]
495
    fn test_try_binary_mut_null_buffer() {
496
        let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
497
498
        let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);
499
500
        let r1 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
501
502
        let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
503
        let b = Int32Array::new(
504
            vec![10, 11, 12, 13, 14].into(),
505
            Some(vec![true, true, true, true, true].into()),
506
        );
507
508
        // unwrap here means that no copying occured
509
        let r2 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
510
        assert_eq!(r1.unwrap(), r2.unwrap());
511
    }
512
513
    #[test]
514
    fn test_unary_dict_mut() {
515
        let values = Int32Array::from(vec![Some(10), Some(20), None]);
516
        let keys = Int8Array::from_iter_values([0, 0, 1, 2]);
517
        let dictionary = DictionaryArray::new(keys, Arc::new(values));
518
519
        let updated = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap();
520
        let typed = updated.downcast_dict::<Int32Array>().unwrap();
521
        assert_eq!(typed.value(0), 11);
522
        assert_eq!(typed.value(1), 11);
523
        assert_eq!(typed.value(2), 21);
524
525
        let values = updated.values();
526
        assert!(values.is_null(2));
527
    }
528
}