/Users/andrewlamb/Software/arrow-rs/arrow-select/src/zip.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`zip`]: Combine values from two arrays based on boolean mask |
19 | | |
20 | | use crate::filter::SlicesIterator; |
21 | | use arrow_array::*; |
22 | | use arrow_data::transform::MutableArrayData; |
23 | | use arrow_schema::ArrowError; |
24 | | |
25 | | /// Zip two arrays by some boolean mask. |
26 | | /// |
27 | | /// - Where `mask` is `true`, values of `truthy` are taken |
28 | | /// - Where `mask` is `false` or `NULL`, values of `falsy` are taken |
29 | | /// |
30 | | /// # Example: `zip` two arrays |
31 | | /// ``` |
32 | | /// # use std::sync::Arc; |
33 | | /// # use arrow_array::{ArrayRef, BooleanArray, Int32Array}; |
34 | | /// # use arrow_select::zip::zip; |
35 | | /// // mask: [true, true, false, NULL, true] |
36 | | /// let mask = BooleanArray::from(vec![ |
37 | | /// Some(true), Some(true), Some(false), None, Some(true) |
38 | | /// ]); |
39 | | /// // truthy array: [1, NULL, 3, 4, 5] |
40 | | /// let truthy = Int32Array::from(vec![ |
41 | | /// Some(1), None, Some(3), Some(4), Some(5) |
42 | | /// ]); |
43 | | /// // falsy array: [10, 20, 30, 40, 50] |
44 | | /// let falsy = Int32Array::from(vec![ |
45 | | /// Some(10), Some(20), Some(30), Some(40), Some(50) |
46 | | /// ]); |
47 | | /// // zip with this mask select the first, second and last value from `truthy` |
48 | | /// // and the third and fourth value from `falsy` |
49 | | /// let result = zip(&mask, &truthy, &falsy).unwrap(); |
50 | | /// // Expected: [1, NULL, 30, 40, 5] |
51 | | /// let expected: ArrayRef = Arc::new(Int32Array::from(vec![ |
52 | | /// Some(1), None, Some(30), Some(40), Some(5) |
53 | | /// ])); |
54 | | /// assert_eq!(&result, &expected); |
55 | | /// ``` |
56 | | /// |
57 | | /// # Example: `zip` and array with a scalar |
58 | | /// |
59 | | /// Use `zip` to replace certain values in an array with a scalar |
60 | | /// |
61 | | /// ``` |
62 | | /// # use std::sync::Arc; |
63 | | /// # use arrow_array::{ArrayRef, BooleanArray, Int32Array}; |
64 | | /// # use arrow_select::zip::zip; |
65 | | /// // mask: [true, true, false, NULL, true] |
66 | | /// let mask = BooleanArray::from(vec![ |
67 | | /// Some(true), Some(true), Some(false), None, Some(true) |
68 | | /// ]); |
69 | | /// // array: [1, NULL, 3, 4, 5] |
70 | | /// let arr = Int32Array::from(vec![ |
71 | | /// Some(1), None, Some(3), Some(4), Some(5) |
72 | | /// ]); |
73 | | /// // scalar: 42 |
74 | | /// let scalar = Int32Array::new_scalar(42); |
75 | | /// // zip the array with the mask select the first, second and last value from `arr` |
76 | | /// // and fill the third and fourth value with the scalar 42 |
77 | | /// let result = zip(&mask, &arr, &scalar).unwrap(); |
78 | | /// // Expected: [1, NULL, 42, 42, 5] |
79 | | /// let expected: ArrayRef = Arc::new(Int32Array::from(vec![ |
80 | | /// Some(1), None, Some(42), Some(42), Some(5) |
81 | | /// ])); |
82 | | /// assert_eq!(&result, &expected); |
83 | | /// ``` |
84 | | pub fn zip( |
85 | | mask: &BooleanArray, |
86 | | truthy: &dyn Datum, |
87 | | falsy: &dyn Datum, |
88 | | ) -> Result<ArrayRef, ArrowError> { |
89 | | let (truthy, truthy_is_scalar) = truthy.get(); |
90 | | let (falsy, falsy_is_scalar) = falsy.get(); |
91 | | |
92 | | if truthy.data_type() != falsy.data_type() { |
93 | | return Err(ArrowError::InvalidArgumentError( |
94 | | "arguments need to have the same data type".into(), |
95 | | )); |
96 | | } |
97 | | |
98 | | if truthy_is_scalar && truthy.len() != 1 { |
99 | | return Err(ArrowError::InvalidArgumentError( |
100 | | "scalar arrays must have 1 element".into(), |
101 | | )); |
102 | | } |
103 | | if !truthy_is_scalar && truthy.len() != mask.len() { |
104 | | return Err(ArrowError::InvalidArgumentError( |
105 | | "all arrays should have the same length".into(), |
106 | | )); |
107 | | } |
108 | | if falsy_is_scalar && falsy.len() != 1 { |
109 | | return Err(ArrowError::InvalidArgumentError( |
110 | | "scalar arrays must have 1 element".into(), |
111 | | )); |
112 | | } |
113 | | if !falsy_is_scalar && falsy.len() != mask.len() { |
114 | | return Err(ArrowError::InvalidArgumentError( |
115 | | "all arrays should have the same length".into(), |
116 | | )); |
117 | | } |
118 | | |
119 | | let falsy = falsy.to_data(); |
120 | | let truthy = truthy.to_data(); |
121 | | |
122 | | let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len()); |
123 | | |
124 | | // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to |
125 | | // fill with falsy values |
126 | | |
127 | | // keep track of how much is filled |
128 | | let mut filled = 0; |
129 | | |
130 | 0 | SlicesIterator::new(mask).for_each(|(start, end)| { |
131 | | // the gap needs to be filled with falsy values |
132 | 0 | if start > filled { |
133 | 0 | if falsy_is_scalar { |
134 | 0 | for _ in filled..start { |
135 | 0 | // Copy the first item from the 'falsy' array into the output buffer. |
136 | 0 | mutable.extend(1, 0, 1); |
137 | 0 | } |
138 | 0 | } else { |
139 | 0 | mutable.extend(1, filled, start); |
140 | 0 | } |
141 | 0 | } |
142 | | // fill with truthy values |
143 | 0 | if truthy_is_scalar { |
144 | 0 | for _ in start..end { |
145 | 0 | // Copy the first item from the 'truthy' array into the output buffer. |
146 | 0 | mutable.extend(0, 0, 1); |
147 | 0 | } |
148 | 0 | } else { |
149 | 0 | mutable.extend(0, start, end); |
150 | 0 | } |
151 | 0 | filled = end; |
152 | 0 | }); |
153 | | // the remaining part is falsy |
154 | | if filled < mask.len() { |
155 | | if falsy_is_scalar { |
156 | | for _ in filled..mask.len() { |
157 | | // Copy the first item from the 'falsy' array into the output buffer. |
158 | | mutable.extend(1, 0, 1); |
159 | | } |
160 | | } else { |
161 | | mutable.extend(1, filled, mask.len()); |
162 | | } |
163 | | } |
164 | | |
165 | | let data = mutable.freeze(); |
166 | | Ok(make_array(data)) |
167 | | } |
168 | | |
169 | | #[cfg(test)] |
170 | | mod test { |
171 | | use super::*; |
172 | | |
173 | | #[test] |
174 | | fn test_zip_kernel_one() { |
175 | | let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]); |
176 | | let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]); |
177 | | let mask = BooleanArray::from(vec![true, true, false, false, true]); |
178 | | let out = zip(&mask, &a, &b).unwrap(); |
179 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
180 | | let expected = Int32Array::from(vec![Some(5), None, Some(6), Some(7), Some(1)]); |
181 | | assert_eq!(actual, &expected); |
182 | | } |
183 | | |
184 | | #[test] |
185 | | fn test_zip_kernel_two() { |
186 | | let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]); |
187 | | let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]); |
188 | | let mask = BooleanArray::from(vec![false, false, true, true, false]); |
189 | | let out = zip(&mask, &a, &b).unwrap(); |
190 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
191 | | let expected = Int32Array::from(vec![None, Some(3), Some(7), None, Some(3)]); |
192 | | assert_eq!(actual, &expected); |
193 | | } |
194 | | |
195 | | #[test] |
196 | | fn test_zip_kernel_scalar_falsy_1() { |
197 | | let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]); |
198 | | |
199 | | let fallback = Scalar::new(Int32Array::from_value(42, 1)); |
200 | | |
201 | | let mask = BooleanArray::from(vec![true, true, false, false, true]); |
202 | | let out = zip(&mask, &a, &fallback).unwrap(); |
203 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
204 | | let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]); |
205 | | assert_eq!(actual, &expected); |
206 | | } |
207 | | |
208 | | #[test] |
209 | | fn test_zip_kernel_scalar_falsy_2() { |
210 | | let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]); |
211 | | |
212 | | let fallback = Scalar::new(Int32Array::from_value(42, 1)); |
213 | | |
214 | | let mask = BooleanArray::from(vec![false, false, true, true, false]); |
215 | | let out = zip(&mask, &a, &fallback).unwrap(); |
216 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
217 | | let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]); |
218 | | assert_eq!(actual, &expected); |
219 | | } |
220 | | |
221 | | #[test] |
222 | | fn test_zip_kernel_scalar_truthy_1() { |
223 | | let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]); |
224 | | |
225 | | let fallback = Scalar::new(Int32Array::from_value(42, 1)); |
226 | | |
227 | | let mask = BooleanArray::from(vec![true, true, false, false, true]); |
228 | | let out = zip(&mask, &fallback, &a).unwrap(); |
229 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
230 | | let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]); |
231 | | assert_eq!(actual, &expected); |
232 | | } |
233 | | |
234 | | #[test] |
235 | | fn test_zip_kernel_scalar_truthy_2() { |
236 | | let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]); |
237 | | |
238 | | let fallback = Scalar::new(Int32Array::from_value(42, 1)); |
239 | | |
240 | | let mask = BooleanArray::from(vec![false, false, true, true, false]); |
241 | | let out = zip(&mask, &fallback, &a).unwrap(); |
242 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
243 | | let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]); |
244 | | assert_eq!(actual, &expected); |
245 | | } |
246 | | |
247 | | #[test] |
248 | | fn test_zip_kernel_scalar_both() { |
249 | | let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); |
250 | | let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1)); |
251 | | |
252 | | let mask = BooleanArray::from(vec![true, true, false, false, true]); |
253 | | let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
254 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
255 | | let expected = Int32Array::from(vec![Some(42), Some(42), Some(123), Some(123), Some(42)]); |
256 | | assert_eq!(actual, &expected); |
257 | | } |
258 | | |
259 | | #[test] |
260 | | fn test_zip_kernel_scalar_none_1() { |
261 | | let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); |
262 | | let scalar_falsy = Scalar::new(Int32Array::new_null(1)); |
263 | | |
264 | | let mask = BooleanArray::from(vec![true, true, false, false, true]); |
265 | | let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
266 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
267 | | let expected = Int32Array::from(vec![Some(42), Some(42), None, None, Some(42)]); |
268 | | assert_eq!(actual, &expected); |
269 | | } |
270 | | |
271 | | #[test] |
272 | | fn test_zip_kernel_scalar_none_2() { |
273 | | let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); |
274 | | let scalar_falsy = Scalar::new(Int32Array::new_null(1)); |
275 | | |
276 | | let mask = BooleanArray::from(vec![false, false, true, true, false]); |
277 | | let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
278 | | let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
279 | | let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]); |
280 | | assert_eq!(actual, &expected); |
281 | | } |
282 | | } |