/Users/andrewlamb/Software/arrow-rs/arrow-array/src/scalar.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use crate::Array; |
19 | | |
20 | | /// A possibly [`Scalar`] [`Array`] |
21 | | /// |
22 | | /// This allows optimised binary kernels where one or more arguments are constant |
23 | | /// |
24 | | /// ``` |
25 | | /// # use arrow_array::*; |
26 | | /// # use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; |
27 | | /// # use arrow_schema::ArrowError; |
28 | | /// # |
29 | | /// fn eq_impl<T: ArrowPrimitiveType>( |
30 | | /// a: &PrimitiveArray<T>, |
31 | | /// a_scalar: bool, |
32 | | /// b: &PrimitiveArray<T>, |
33 | | /// b_scalar: bool, |
34 | | /// ) -> BooleanArray { |
35 | | /// let (array, scalar) = match (a_scalar, b_scalar) { |
36 | | /// (true, true) | (false, false) => { |
37 | | /// let len = a.len().min(b.len()); |
38 | | /// let nulls = NullBuffer::union(a.nulls(), b.nulls()); |
39 | | /// let buffer = BooleanBuffer::collect_bool(len, |idx| a.value(idx) == b.value(idx)); |
40 | | /// return BooleanArray::new(buffer, nulls); |
41 | | /// } |
42 | | /// (true, false) => (b, (a.null_count() == 0).then(|| a.value(0))), |
43 | | /// (false, true) => (a, (b.null_count() == 0).then(|| b.value(0))), |
44 | | /// }; |
45 | | /// match scalar { |
46 | | /// Some(v) => { |
47 | | /// let len = array.len(); |
48 | | /// let nulls = array.nulls().cloned(); |
49 | | /// let buffer = BooleanBuffer::collect_bool(len, |idx| array.value(idx) == v); |
50 | | /// BooleanArray::new(buffer, nulls) |
51 | | /// } |
52 | | /// None => BooleanArray::new_null(array.len()), |
53 | | /// } |
54 | | /// } |
55 | | /// |
56 | | /// pub fn eq(l: &dyn Datum, r: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
57 | | /// let (l_array, l_scalar) = l.get(); |
58 | | /// let (r_array, r_scalar) = r.get(); |
59 | | /// downcast_primitive_array!( |
60 | | /// (l_array, r_array) => Ok(eq_impl(l_array, l_scalar, r_array, r_scalar)), |
61 | | /// (a, b) => Err(ArrowError::NotYetImplemented(format!("{a} == {b}"))), |
62 | | /// ) |
63 | | /// } |
64 | | /// |
65 | | /// // Comparison of two arrays |
66 | | /// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
67 | | /// let b = Int32Array::from(vec![1, 2, 4, 7, 3]); |
68 | | /// let r = eq(&a, &b).unwrap(); |
69 | | /// let values: Vec<_> = r.values().iter().collect(); |
70 | | /// assert_eq!(values, &[true, true, false, false, false]); |
71 | | /// |
72 | | /// // Comparison of an array and a scalar |
73 | | /// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
74 | | /// let b = Int32Array::new_scalar(1); |
75 | | /// let r = eq(&a, &b).unwrap(); |
76 | | /// let values: Vec<_> = r.values().iter().collect(); |
77 | | /// assert_eq!(values, &[true, false, false, false, false]); |
78 | | pub trait Datum { |
79 | | /// Returns the value for this [`Datum`] and a boolean indicating if the value is scalar |
80 | | fn get(&self) -> (&dyn Array, bool); |
81 | | } |
82 | | |
83 | | impl<T: Array> Datum for T { |
84 | 0 | fn get(&self) -> (&dyn Array, bool) { |
85 | 0 | (self, false) |
86 | 0 | } |
87 | | } |
88 | | |
89 | | impl Datum for dyn Array { |
90 | 0 | fn get(&self) -> (&dyn Array, bool) { |
91 | 0 | (self, false) |
92 | 0 | } |
93 | | } |
94 | | |
95 | | impl Datum for &dyn Array { |
96 | 0 | fn get(&self) -> (&dyn Array, bool) { |
97 | 0 | (*self, false) |
98 | 0 | } |
99 | | } |
100 | | |
101 | | /// A wrapper around a single value [`Array`] that implements |
102 | | /// [`Datum`] and indicates [compute] kernels should treat this array |
103 | | /// as a scalar value (a single value). |
104 | | /// |
105 | | /// Using a [`Scalar`] is often much more efficient than creating an |
106 | | /// [`Array`] with the same (repeated) value. |
107 | | /// |
108 | | /// See [`Datum`] for more information. |
109 | | /// |
110 | | /// # Example |
111 | | /// |
112 | | /// ```rust |
113 | | /// # use arrow_array::{Scalar, Int32Array, ArrayRef}; |
114 | | /// # fn get_array() -> ArrayRef { std::sync::Arc::new(Int32Array::from(vec![42])) } |
115 | | /// // Create a (typed) scalar for Int32Array for the value 42 |
116 | | /// let scalar = Scalar::new(Int32Array::from(vec![42])); |
117 | | /// |
118 | | /// // Create a scalar using PrimtiveArray::scalar |
119 | | /// let scalar = Int32Array::new_scalar(42); |
120 | | /// |
121 | | /// // create a scalar from an ArrayRef (for dynamic typed Arrays) |
122 | | /// let array: ArrayRef = get_array(); |
123 | | /// let scalar = Scalar::new(array); |
124 | | /// ``` |
125 | | /// |
126 | | /// [compute]: https://docs.rs/arrow/latest/arrow/compute/index.html |
127 | | #[derive(Debug, Copy, Clone)] |
128 | | pub struct Scalar<T: Array>(T); |
129 | | |
130 | | impl<T: Array> Scalar<T> { |
131 | | /// Create a new [`Scalar`] from an [`Array`] |
132 | | /// |
133 | | /// # Panics |
134 | | /// |
135 | | /// Panics if `array.len() != 1` |
136 | 0 | pub fn new(array: T) -> Self { |
137 | 0 | assert_eq!(array.len(), 1); |
138 | 0 | Self(array) |
139 | 0 | } |
140 | | |
141 | | /// Returns the inner array |
142 | | #[inline] |
143 | | pub fn into_inner(self) -> T { |
144 | | self.0 |
145 | | } |
146 | | } |
147 | | |
148 | | impl<T: Array> Datum for Scalar<T> { |
149 | 0 | fn get(&self) -> (&dyn Array, bool) { |
150 | 0 | (&self.0, true) |
151 | 0 | } |
152 | | } |