/Users/andrewlamb/Software/arrow-rs/arrow-select/src/take.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines take kernel for [Array] |
19 | | |
20 | | use std::sync::Arc; |
21 | | |
22 | | use arrow_array::builder::{BufferBuilder, UInt32Builder}; |
23 | | use arrow_array::cast::AsArray; |
24 | | use arrow_array::types::*; |
25 | | use arrow_array::*; |
26 | | use arrow_buffer::{ |
27 | | bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, |
28 | | ScalarBuffer, |
29 | | }; |
30 | | use arrow_data::ArrayDataBuilder; |
31 | | use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode}; |
32 | | |
33 | | use num::{One, Zero}; |
34 | | |
35 | | /// Take elements by index from [Array], creating a new [Array] from those indexes. |
36 | | /// |
37 | | /// ```text |
38 | | /// ┌─────────────────┐ ┌─────────┐ ┌─────────────────┐ |
39 | | /// │ A │ │ 0 │ │ A │ |
40 | | /// ├─────────────────┤ ├─────────┤ ├─────────────────┤ |
41 | | /// │ D │ │ 2 │ │ B │ |
42 | | /// ├─────────────────┤ ├─────────┤ take(values, indices) ├─────────────────┤ |
43 | | /// │ B │ │ 3 │ ─────────────────────────▶ │ C │ |
44 | | /// ├─────────────────┤ ├─────────┤ ├─────────────────┤ |
45 | | /// │ C │ │ 1 │ │ D │ |
46 | | /// ├─────────────────┤ └─────────┘ └─────────────────┘ |
47 | | /// │ E │ |
48 | | /// └─────────────────┘ |
49 | | /// values array indices array result |
50 | | /// ``` |
51 | | /// |
52 | | /// For selecting values by index from multiple arrays see [`crate::interleave`] |
53 | | /// |
54 | | /// Note that this kernel, similar to other kernels in this crate, |
55 | | /// will avoid allocating where not necessary. Consequently |
56 | | /// the returned array may share buffers with the inputs |
57 | | /// |
58 | | /// # Errors |
59 | | /// This function errors whenever: |
60 | | /// * An index cannot be casted to `usize` (typically 32 bit architectures) |
61 | | /// * An index is out of bounds and `options` is set to check bounds. |
62 | | /// |
63 | | /// # Safety |
64 | | /// |
65 | | /// When `options` is not set to check bounds, taking indexes after `len` will panic. |
66 | | /// |
67 | | /// # See also |
68 | | /// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce |
69 | | /// the results into a single array. |
70 | | /// |
71 | | /// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer |
72 | | /// |
73 | | /// # Examples |
74 | | /// ``` |
75 | | /// # use arrow_array::{StringArray, UInt32Array, cast::AsArray}; |
76 | | /// # use arrow_select::take::take; |
77 | | /// let values = StringArray::from(vec!["zero", "one", "two"]); |
78 | | /// |
79 | | /// // Take items at index 2, and 1: |
80 | | /// let indices = UInt32Array::from(vec![2, 1]); |
81 | | /// let taken = take(&values, &indices, None).unwrap(); |
82 | | /// let taken = taken.as_string::<i32>(); |
83 | | /// |
84 | | /// assert_eq!(*taken, StringArray::from(vec!["two", "one"])); |
85 | | /// ``` |
86 | 0 | pub fn take( |
87 | 0 | values: &dyn Array, |
88 | 0 | indices: &dyn Array, |
89 | 0 | options: Option<TakeOptions>, |
90 | 0 | ) -> Result<ArrayRef, ArrowError> { |
91 | 0 | let options = options.unwrap_or_default(); |
92 | 0 | downcast_integer_array!( |
93 | | indices => { |
94 | 0 | if options.check_bounds { |
95 | 0 | check_bounds(values.len(), indices)?; |
96 | 0 | } |
97 | 0 | let indices = indices.to_indices(); |
98 | 0 | take_impl(values, &indices) |
99 | | }, |
100 | 0 | d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}"))) |
101 | | ) |
102 | 0 | } |
103 | | |
104 | | /// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new |
105 | | /// [`Vec<ArrayRef>`] from those indices. |
106 | | /// |
107 | | /// ```text |
108 | | /// ┌────────┬────────┐ |
109 | | /// │ │ │ ┌────────┐ ┌────────┬────────┐ |
110 | | /// │ A │ 1 │ │ │ │ │ │ |
111 | | /// ├────────┼────────┤ │ 0 │ │ A │ 1 │ |
112 | | /// │ │ │ ├────────┤ ├────────┼────────┤ |
113 | | /// │ D │ 4 │ │ │ │ │ │ |
114 | | /// ├────────┼────────┤ │ 2 │ take_arrays(values,indices) │ B │ 2 │ |
115 | | /// │ │ │ ├────────┤ ├────────┼────────┤ |
116 | | /// │ B │ 2 │ │ │ ───────────────────────────► │ │ │ |
117 | | /// ├────────┼────────┤ │ 3 │ │ C │ 3 │ |
118 | | /// │ │ │ ├────────┤ ├────────┼────────┤ |
119 | | /// │ C │ 3 │ │ │ │ │ │ |
120 | | /// ├────────┼────────┤ │ 1 │ │ D │ 4 │ |
121 | | /// │ │ │ └────────┘ └────────┼────────┘ |
122 | | /// │ E │ 5 │ |
123 | | /// └────────┴────────┘ |
124 | | /// values arrays indices array result |
125 | | /// ``` |
126 | | /// |
127 | | /// # Errors |
128 | | /// This function errors whenever: |
129 | | /// * An index cannot be casted to `usize` (typically 32 bit architectures) |
130 | | /// * An index is out of bounds and `options` is set to check bounds. |
131 | | /// |
132 | | /// # Safety |
133 | | /// |
134 | | /// When `options` is not set to check bounds, taking indexes after `len` will panic. |
135 | | /// |
136 | | /// # Examples |
137 | | /// ``` |
138 | | /// # use std::sync::Arc; |
139 | | /// # use arrow_array::{StringArray, UInt32Array, cast::AsArray}; |
140 | | /// # use arrow_select::take::{take, take_arrays}; |
141 | | /// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"])); |
142 | | /// let values = Arc::new(UInt32Array::from(vec![0, 1, 2])); |
143 | | /// |
144 | | /// // Take items at index 2, and 1: |
145 | | /// let indices = UInt32Array::from(vec![2, 1]); |
146 | | /// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap(); |
147 | | /// let taken_string = taken_arrays[0].as_string::<i32>(); |
148 | | /// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"])); |
149 | | /// let taken_values = taken_arrays[1].as_primitive(); |
150 | | /// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1])); |
151 | | /// ``` |
152 | 0 | pub fn take_arrays( |
153 | 0 | arrays: &[ArrayRef], |
154 | 0 | indices: &dyn Array, |
155 | 0 | options: Option<TakeOptions>, |
156 | 0 | ) -> Result<Vec<ArrayRef>, ArrowError> { |
157 | 0 | arrays |
158 | 0 | .iter() |
159 | 0 | .map(|array| take(array.as_ref(), indices, options.clone())) |
160 | 0 | .collect() |
161 | 0 | } |
162 | | |
163 | | /// Verifies that the non-null values of `indices` are all `< len` |
164 | 0 | fn check_bounds<T: ArrowPrimitiveType>( |
165 | 0 | len: usize, |
166 | 0 | indices: &PrimitiveArray<T>, |
167 | 0 | ) -> Result<(), ArrowError> { |
168 | 0 | if indices.null_count() > 0 { |
169 | 0 | indices.iter().flatten().try_for_each(|index| { |
170 | 0 | let ix = index |
171 | 0 | .to_usize() |
172 | 0 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?; |
173 | 0 | if ix >= len { |
174 | 0 | return Err(ArrowError::ComputeError(format!( |
175 | 0 | "Array index out of bounds, cannot get item at index {ix} from {len} entries" |
176 | 0 | ))); |
177 | 0 | } |
178 | 0 | Ok(()) |
179 | 0 | }) |
180 | | } else { |
181 | 0 | indices.values().iter().try_for_each(|index| { |
182 | 0 | let ix = index |
183 | 0 | .to_usize() |
184 | 0 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?; |
185 | 0 | if ix >= len { |
186 | 0 | return Err(ArrowError::ComputeError(format!( |
187 | 0 | "Array index out of bounds, cannot get item at index {ix} from {len} entries" |
188 | 0 | ))); |
189 | 0 | } |
190 | 0 | Ok(()) |
191 | 0 | }) |
192 | | } |
193 | 0 | } |
194 | | |
195 | | #[inline(never)] |
196 | 0 | fn take_impl<IndexType: ArrowPrimitiveType>( |
197 | 0 | values: &dyn Array, |
198 | 0 | indices: &PrimitiveArray<IndexType>, |
199 | 0 | ) -> Result<ArrayRef, ArrowError> { |
200 | 0 | downcast_primitive_array! { |
201 | 0 | values => Ok(Arc::new(take_primitive(values, indices)?)), |
202 | | DataType::Boolean => { |
203 | 0 | let values = values.as_any().downcast_ref::<BooleanArray>().unwrap(); |
204 | 0 | Ok(Arc::new(take_boolean(values, indices))) |
205 | | } |
206 | | DataType::Utf8 => { |
207 | 0 | Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?)) |
208 | | } |
209 | | DataType::LargeUtf8 => { |
210 | 0 | Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?)) |
211 | | } |
212 | | DataType::Utf8View => { |
213 | 0 | Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?)) |
214 | | } |
215 | | DataType::List(_) => { |
216 | 0 | Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?)) |
217 | | } |
218 | | DataType::LargeList(_) => { |
219 | 0 | Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?)) |
220 | | } |
221 | 0 | DataType::FixedSizeList(_, length) => { |
222 | 0 | let values = values |
223 | 0 | .as_any() |
224 | 0 | .downcast_ref::<FixedSizeListArray>() |
225 | 0 | .unwrap(); |
226 | 0 | Ok(Arc::new(take_fixed_size_list( |
227 | 0 | values, |
228 | 0 | indices, |
229 | 0 | *length as u32, |
230 | 0 | )?)) |
231 | | } |
232 | | DataType::Map(_, _) => { |
233 | 0 | let list_arr = ListArray::from(values.as_map().clone()); |
234 | 0 | let list_data = take_list::<_, Int32Type>(&list_arr, indices)?; |
235 | 0 | let builder = list_data.into_data().into_builder().data_type(values.data_type().clone()); |
236 | 0 | Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() }))) |
237 | | } |
238 | 0 | DataType::Struct(fields) => { |
239 | 0 | let array: &StructArray = values.as_struct(); |
240 | 0 | let arrays = array |
241 | 0 | .columns() |
242 | 0 | .iter() |
243 | 0 | .map(|a| take_impl(a.as_ref(), indices)) |
244 | 0 | .collect::<Result<Vec<ArrayRef>, _>>()?; |
245 | 0 | let fields: Vec<(FieldRef, ArrayRef)> = |
246 | 0 | fields.iter().cloned().zip(arrays).collect(); |
247 | | |
248 | | // Create the null bit buffer. |
249 | 0 | let is_valid: Buffer = indices |
250 | 0 | .iter() |
251 | 0 | .map(|index| { |
252 | 0 | if let Some(index) = index { |
253 | 0 | array.is_valid(index.to_usize().unwrap()) |
254 | | } else { |
255 | 0 | false |
256 | | } |
257 | 0 | }) |
258 | 0 | .collect(); |
259 | | |
260 | 0 | if fields.is_empty() { |
261 | 0 | let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len())); |
262 | 0 | Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls)))) |
263 | | } else { |
264 | 0 | Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef) |
265 | | } |
266 | | } |
267 | 0 | DataType::Dictionary(_, _) => downcast_dictionary_array! { |
268 | 0 | values => Ok(Arc::new(take_dict(values, indices)?)), |
269 | 0 | t => unimplemented!("Take not supported for dictionary type {:?}", t) |
270 | | } |
271 | 0 | DataType::RunEndEncoded(_, _) => downcast_run_array! { |
272 | 0 | values => Ok(Arc::new(take_run(values, indices)?)), |
273 | 0 | t => unimplemented!("Take not supported for run type {:?}", t) |
274 | | } |
275 | | DataType::Binary => { |
276 | 0 | Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?)) |
277 | | } |
278 | | DataType::LargeBinary => { |
279 | 0 | Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?)) |
280 | | } |
281 | | DataType::BinaryView => { |
282 | 0 | Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?)) |
283 | | } |
284 | 0 | DataType::FixedSizeBinary(size) => { |
285 | 0 | let values = values |
286 | 0 | .as_any() |
287 | 0 | .downcast_ref::<FixedSizeBinaryArray>() |
288 | 0 | .unwrap(); |
289 | 0 | Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?)) |
290 | | } |
291 | | DataType::Null => { |
292 | | // Take applied to a null array produces a null array. |
293 | 0 | if values.len() >= indices.len() { |
294 | | // If the existing null array is as big as the indices, we can use a slice of it |
295 | | // to avoid allocating a new null array. |
296 | 0 | Ok(values.slice(0, indices.len())) |
297 | | } else { |
298 | | // If the existing null array isn't big enough, create a new one. |
299 | 0 | Ok(new_null_array(&DataType::Null, indices.len())) |
300 | | } |
301 | | } |
302 | 0 | DataType::Union(fields, UnionMode::Sparse) => { |
303 | 0 | let mut children = Vec::with_capacity(fields.len()); |
304 | 0 | let values = values.as_any().downcast_ref::<UnionArray>().unwrap(); |
305 | 0 | let type_ids = take_native(values.type_ids(), indices); |
306 | 0 | for (type_id, _field) in fields.iter() { |
307 | 0 | let values = values.child(type_id); |
308 | 0 | let values = take_impl(values, indices)?; |
309 | 0 | children.push(values); |
310 | | } |
311 | 0 | let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?; |
312 | 0 | Ok(Arc::new(array)) |
313 | | } |
314 | 0 | DataType::Union(fields, UnionMode::Dense) => { |
315 | 0 | let values = values.as_any().downcast_ref::<UnionArray>().unwrap(); |
316 | | |
317 | 0 | let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None); |
318 | 0 | let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None); |
319 | | |
320 | 0 | let children = fields.iter() |
321 | 0 | .map(|(field_type_id, _)| { |
322 | 0 | let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id); |
323 | | |
324 | 0 | let indices = crate::filter::filter(&offsets, &mask)?; |
325 | | |
326 | 0 | let values = values.child(field_type_id); |
327 | | |
328 | 0 | take_impl(values, indices.as_primitive::<Int32Type>()) |
329 | 0 | }) |
330 | 0 | .collect::<Result<_, _>>()?; |
331 | | |
332 | 0 | let mut child_offsets = [0; 128]; |
333 | | |
334 | 0 | let offsets = type_ids.values() |
335 | 0 | .iter() |
336 | 0 | .map(|&i| { |
337 | 0 | let offset = child_offsets[i as usize]; |
338 | | |
339 | 0 | child_offsets[i as usize] += 1; |
340 | | |
341 | 0 | offset |
342 | 0 | }) |
343 | 0 | .collect(); |
344 | | |
345 | 0 | let (_, type_ids, _) = type_ids.into_parts(); |
346 | | |
347 | 0 | let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?; |
348 | | |
349 | 0 | Ok(Arc::new(array)) |
350 | | } |
351 | 0 | t => unimplemented!("Take not supported for data type {:?}", t) |
352 | | } |
353 | 0 | } |
354 | | |
355 | | /// Options that define how `take` should behave |
356 | | #[derive(Clone, Debug, Default)] |
357 | | pub struct TakeOptions { |
358 | | /// Perform bounds check before taking indices from values. |
359 | | /// If enabled, an `ArrowError` is returned if the indices are out of bounds. |
360 | | /// If not enabled, and indices exceed bounds, the kernel will panic. |
361 | | pub check_bounds: bool, |
362 | | } |
363 | | |
364 | | #[inline(always)] |
365 | 0 | fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> { |
366 | 0 | index |
367 | 0 | .to_usize() |
368 | 0 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string())) |
369 | 0 | } |
370 | | |
371 | | /// `take` implementation for all primitive arrays |
372 | | /// |
373 | | /// This checks if an `indices` slot is populated, and gets the value from `values` |
374 | | /// as the populated index. |
375 | | /// If the `indices` slot is null, a null value is returned. |
376 | | /// For example, given: |
377 | | /// values: [1, 2, 3, null, 5] |
378 | | /// indices: [0, null, 4, 3] |
379 | | /// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)] |
380 | 0 | fn take_primitive<T, I>( |
381 | 0 | values: &PrimitiveArray<T>, |
382 | 0 | indices: &PrimitiveArray<I>, |
383 | 0 | ) -> Result<PrimitiveArray<T>, ArrowError> |
384 | 0 | where |
385 | 0 | T: ArrowPrimitiveType, |
386 | 0 | I: ArrowPrimitiveType, |
387 | | { |
388 | 0 | let values_buf = take_native(values.values(), indices); |
389 | 0 | let nulls = take_nulls(values.nulls(), indices); |
390 | 0 | Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone())) |
391 | 0 | } |
392 | | |
393 | | #[inline(never)] |
394 | 0 | fn take_nulls<I: ArrowPrimitiveType>( |
395 | 0 | values: Option<&NullBuffer>, |
396 | 0 | indices: &PrimitiveArray<I>, |
397 | 0 | ) -> Option<NullBuffer> { |
398 | 0 | match values.filter(|n| n.null_count() > 0) { |
399 | 0 | Some(n) => { |
400 | 0 | let buffer = take_bits(n.inner(), indices); |
401 | 0 | Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0) |
402 | | } |
403 | 0 | None => indices.nulls().cloned(), |
404 | | } |
405 | 0 | } |
406 | | |
407 | | #[inline(never)] |
408 | 0 | fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>( |
409 | 0 | values: &[T], |
410 | 0 | indices: &PrimitiveArray<I>, |
411 | 0 | ) -> ScalarBuffer<T> { |
412 | 0 | match indices.nulls().filter(|n| n.null_count() > 0) { |
413 | 0 | Some(n) => indices |
414 | 0 | .values() |
415 | 0 | .iter() |
416 | 0 | .enumerate() |
417 | 0 | .map(|(idx, index)| match values.get(index.as_usize()) { |
418 | 0 | Some(v) => *v, |
419 | 0 | None => match n.is_null(idx) { |
420 | 0 | true => T::default(), |
421 | 0 | false => panic!("Out-of-bounds index {index:?}"), |
422 | | }, |
423 | 0 | }) |
424 | 0 | .collect(), |
425 | 0 | None => indices |
426 | 0 | .values() |
427 | 0 | .iter() |
428 | 0 | .map(|index| values[index.as_usize()]) |
429 | 0 | .collect(), |
430 | | } |
431 | 0 | } |
432 | | |
433 | | #[inline(never)] |
434 | 0 | fn take_bits<I: ArrowPrimitiveType>( |
435 | 0 | values: &BooleanBuffer, |
436 | 0 | indices: &PrimitiveArray<I>, |
437 | 0 | ) -> BooleanBuffer { |
438 | 0 | let len = indices.len(); |
439 | | |
440 | 0 | match indices.nulls().filter(|n| n.null_count() > 0) { |
441 | 0 | Some(nulls) => { |
442 | 0 | let mut output_buffer = MutableBuffer::new_null(len); |
443 | 0 | let output_slice = output_buffer.as_slice_mut(); |
444 | 0 | nulls.valid_indices().for_each(|idx| { |
445 | 0 | if values.value(indices.value(idx).as_usize()) { |
446 | 0 | bit_util::set_bit(output_slice, idx); |
447 | 0 | } |
448 | 0 | }); |
449 | 0 | BooleanBuffer::new(output_buffer.into(), 0, len) |
450 | | } |
451 | | None => { |
452 | 0 | BooleanBuffer::collect_bool(len, |idx: usize| { |
453 | | // SAFETY: idx<indices.len() |
454 | 0 | values.value(unsafe { indices.value_unchecked(idx).as_usize() }) |
455 | 0 | }) |
456 | | } |
457 | | } |
458 | 0 | } |
459 | | |
460 | | /// `take` implementation for boolean arrays |
461 | 0 | fn take_boolean<IndexType: ArrowPrimitiveType>( |
462 | 0 | values: &BooleanArray, |
463 | 0 | indices: &PrimitiveArray<IndexType>, |
464 | 0 | ) -> BooleanArray { |
465 | 0 | let val_buf = take_bits(values.values(), indices); |
466 | 0 | let null_buf = take_nulls(values.nulls(), indices); |
467 | 0 | BooleanArray::new(val_buf, null_buf) |
468 | 0 | } |
469 | | |
470 | | /// `take` implementation for string arrays |
471 | 0 | fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>( |
472 | 0 | array: &GenericByteArray<T>, |
473 | 0 | indices: &PrimitiveArray<IndexType>, |
474 | 0 | ) -> Result<GenericByteArray<T>, ArrowError> { |
475 | 0 | let mut offsets = Vec::with_capacity(indices.len() + 1); |
476 | 0 | offsets.push(T::Offset::default()); |
477 | | |
478 | 0 | let input_offsets = array.value_offsets(); |
479 | 0 | let mut capacity = 0; |
480 | 0 | let nulls = take_nulls(array.nulls(), indices); |
481 | | |
482 | 0 | let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 { |
483 | 0 | offsets.reserve(indices.len()); |
484 | 0 | for index in indices.values() { |
485 | 0 | let index = index.as_usize(); |
486 | 0 | capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); |
487 | 0 | offsets.push( |
488 | 0 | T::Offset::from_usize(capacity) |
489 | 0 | .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?, |
490 | | ); |
491 | | } |
492 | 0 | let mut values = Vec::with_capacity(capacity); |
493 | | |
494 | 0 | for index in indices.values() { |
495 | 0 | values.extend_from_slice(array.value(index.as_usize()).as_ref()); |
496 | 0 | } |
497 | 0 | (offsets, values) |
498 | 0 | } else if indices.null_count() == 0 { |
499 | 0 | offsets.reserve(indices.len()); |
500 | 0 | for index in indices.values() { |
501 | 0 | let index = index.as_usize(); |
502 | 0 | if array.is_valid(index) { |
503 | 0 | capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); |
504 | 0 | } |
505 | 0 | offsets.push( |
506 | 0 | T::Offset::from_usize(capacity) |
507 | 0 | .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?, |
508 | | ); |
509 | | } |
510 | 0 | let mut values = Vec::with_capacity(capacity); |
511 | | |
512 | 0 | for index in indices.values() { |
513 | 0 | let index = index.as_usize(); |
514 | 0 | if array.is_valid(index) { |
515 | 0 | values.extend_from_slice(array.value(index).as_ref()); |
516 | 0 | } |
517 | | } |
518 | 0 | (offsets, values) |
519 | 0 | } else if array.null_count() == 0 { |
520 | 0 | offsets.reserve(indices.len()); |
521 | 0 | for (i, index) in indices.values().iter().enumerate() { |
522 | 0 | let index = index.as_usize(); |
523 | 0 | if indices.is_valid(i) { |
524 | 0 | capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); |
525 | 0 | } |
526 | 0 | offsets.push( |
527 | 0 | T::Offset::from_usize(capacity) |
528 | 0 | .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?, |
529 | | ); |
530 | | } |
531 | 0 | let mut values = Vec::with_capacity(capacity); |
532 | | |
533 | 0 | for (i, index) in indices.values().iter().enumerate() { |
534 | 0 | if indices.is_valid(i) { |
535 | 0 | values.extend_from_slice(array.value(index.as_usize()).as_ref()); |
536 | 0 | } |
537 | | } |
538 | 0 | (offsets, values) |
539 | | } else { |
540 | 0 | let nulls = nulls.as_ref().unwrap(); |
541 | 0 | offsets.reserve(indices.len()); |
542 | 0 | for (i, index) in indices.values().iter().enumerate() { |
543 | 0 | let index = index.as_usize(); |
544 | 0 | if nulls.is_valid(i) { |
545 | 0 | capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); |
546 | 0 | } |
547 | 0 | offsets.push( |
548 | 0 | T::Offset::from_usize(capacity) |
549 | 0 | .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?, |
550 | | ); |
551 | | } |
552 | 0 | let mut values = Vec::with_capacity(capacity); |
553 | | |
554 | 0 | for (i, index) in indices.values().iter().enumerate() { |
555 | | // check index is valid before using index. The value in |
556 | | // NULL index slots may not be within bounds of array |
557 | 0 | let index = index.as_usize(); |
558 | 0 | if nulls.is_valid(i) { |
559 | 0 | values.extend_from_slice(array.value(index).as_ref()); |
560 | 0 | } |
561 | | } |
562 | 0 | (offsets, values) |
563 | | }; |
564 | | |
565 | 0 | T::Offset::from_usize(values.len()) |
566 | 0 | .ok_or_else(|| ArrowError::OffsetOverflowError(values.len()))?; |
567 | | |
568 | 0 | let array = unsafe { |
569 | 0 | let offsets = OffsetBuffer::new_unchecked(offsets.into()); |
570 | 0 | GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls) |
571 | | }; |
572 | | |
573 | 0 | Ok(array) |
574 | 0 | } |
575 | | |
576 | | /// `take` implementation for byte view arrays |
577 | 0 | fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>( |
578 | 0 | array: &GenericByteViewArray<T>, |
579 | 0 | indices: &PrimitiveArray<IndexType>, |
580 | 0 | ) -> Result<GenericByteViewArray<T>, ArrowError> { |
581 | 0 | let new_views = take_native(array.views(), indices); |
582 | 0 | let new_nulls = take_nulls(array.nulls(), indices); |
583 | | // Safety: array.views was valid, and take_native copies only valid values, and verifies bounds |
584 | 0 | Ok(unsafe { |
585 | 0 | GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls) |
586 | 0 | }) |
587 | 0 | } |
588 | | |
589 | | /// `take` implementation for list arrays |
590 | | /// |
591 | | /// Calculates the index and indexed offset for the inner array, |
592 | | /// applying `take` on the inner array, then reconstructing a list array |
593 | | /// with the indexed offsets |
594 | 0 | fn take_list<IndexType, OffsetType>( |
595 | 0 | values: &GenericListArray<OffsetType::Native>, |
596 | 0 | indices: &PrimitiveArray<IndexType>, |
597 | 0 | ) -> Result<GenericListArray<OffsetType::Native>, ArrowError> |
598 | 0 | where |
599 | 0 | IndexType: ArrowPrimitiveType, |
600 | 0 | OffsetType: ArrowPrimitiveType, |
601 | 0 | OffsetType::Native: OffsetSizeTrait, |
602 | 0 | PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>, |
603 | | { |
604 | | // TODO: Some optimizations can be done here such as if it is |
605 | | // taking the whole list or a contiguous sublist |
606 | 0 | let (list_indices, offsets, null_buf) = |
607 | 0 | take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?; |
608 | | |
609 | 0 | let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?; |
610 | 0 | let value_offsets = Buffer::from_vec(offsets); |
611 | | // create a new list with taken data and computed null information |
612 | 0 | let list_data = ArrayDataBuilder::new(values.data_type().clone()) |
613 | 0 | .len(indices.len()) |
614 | 0 | .null_bit_buffer(Some(null_buf.into())) |
615 | 0 | .offset(0) |
616 | 0 | .add_child_data(taken.into_data()) |
617 | 0 | .add_buffer(value_offsets); |
618 | | |
619 | 0 | let list_data = unsafe { list_data.build_unchecked() }; |
620 | | |
621 | 0 | Ok(GenericListArray::<OffsetType::Native>::from(list_data)) |
622 | 0 | } |
623 | | |
624 | | /// `take` implementation for `FixedSizeListArray` |
625 | | /// |
626 | | /// Calculates the index and indexed offset for the inner array, |
627 | | /// applying `take` on the inner array, then reconstructing a list array |
628 | | /// with the indexed offsets |
629 | 0 | fn take_fixed_size_list<IndexType: ArrowPrimitiveType>( |
630 | 0 | values: &FixedSizeListArray, |
631 | 0 | indices: &PrimitiveArray<IndexType>, |
632 | 0 | length: <UInt32Type as ArrowPrimitiveType>::Native, |
633 | 0 | ) -> Result<FixedSizeListArray, ArrowError> { |
634 | 0 | let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?; |
635 | 0 | let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?; |
636 | | |
637 | | // determine null count and null buffer, which are a function of `values` and `indices` |
638 | 0 | let num_bytes = bit_util::ceil(indices.len(), 8); |
639 | 0 | let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); |
640 | 0 | let null_slice = null_buf.as_slice_mut(); |
641 | | |
642 | 0 | for i in 0..indices.len() { |
643 | 0 | let index = indices |
644 | 0 | .value(i) |
645 | 0 | .to_usize() |
646 | 0 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?; |
647 | 0 | if !indices.is_valid(i) || values.is_null(index) { |
648 | 0 | bit_util::unset_bit(null_slice, i); |
649 | 0 | } |
650 | | } |
651 | | |
652 | 0 | let list_data = ArrayDataBuilder::new(values.data_type().clone()) |
653 | 0 | .len(indices.len()) |
654 | 0 | .null_bit_buffer(Some(null_buf.into())) |
655 | 0 | .offset(0) |
656 | 0 | .add_child_data(taken.into_data()); |
657 | | |
658 | 0 | let list_data = unsafe { list_data.build_unchecked() }; |
659 | | |
660 | 0 | Ok(FixedSizeListArray::from(list_data)) |
661 | 0 | } |
662 | | |
663 | 0 | fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>( |
664 | 0 | values: &FixedSizeBinaryArray, |
665 | 0 | indices: &PrimitiveArray<IndexType>, |
666 | 0 | size: i32, |
667 | 0 | ) -> Result<FixedSizeBinaryArray, ArrowError> { |
668 | 0 | let nulls = values.nulls(); |
669 | 0 | let array_iter = indices |
670 | 0 | .values() |
671 | 0 | .iter() |
672 | 0 | .map(|idx| { |
673 | 0 | let idx = maybe_usize::<IndexType::Native>(*idx)?; |
674 | 0 | if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) { |
675 | 0 | Ok(Some(values.value(idx))) |
676 | | } else { |
677 | 0 | Ok(None) |
678 | | } |
679 | 0 | }) |
680 | 0 | .collect::<Result<Vec<_>, ArrowError>>()? |
681 | 0 | .into_iter(); |
682 | | |
683 | 0 | FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size) |
684 | 0 | } |
685 | | |
686 | | /// `take` implementation for dictionary arrays |
687 | | /// |
688 | | /// applies `take` to the keys of the dictionary array and returns a new dictionary array |
689 | | /// with the same dictionary values and reordered keys |
690 | 0 | fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>( |
691 | 0 | values: &DictionaryArray<T>, |
692 | 0 | indices: &PrimitiveArray<I>, |
693 | 0 | ) -> Result<DictionaryArray<T>, ArrowError> { |
694 | 0 | let new_keys = take_primitive(values.keys(), indices)?; |
695 | 0 | Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) }) |
696 | 0 | } |
697 | | |
698 | | /// `take` implementation for run arrays |
699 | | /// |
700 | | /// Finds physical indices for the given logical indices and builds output run array |
701 | | /// by taking values in the input run_array.values at the physical indices. |
702 | | /// The output run array will be run encoded on the physical indices and not on output values. |
703 | | /// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]` |
704 | | /// would be converted to `physical_indices=[1,1,3,3]` which will be used to build |
705 | | /// output `RunArray{ run_ends=[2,4], values=[2,2] }`. |
706 | 0 | fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>( |
707 | 0 | run_array: &RunArray<T>, |
708 | 0 | logical_indices: &PrimitiveArray<I>, |
709 | 0 | ) -> Result<RunArray<T>, ArrowError> { |
710 | | // get physical indices for the input logical indices |
711 | 0 | let physical_indices = run_array.get_physical_indices(logical_indices.values())?; |
712 | | |
713 | | // Run encode the physical indices into new_run_ends_builder |
714 | | // Keep track of the physical indices to take in take_value_indices |
715 | | // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`. |
716 | 0 | let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1); |
717 | 0 | let mut take_value_indices = BufferBuilder::<I::Native>::new(1); |
718 | 0 | let mut new_physical_len = 1; |
719 | 0 | for ix in 1..physical_indices.len() { |
720 | 0 | if physical_indices[ix] != physical_indices[ix - 1] { |
721 | 0 | take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap()); |
722 | 0 | new_run_ends_builder.append(T::Native::from_usize(ix).unwrap()); |
723 | 0 | new_physical_len += 1; |
724 | 0 | } |
725 | | } |
726 | 0 | take_value_indices |
727 | 0 | .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap()); |
728 | 0 | new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap()); |
729 | 0 | let new_run_ends = unsafe { |
730 | | // Safety: |
731 | | // The function builds a valid run_ends array and hence need not be validated. |
732 | 0 | ArrayDataBuilder::new(T::DATA_TYPE) |
733 | 0 | .len(new_physical_len) |
734 | 0 | .null_count(0) |
735 | 0 | .add_buffer(new_run_ends_builder.finish()) |
736 | 0 | .build_unchecked() |
737 | | }; |
738 | | |
739 | 0 | let take_value_indices: PrimitiveArray<I> = unsafe { |
740 | | // Safety: |
741 | | // The function builds a valid take_value_indices array and hence need not be validated. |
742 | 0 | ArrayDataBuilder::new(I::DATA_TYPE) |
743 | 0 | .len(new_physical_len) |
744 | 0 | .null_count(0) |
745 | 0 | .add_buffer(take_value_indices.finish()) |
746 | 0 | .build_unchecked() |
747 | 0 | .into() |
748 | | }; |
749 | | |
750 | 0 | let new_values = take(run_array.values(), &take_value_indices, None)?; |
751 | | |
752 | 0 | let builder = ArrayDataBuilder::new(run_array.data_type().clone()) |
753 | 0 | .len(physical_indices.len()) |
754 | 0 | .add_child_data(new_run_ends) |
755 | 0 | .add_child_data(new_values.into_data()); |
756 | 0 | let array_data = unsafe { |
757 | | // Safety: |
758 | | // This function builds a valid run array and hence can skip validation. |
759 | 0 | builder.build_unchecked() |
760 | | }; |
761 | 0 | Ok(array_data.into()) |
762 | 0 | } |
763 | | |
764 | | /// Takes/filters a list array's inner data using the offsets of the list array. |
765 | | /// |
766 | | /// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns |
767 | | /// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2 |
768 | | /// elements) |
769 | | #[allow(clippy::type_complexity)] |
770 | 0 | fn take_value_indices_from_list<IndexType, OffsetType>( |
771 | 0 | list: &GenericListArray<OffsetType::Native>, |
772 | 0 | indices: &PrimitiveArray<IndexType>, |
773 | 0 | ) -> Result< |
774 | 0 | ( |
775 | 0 | PrimitiveArray<OffsetType>, |
776 | 0 | Vec<OffsetType::Native>, |
777 | 0 | MutableBuffer, |
778 | 0 | ), |
779 | 0 | ArrowError, |
780 | 0 | > |
781 | 0 | where |
782 | 0 | IndexType: ArrowPrimitiveType, |
783 | 0 | OffsetType: ArrowPrimitiveType, |
784 | 0 | OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One, |
785 | 0 | PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>, |
786 | | { |
787 | | // TODO: benchmark this function, there might be a faster unsafe alternative |
788 | 0 | let offsets: &[OffsetType::Native] = list.value_offsets(); |
789 | | |
790 | 0 | let mut new_offsets = Vec::with_capacity(indices.len()); |
791 | 0 | let mut values = Vec::new(); |
792 | 0 | let mut current_offset = OffsetType::Native::zero(); |
793 | | // add first offset |
794 | 0 | new_offsets.push(OffsetType::Native::zero()); |
795 | | |
796 | | // Initialize null buffer |
797 | 0 | let num_bytes = bit_util::ceil(indices.len(), 8); |
798 | 0 | let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); |
799 | 0 | let null_slice = null_buf.as_slice_mut(); |
800 | | |
801 | | // compute the value indices, and set offsets accordingly |
802 | 0 | for i in 0..indices.len() { |
803 | 0 | if indices.is_valid(i) { |
804 | 0 | let ix = indices |
805 | 0 | .value(i) |
806 | 0 | .to_usize() |
807 | 0 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?; |
808 | 0 | let start = offsets[ix]; |
809 | 0 | let end = offsets[ix + 1]; |
810 | 0 | current_offset += end - start; |
811 | 0 | new_offsets.push(current_offset); |
812 | | |
813 | 0 | let mut curr = start; |
814 | | |
815 | | // if start == end, this slot is empty |
816 | 0 | while curr < end { |
817 | 0 | values.push(curr); |
818 | 0 | curr += One::one(); |
819 | 0 | } |
820 | 0 | if !list.is_valid(ix) { |
821 | 0 | bit_util::unset_bit(null_slice, i); |
822 | 0 | } |
823 | 0 | } else { |
824 | 0 | bit_util::unset_bit(null_slice, i); |
825 | 0 | new_offsets.push(current_offset); |
826 | 0 | } |
827 | | } |
828 | | |
829 | 0 | Ok(( |
830 | 0 | PrimitiveArray::<OffsetType>::from(values), |
831 | 0 | new_offsets, |
832 | 0 | null_buf, |
833 | 0 | )) |
834 | 0 | } |
835 | | |
836 | | /// Takes/filters a fixed size list array's inner data using the offsets of the list array. |
837 | 0 | fn take_value_indices_from_fixed_size_list<IndexType>( |
838 | 0 | list: &FixedSizeListArray, |
839 | 0 | indices: &PrimitiveArray<IndexType>, |
840 | 0 | length: <UInt32Type as ArrowPrimitiveType>::Native, |
841 | 0 | ) -> Result<PrimitiveArray<UInt32Type>, ArrowError> |
842 | 0 | where |
843 | 0 | IndexType: ArrowPrimitiveType, |
844 | | { |
845 | 0 | let mut values = UInt32Builder::with_capacity(length as usize * indices.len()); |
846 | | |
847 | 0 | for i in 0..indices.len() { |
848 | 0 | if indices.is_valid(i) { |
849 | 0 | let index = indices |
850 | 0 | .value(i) |
851 | 0 | .to_usize() |
852 | 0 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?; |
853 | 0 | let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native; |
854 | | |
855 | | // Safety: Range always has known length. |
856 | 0 | unsafe { |
857 | 0 | values.append_trusted_len_iter(start..start + length); |
858 | 0 | } |
859 | 0 | } else { |
860 | 0 | values.append_nulls(length as usize); |
861 | 0 | } |
862 | | } |
863 | | |
864 | 0 | Ok(values.finish()) |
865 | 0 | } |
866 | | |
867 | | /// To avoid generating take implementations for every index type, instead we |
868 | | /// only generate for UInt32 and UInt64 and coerce inputs to these types |
869 | | trait ToIndices { |
870 | | type T: ArrowPrimitiveType; |
871 | | |
872 | | fn to_indices(&self) -> PrimitiveArray<Self::T>; |
873 | | } |
874 | | |
875 | | macro_rules! to_indices_reinterpret { |
876 | | ($t:ty, $o:ty) => { |
877 | | impl ToIndices for PrimitiveArray<$t> { |
878 | | type T = $o; |
879 | | |
880 | 0 | fn to_indices(&self) -> PrimitiveArray<$o> { |
881 | 0 | let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len()); |
882 | 0 | PrimitiveArray::new(cast, self.nulls().cloned()) |
883 | 0 | } |
884 | | } |
885 | | }; |
886 | | } |
887 | | |
888 | | macro_rules! to_indices_identity { |
889 | | ($t:ty) => { |
890 | | impl ToIndices for PrimitiveArray<$t> { |
891 | | type T = $t; |
892 | | |
893 | 0 | fn to_indices(&self) -> PrimitiveArray<$t> { |
894 | 0 | self.clone() |
895 | 0 | } |
896 | | } |
897 | | }; |
898 | | } |
899 | | |
900 | | macro_rules! to_indices_widening { |
901 | | ($t:ty, $o:ty) => { |
902 | | impl ToIndices for PrimitiveArray<$t> { |
903 | | type T = UInt32Type; |
904 | | |
905 | 0 | fn to_indices(&self) -> PrimitiveArray<$o> { |
906 | 0 | let cast = self.values().iter().copied().map(|x| x as _).collect(); |
907 | 0 | PrimitiveArray::new(cast, self.nulls().cloned()) |
908 | 0 | } |
909 | | } |
910 | | }; |
911 | | } |
912 | | |
913 | | to_indices_widening!(UInt8Type, UInt32Type); |
914 | | to_indices_widening!(Int8Type, UInt32Type); |
915 | | |
916 | | to_indices_widening!(UInt16Type, UInt32Type); |
917 | | to_indices_widening!(Int16Type, UInt32Type); |
918 | | |
919 | | to_indices_identity!(UInt32Type); |
920 | | to_indices_reinterpret!(Int32Type, UInt32Type); |
921 | | |
922 | | to_indices_identity!(UInt64Type); |
923 | | to_indices_reinterpret!(Int64Type, UInt64Type); |
924 | | |
925 | | /// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes. |
926 | | /// |
927 | | /// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`]. |
928 | | /// |
929 | | /// # Example |
930 | | /// ``` |
931 | | /// # use std::sync::Arc; |
932 | | /// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch}; |
933 | | /// # use arrow_schema::{DataType, Field, Schema}; |
934 | | /// # use arrow_select::take::take_record_batch; |
935 | | /// |
936 | | /// let schema = Arc::new(Schema::new(vec![ |
937 | | /// Field::new("a", DataType::Int32, true), |
938 | | /// Field::new("b", DataType::Utf8, true), |
939 | | /// ])); |
940 | | /// let batch = RecordBatch::try_new( |
941 | | /// schema.clone(), |
942 | | /// vec![ |
943 | | /// Arc::new(Int32Array::from_iter_values(0..20)), |
944 | | /// Arc::new(StringArray::from_iter_values( |
945 | | /// (0..20).map(|i| format!("str-{}", i)), |
946 | | /// )), |
947 | | /// ], |
948 | | /// ) |
949 | | /// .unwrap(); |
950 | | /// |
951 | | /// let indices = UInt32Array::from(vec![1, 5, 10]); |
952 | | /// let taken = take_record_batch(&batch, &indices).unwrap(); |
953 | | /// |
954 | | /// let expected = RecordBatch::try_new( |
955 | | /// schema, |
956 | | /// vec![ |
957 | | /// Arc::new(Int32Array::from(vec![1, 5, 10])), |
958 | | /// Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])), |
959 | | /// ], |
960 | | /// ) |
961 | | /// .unwrap(); |
962 | | /// assert_eq!(taken, expected); |
963 | | /// ``` |
964 | 0 | pub fn take_record_batch( |
965 | 0 | record_batch: &RecordBatch, |
966 | 0 | indices: &dyn Array, |
967 | 0 | ) -> Result<RecordBatch, ArrowError> { |
968 | 0 | let columns = record_batch |
969 | 0 | .columns() |
970 | 0 | .iter() |
971 | 0 | .map(|c| take(c, indices, None)) |
972 | 0 | .collect::<Result<Vec<_>, _>>()?; |
973 | 0 | RecordBatch::try_new(record_batch.schema(), columns) |
974 | 0 | } |
975 | | |
976 | | #[cfg(test)] |
977 | | mod tests { |
978 | | use super::*; |
979 | | use arrow_array::builder::*; |
980 | | use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; |
981 | | use arrow_data::ArrayData; |
982 | | use arrow_schema::{Field, Fields, TimeUnit, UnionFields}; |
983 | | |
984 | | fn test_take_decimal_arrays( |
985 | | data: Vec<Option<i128>>, |
986 | | index: &UInt32Array, |
987 | | options: Option<TakeOptions>, |
988 | | expected_data: Vec<Option<i128>>, |
989 | | precision: &u8, |
990 | | scale: &i8, |
991 | | ) -> Result<(), ArrowError> { |
992 | | let output = data |
993 | | .into_iter() |
994 | | .collect::<Decimal128Array>() |
995 | | .with_precision_and_scale(*precision, *scale) |
996 | | .unwrap(); |
997 | | |
998 | | let expected = expected_data |
999 | | .into_iter() |
1000 | | .collect::<Decimal128Array>() |
1001 | | .with_precision_and_scale(*precision, *scale) |
1002 | | .unwrap(); |
1003 | | |
1004 | | let expected = Arc::new(expected) as ArrayRef; |
1005 | | let output = take(&output, index, options).unwrap(); |
1006 | | assert_eq!(&output, &expected); |
1007 | | Ok(()) |
1008 | | } |
1009 | | |
1010 | | fn test_take_boolean_arrays( |
1011 | | data: Vec<Option<bool>>, |
1012 | | index: &UInt32Array, |
1013 | | options: Option<TakeOptions>, |
1014 | | expected_data: Vec<Option<bool>>, |
1015 | | ) { |
1016 | | let output = BooleanArray::from(data); |
1017 | | let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef; |
1018 | | let output = take(&output, index, options).unwrap(); |
1019 | | assert_eq!(&output, &expected) |
1020 | | } |
1021 | | |
1022 | | fn test_take_primitive_arrays<T>( |
1023 | | data: Vec<Option<T::Native>>, |
1024 | | index: &UInt32Array, |
1025 | | options: Option<TakeOptions>, |
1026 | | expected_data: Vec<Option<T::Native>>, |
1027 | | ) -> Result<(), ArrowError> |
1028 | | where |
1029 | | T: ArrowPrimitiveType, |
1030 | | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
1031 | | { |
1032 | | let output = PrimitiveArray::<T>::from(data); |
1033 | | let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef; |
1034 | | let output = take(&output, index, options)?; |
1035 | | assert_eq!(&output, &expected); |
1036 | | Ok(()) |
1037 | | } |
1038 | | |
1039 | | fn test_take_primitive_arrays_non_null<T>( |
1040 | | data: Vec<T::Native>, |
1041 | | index: &UInt32Array, |
1042 | | options: Option<TakeOptions>, |
1043 | | expected_data: Vec<Option<T::Native>>, |
1044 | | ) -> Result<(), ArrowError> |
1045 | | where |
1046 | | T: ArrowPrimitiveType, |
1047 | | PrimitiveArray<T>: From<Vec<T::Native>>, |
1048 | | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
1049 | | { |
1050 | | let output = PrimitiveArray::<T>::from(data); |
1051 | | let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef; |
1052 | | let output = take(&output, index, options)?; |
1053 | | assert_eq!(&output, &expected); |
1054 | | Ok(()) |
1055 | | } |
1056 | | |
1057 | | fn test_take_impl_primitive_arrays<T, I>( |
1058 | | data: Vec<Option<T::Native>>, |
1059 | | index: &PrimitiveArray<I>, |
1060 | | options: Option<TakeOptions>, |
1061 | | expected_data: Vec<Option<T::Native>>, |
1062 | | ) where |
1063 | | T: ArrowPrimitiveType, |
1064 | | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
1065 | | I: ArrowPrimitiveType, |
1066 | | { |
1067 | | let output = PrimitiveArray::<T>::from(data); |
1068 | | let expected = PrimitiveArray::<T>::from(expected_data); |
1069 | | let output = take(&output, index, options).unwrap(); |
1070 | | let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); |
1071 | | assert_eq!(output, &expected) |
1072 | | } |
1073 | | |
1074 | | // create a simple struct for testing purposes |
1075 | | fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray { |
1076 | | let mut struct_builder = StructBuilder::new( |
1077 | | Fields::from(vec![ |
1078 | | Field::new("a", DataType::Boolean, true), |
1079 | | Field::new("b", DataType::Int32, true), |
1080 | | ]), |
1081 | | vec![ |
1082 | | Box::new(BooleanBuilder::with_capacity(values.len())), |
1083 | | Box::new(Int32Builder::with_capacity(values.len())), |
1084 | | ], |
1085 | | ); |
1086 | | |
1087 | | for value in values { |
1088 | | struct_builder |
1089 | | .field_builder::<BooleanBuilder>(0) |
1090 | | .unwrap() |
1091 | | .append_option(value.and_then(|v| v.0)); |
1092 | | struct_builder |
1093 | | .field_builder::<Int32Builder>(1) |
1094 | | .unwrap() |
1095 | | .append_option(value.and_then(|v| v.1)); |
1096 | | struct_builder.append(value.is_some()); |
1097 | | } |
1098 | | struct_builder.finish() |
1099 | | } |
1100 | | |
1101 | | #[test] |
1102 | | fn test_take_decimal128_non_null_indices() { |
1103 | | let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); |
1104 | | let precision: u8 = 10; |
1105 | | let scale: i8 = 5; |
1106 | | test_take_decimal_arrays( |
1107 | | vec![None, Some(3), Some(5), Some(2), Some(3), None], |
1108 | | &index, |
1109 | | None, |
1110 | | vec![None, None, Some(2), Some(3), Some(3), Some(5)], |
1111 | | &precision, |
1112 | | &scale, |
1113 | | ) |
1114 | | .unwrap(); |
1115 | | } |
1116 | | |
1117 | | #[test] |
1118 | | fn test_take_decimal128() { |
1119 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1120 | | let precision: u8 = 10; |
1121 | | let scale: i8 = 5; |
1122 | | test_take_decimal_arrays( |
1123 | | vec![Some(0), Some(1), Some(2), Some(3), Some(4)], |
1124 | | &index, |
1125 | | None, |
1126 | | vec![Some(3), None, Some(1), Some(3), Some(2)], |
1127 | | &precision, |
1128 | | &scale, |
1129 | | ) |
1130 | | .unwrap(); |
1131 | | } |
1132 | | |
1133 | | #[test] |
1134 | | fn test_take_primitive_non_null_indices() { |
1135 | | let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); |
1136 | | test_take_primitive_arrays::<Int8Type>( |
1137 | | vec![None, Some(3), Some(5), Some(2), Some(3), None], |
1138 | | &index, |
1139 | | None, |
1140 | | vec![None, None, Some(2), Some(3), Some(3), Some(5)], |
1141 | | ) |
1142 | | .unwrap(); |
1143 | | } |
1144 | | |
1145 | | #[test] |
1146 | | fn test_take_primitive_non_null_values() { |
1147 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1148 | | test_take_primitive_arrays::<Int8Type>( |
1149 | | vec![Some(0), Some(1), Some(2), Some(3), Some(4)], |
1150 | | &index, |
1151 | | None, |
1152 | | vec![Some(3), None, Some(1), Some(3), Some(2)], |
1153 | | ) |
1154 | | .unwrap(); |
1155 | | } |
1156 | | |
1157 | | #[test] |
1158 | | fn test_take_primitive_non_null() { |
1159 | | let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); |
1160 | | test_take_primitive_arrays::<Int8Type>( |
1161 | | vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)], |
1162 | | &index, |
1163 | | None, |
1164 | | vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)], |
1165 | | ) |
1166 | | .unwrap(); |
1167 | | } |
1168 | | |
1169 | | #[test] |
1170 | | fn test_take_primitive_nullable_indices_non_null_values_with_offset() { |
1171 | | let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]); |
1172 | | let index = index.slice(2, 4); |
1173 | | let index = index.as_any().downcast_ref::<UInt32Array>().unwrap(); |
1174 | | |
1175 | | assert_eq!( |
1176 | | index, |
1177 | | &UInt32Array::from(vec![Some(2), Some(3), None, None]) |
1178 | | ); |
1179 | | |
1180 | | test_take_primitive_arrays_non_null::<Int64Type>( |
1181 | | vec![0, 10, 20, 30, 40, 50], |
1182 | | index, |
1183 | | None, |
1184 | | vec![Some(20), Some(30), None, None], |
1185 | | ) |
1186 | | .unwrap(); |
1187 | | } |
1188 | | |
1189 | | #[test] |
1190 | | fn test_take_primitive_nullable_indices_nullable_values_with_offset() { |
1191 | | let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]); |
1192 | | let index = index.slice(2, 4); |
1193 | | let index = index.as_any().downcast_ref::<UInt32Array>().unwrap(); |
1194 | | |
1195 | | assert_eq!( |
1196 | | index, |
1197 | | &UInt32Array::from(vec![Some(2), Some(3), None, None]) |
1198 | | ); |
1199 | | |
1200 | | test_take_primitive_arrays::<Int64Type>( |
1201 | | vec![None, None, Some(20), Some(30), Some(40), Some(50)], |
1202 | | index, |
1203 | | None, |
1204 | | vec![Some(20), Some(30), None, None], |
1205 | | ) |
1206 | | .unwrap(); |
1207 | | } |
1208 | | |
1209 | | #[test] |
1210 | | fn test_take_primitive() { |
1211 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1212 | | |
1213 | | // int8 |
1214 | | test_take_primitive_arrays::<Int8Type>( |
1215 | | vec![Some(0), None, Some(2), Some(3), None], |
1216 | | &index, |
1217 | | None, |
1218 | | vec![Some(3), None, None, Some(3), Some(2)], |
1219 | | ) |
1220 | | .unwrap(); |
1221 | | |
1222 | | // int16 |
1223 | | test_take_primitive_arrays::<Int16Type>( |
1224 | | vec![Some(0), None, Some(2), Some(3), None], |
1225 | | &index, |
1226 | | None, |
1227 | | vec![Some(3), None, None, Some(3), Some(2)], |
1228 | | ) |
1229 | | .unwrap(); |
1230 | | |
1231 | | // int32 |
1232 | | test_take_primitive_arrays::<Int32Type>( |
1233 | | vec![Some(0), None, Some(2), Some(3), None], |
1234 | | &index, |
1235 | | None, |
1236 | | vec![Some(3), None, None, Some(3), Some(2)], |
1237 | | ) |
1238 | | .unwrap(); |
1239 | | |
1240 | | // int64 |
1241 | | test_take_primitive_arrays::<Int64Type>( |
1242 | | vec![Some(0), None, Some(2), Some(3), None], |
1243 | | &index, |
1244 | | None, |
1245 | | vec![Some(3), None, None, Some(3), Some(2)], |
1246 | | ) |
1247 | | .unwrap(); |
1248 | | |
1249 | | // uint8 |
1250 | | test_take_primitive_arrays::<UInt8Type>( |
1251 | | vec![Some(0), None, Some(2), Some(3), None], |
1252 | | &index, |
1253 | | None, |
1254 | | vec![Some(3), None, None, Some(3), Some(2)], |
1255 | | ) |
1256 | | .unwrap(); |
1257 | | |
1258 | | // uint16 |
1259 | | test_take_primitive_arrays::<UInt16Type>( |
1260 | | vec![Some(0), None, Some(2), Some(3), None], |
1261 | | &index, |
1262 | | None, |
1263 | | vec![Some(3), None, None, Some(3), Some(2)], |
1264 | | ) |
1265 | | .unwrap(); |
1266 | | |
1267 | | // uint32 |
1268 | | test_take_primitive_arrays::<UInt32Type>( |
1269 | | vec![Some(0), None, Some(2), Some(3), None], |
1270 | | &index, |
1271 | | None, |
1272 | | vec![Some(3), None, None, Some(3), Some(2)], |
1273 | | ) |
1274 | | .unwrap(); |
1275 | | |
1276 | | // int64 |
1277 | | test_take_primitive_arrays::<Int64Type>( |
1278 | | vec![Some(0), None, Some(2), Some(-15), None], |
1279 | | &index, |
1280 | | None, |
1281 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1282 | | ) |
1283 | | .unwrap(); |
1284 | | |
1285 | | // interval_year_month |
1286 | | test_take_primitive_arrays::<IntervalYearMonthType>( |
1287 | | vec![Some(0), None, Some(2), Some(-15), None], |
1288 | | &index, |
1289 | | None, |
1290 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1291 | | ) |
1292 | | .unwrap(); |
1293 | | |
1294 | | // interval_day_time |
1295 | | let v1 = IntervalDayTime::new(0, 0); |
1296 | | let v2 = IntervalDayTime::new(2, 0); |
1297 | | let v3 = IntervalDayTime::new(-15, 0); |
1298 | | test_take_primitive_arrays::<IntervalDayTimeType>( |
1299 | | vec![Some(v1), None, Some(v2), Some(v3), None], |
1300 | | &index, |
1301 | | None, |
1302 | | vec![Some(v3), None, None, Some(v3), Some(v2)], |
1303 | | ) |
1304 | | .unwrap(); |
1305 | | |
1306 | | // interval_month_day_nano |
1307 | | let v1 = IntervalMonthDayNano::new(0, 0, 0); |
1308 | | let v2 = IntervalMonthDayNano::new(2, 0, 0); |
1309 | | let v3 = IntervalMonthDayNano::new(-15, 0, 0); |
1310 | | test_take_primitive_arrays::<IntervalMonthDayNanoType>( |
1311 | | vec![Some(v1), None, Some(v2), Some(v3), None], |
1312 | | &index, |
1313 | | None, |
1314 | | vec![Some(v3), None, None, Some(v3), Some(v2)], |
1315 | | ) |
1316 | | .unwrap(); |
1317 | | |
1318 | | // duration_second |
1319 | | test_take_primitive_arrays::<DurationSecondType>( |
1320 | | vec![Some(0), None, Some(2), Some(-15), None], |
1321 | | &index, |
1322 | | None, |
1323 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1324 | | ) |
1325 | | .unwrap(); |
1326 | | |
1327 | | // duration_millisecond |
1328 | | test_take_primitive_arrays::<DurationMillisecondType>( |
1329 | | vec![Some(0), None, Some(2), Some(-15), None], |
1330 | | &index, |
1331 | | None, |
1332 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1333 | | ) |
1334 | | .unwrap(); |
1335 | | |
1336 | | // duration_microsecond |
1337 | | test_take_primitive_arrays::<DurationMicrosecondType>( |
1338 | | vec![Some(0), None, Some(2), Some(-15), None], |
1339 | | &index, |
1340 | | None, |
1341 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1342 | | ) |
1343 | | .unwrap(); |
1344 | | |
1345 | | // duration_nanosecond |
1346 | | test_take_primitive_arrays::<DurationNanosecondType>( |
1347 | | vec![Some(0), None, Some(2), Some(-15), None], |
1348 | | &index, |
1349 | | None, |
1350 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1351 | | ) |
1352 | | .unwrap(); |
1353 | | |
1354 | | // float32 |
1355 | | test_take_primitive_arrays::<Float32Type>( |
1356 | | vec![Some(0.0), None, Some(2.21), Some(-3.1), None], |
1357 | | &index, |
1358 | | None, |
1359 | | vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], |
1360 | | ) |
1361 | | .unwrap(); |
1362 | | |
1363 | | // float64 |
1364 | | test_take_primitive_arrays::<Float64Type>( |
1365 | | vec![Some(0.0), None, Some(2.21), Some(-3.1), None], |
1366 | | &index, |
1367 | | None, |
1368 | | vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], |
1369 | | ) |
1370 | | .unwrap(); |
1371 | | } |
1372 | | |
1373 | | #[test] |
1374 | | fn test_take_preserve_timezone() { |
1375 | | let index = Int64Array::from(vec![Some(0), None]); |
1376 | | |
1377 | | let input = TimestampNanosecondArray::from(vec![ |
1378 | | 1_639_715_368_000_000_000, |
1379 | | 1_639_715_368_000_000_000, |
1380 | | ]) |
1381 | | .with_timezone("UTC".to_string()); |
1382 | | let result = take(&input, &index, None).unwrap(); |
1383 | | match result.data_type() { |
1384 | | DataType::Timestamp(TimeUnit::Nanosecond, tz) => { |
1385 | | assert_eq!(tz.clone(), Some("UTC".into())) |
1386 | | } |
1387 | | _ => panic!(), |
1388 | | } |
1389 | | } |
1390 | | |
1391 | | #[test] |
1392 | | fn test_take_impl_primitive_with_int64_indices() { |
1393 | | let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1394 | | |
1395 | | // int16 |
1396 | | test_take_impl_primitive_arrays::<Int16Type, Int64Type>( |
1397 | | vec![Some(0), None, Some(2), Some(3), None], |
1398 | | &index, |
1399 | | None, |
1400 | | vec![Some(3), None, None, Some(3), Some(2)], |
1401 | | ); |
1402 | | |
1403 | | // int64 |
1404 | | test_take_impl_primitive_arrays::<Int64Type, Int64Type>( |
1405 | | vec![Some(0), None, Some(2), Some(-15), None], |
1406 | | &index, |
1407 | | None, |
1408 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1409 | | ); |
1410 | | |
1411 | | // uint64 |
1412 | | test_take_impl_primitive_arrays::<UInt64Type, Int64Type>( |
1413 | | vec![Some(0), None, Some(2), Some(3), None], |
1414 | | &index, |
1415 | | None, |
1416 | | vec![Some(3), None, None, Some(3), Some(2)], |
1417 | | ); |
1418 | | |
1419 | | // duration_millisecond |
1420 | | test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>( |
1421 | | vec![Some(0), None, Some(2), Some(-15), None], |
1422 | | &index, |
1423 | | None, |
1424 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1425 | | ); |
1426 | | |
1427 | | // float32 |
1428 | | test_take_impl_primitive_arrays::<Float32Type, Int64Type>( |
1429 | | vec![Some(0.0), None, Some(2.21), Some(-3.1), None], |
1430 | | &index, |
1431 | | None, |
1432 | | vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], |
1433 | | ); |
1434 | | } |
1435 | | |
1436 | | #[test] |
1437 | | fn test_take_impl_primitive_with_uint8_indices() { |
1438 | | let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1439 | | |
1440 | | // int16 |
1441 | | test_take_impl_primitive_arrays::<Int16Type, UInt8Type>( |
1442 | | vec![Some(0), None, Some(2), Some(3), None], |
1443 | | &index, |
1444 | | None, |
1445 | | vec![Some(3), None, None, Some(3), Some(2)], |
1446 | | ); |
1447 | | |
1448 | | // duration_millisecond |
1449 | | test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>( |
1450 | | vec![Some(0), None, Some(2), Some(-15), None], |
1451 | | &index, |
1452 | | None, |
1453 | | vec![Some(-15), None, None, Some(-15), Some(2)], |
1454 | | ); |
1455 | | |
1456 | | // float32 |
1457 | | test_take_impl_primitive_arrays::<Float32Type, UInt8Type>( |
1458 | | vec![Some(0.0), None, Some(2.21), Some(-3.1), None], |
1459 | | &index, |
1460 | | None, |
1461 | | vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], |
1462 | | ); |
1463 | | } |
1464 | | |
1465 | | #[test] |
1466 | | fn test_take_bool() { |
1467 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1468 | | // boolean |
1469 | | test_take_boolean_arrays( |
1470 | | vec![Some(false), None, Some(true), Some(false), None], |
1471 | | &index, |
1472 | | None, |
1473 | | vec![Some(false), None, None, Some(false), Some(true)], |
1474 | | ); |
1475 | | } |
1476 | | |
1477 | | #[test] |
1478 | | fn test_take_bool_nullable_index() { |
1479 | | // indices where the masked invalid elements would be out of bounds |
1480 | | let index_data = ArrayData::try_new( |
1481 | | DataType::UInt32, |
1482 | | 6, |
1483 | | Some(Buffer::from_iter(vec![ |
1484 | | false, true, false, true, false, true, |
1485 | | ])), |
1486 | | 0, |
1487 | | vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])], |
1488 | | vec![], |
1489 | | ) |
1490 | | .unwrap(); |
1491 | | let index = UInt32Array::from(index_data); |
1492 | | test_take_boolean_arrays( |
1493 | | vec![Some(true), None, Some(false)], |
1494 | | &index, |
1495 | | None, |
1496 | | vec![None, Some(true), None, None, None, Some(false)], |
1497 | | ); |
1498 | | } |
1499 | | |
1500 | | #[test] |
1501 | | fn test_take_bool_nullable_index_nonnull_values() { |
1502 | | // indices where the masked invalid elements would be out of bounds |
1503 | | let index_data = ArrayData::try_new( |
1504 | | DataType::UInt32, |
1505 | | 6, |
1506 | | Some(Buffer::from_iter(vec![ |
1507 | | false, true, false, true, false, true, |
1508 | | ])), |
1509 | | 0, |
1510 | | vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])], |
1511 | | vec![], |
1512 | | ) |
1513 | | .unwrap(); |
1514 | | let index = UInt32Array::from(index_data); |
1515 | | test_take_boolean_arrays( |
1516 | | vec![Some(true), Some(true), Some(false)], |
1517 | | &index, |
1518 | | None, |
1519 | | vec![None, Some(true), None, Some(true), None, Some(false)], |
1520 | | ); |
1521 | | } |
1522 | | |
1523 | | #[test] |
1524 | | fn test_take_bool_with_offset() { |
1525 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]); |
1526 | | let index = index.slice(2, 4); |
1527 | | let index = index |
1528 | | .as_any() |
1529 | | .downcast_ref::<PrimitiveArray<UInt32Type>>() |
1530 | | .unwrap(); |
1531 | | |
1532 | | // boolean |
1533 | | test_take_boolean_arrays( |
1534 | | vec![Some(false), None, Some(true), Some(false), None], |
1535 | | index, |
1536 | | None, |
1537 | | vec![None, Some(false), Some(true), None], |
1538 | | ); |
1539 | | } |
1540 | | |
1541 | | fn _test_take_string<'a, K>() |
1542 | | where |
1543 | | K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static, |
1544 | | { |
1545 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]); |
1546 | | |
1547 | | let array = K::from(vec![ |
1548 | | Some("one"), |
1549 | | None, |
1550 | | Some("three"), |
1551 | | Some("four"), |
1552 | | Some("five"), |
1553 | | ]); |
1554 | | let actual = take(&array, &index, None).unwrap(); |
1555 | | assert_eq!(actual.len(), index.len()); |
1556 | | |
1557 | | let actual = actual.as_any().downcast_ref::<K>().unwrap(); |
1558 | | |
1559 | | let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]); |
1560 | | |
1561 | | assert_eq!(actual, &expected); |
1562 | | } |
1563 | | |
1564 | | #[test] |
1565 | | fn test_take_string() { |
1566 | | _test_take_string::<StringArray>() |
1567 | | } |
1568 | | |
1569 | | #[test] |
1570 | | fn test_take_large_string() { |
1571 | | _test_take_string::<LargeStringArray>() |
1572 | | } |
1573 | | |
1574 | | #[test] |
1575 | | fn test_take_slice_string() { |
1576 | | let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]); |
1577 | | let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]); |
1578 | | let indices_slice = indices.slice(1, 4); |
1579 | | let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]); |
1580 | | let result = take(&strings, &indices_slice, None).unwrap(); |
1581 | | assert_eq!(result.as_ref(), &expected); |
1582 | | } |
1583 | | |
1584 | | fn _test_byte_view<T>() |
1585 | | where |
1586 | | T: ByteViewType, |
1587 | | str: AsRef<T::Native>, |
1588 | | T::Native: PartialEq, |
1589 | | { |
1590 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]); |
1591 | | let array = { |
1592 | | // ["hello", "world", null, "large payload over 12 bytes", "lulu"] |
1593 | | let mut builder = GenericByteViewBuilder::<T>::new(); |
1594 | | builder.append_value("hello"); |
1595 | | builder.append_value("world"); |
1596 | | builder.append_null(); |
1597 | | builder.append_value("large payload over 12 bytes"); |
1598 | | builder.append_value("lulu"); |
1599 | | builder.finish() |
1600 | | }; |
1601 | | |
1602 | | let actual = take(&array, &index, None).unwrap(); |
1603 | | |
1604 | | assert_eq!(actual.len(), index.len()); |
1605 | | |
1606 | | let expected = { |
1607 | | // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null] |
1608 | | let mut builder = GenericByteViewBuilder::<T>::new(); |
1609 | | builder.append_value("large payload over 12 bytes"); |
1610 | | builder.append_null(); |
1611 | | builder.append_value("world"); |
1612 | | builder.append_value("large payload over 12 bytes"); |
1613 | | builder.append_value("lulu"); |
1614 | | builder.append_null(); |
1615 | | builder.finish() |
1616 | | }; |
1617 | | |
1618 | | assert_eq!(actual.as_ref(), &expected); |
1619 | | } |
1620 | | |
1621 | | #[test] |
1622 | | fn test_take_string_view() { |
1623 | | _test_byte_view::<StringViewType>() |
1624 | | } |
1625 | | |
1626 | | #[test] |
1627 | | fn test_take_binary_view() { |
1628 | | _test_byte_view::<BinaryViewType>() |
1629 | | } |
1630 | | |
1631 | | macro_rules! test_take_list { |
1632 | | ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{ |
1633 | | // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]] |
1634 | | let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data(); |
1635 | | // Construct offsets |
1636 | | let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8]; |
1637 | | let value_offsets = Buffer::from_slice_ref(&value_offsets); |
1638 | | // Construct a list array from the above two |
1639 | | let list_data_type = |
1640 | | DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false))); |
1641 | | let list_data = ArrayData::builder(list_data_type.clone()) |
1642 | | .len(4) |
1643 | | .add_buffer(value_offsets) |
1644 | | .add_child_data(value_data) |
1645 | | .build() |
1646 | | .unwrap(); |
1647 | | let list_array = $list_array_type::from(list_data); |
1648 | | |
1649 | | // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]] |
1650 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]); |
1651 | | |
1652 | | let a = take(&list_array, &index, None).unwrap(); |
1653 | | let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap(); |
1654 | | |
1655 | | // construct a value array with expected results: |
1656 | | // [[2,3], null, [-1,-2,-1], [], [0,0,0]] |
1657 | | let expected_data = Int32Array::from(vec![ |
1658 | | Some(2), |
1659 | | Some(3), |
1660 | | Some(-1), |
1661 | | Some(-2), |
1662 | | Some(-1), |
1663 | | Some(0), |
1664 | | Some(0), |
1665 | | Some(0), |
1666 | | ]) |
1667 | | .into_data(); |
1668 | | // construct offsets |
1669 | | let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8]; |
1670 | | let expected_offsets = Buffer::from_slice_ref(&expected_offsets); |
1671 | | // construct list array from the two |
1672 | | let expected_list_data = ArrayData::builder(list_data_type) |
1673 | | .len(5) |
1674 | | // null buffer remains the same as only the indices have nulls |
1675 | | .nulls(index.nulls().cloned()) |
1676 | | .add_buffer(expected_offsets) |
1677 | | .add_child_data(expected_data) |
1678 | | .build() |
1679 | | .unwrap(); |
1680 | | let expected_list_array = $list_array_type::from(expected_list_data); |
1681 | | |
1682 | | assert_eq!(a, &expected_list_array); |
1683 | | }}; |
1684 | | } |
1685 | | |
1686 | | macro_rules! test_take_list_with_value_nulls { |
1687 | | ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{ |
1688 | | // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]] |
1689 | | let value_data = Int32Array::from(vec![ |
1690 | | Some(0), |
1691 | | None, |
1692 | | Some(0), |
1693 | | Some(-1), |
1694 | | Some(-2), |
1695 | | Some(3), |
1696 | | None, |
1697 | | Some(5), |
1698 | | None, |
1699 | | ]) |
1700 | | .into_data(); |
1701 | | // Construct offsets |
1702 | | let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9]; |
1703 | | let value_offsets = Buffer::from_slice_ref(&value_offsets); |
1704 | | // Construct a list array from the above two |
1705 | | let list_data_type = |
1706 | | DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true))); |
1707 | | let list_data = ArrayData::builder(list_data_type.clone()) |
1708 | | .len(4) |
1709 | | .add_buffer(value_offsets) |
1710 | | .null_bit_buffer(Some(Buffer::from([0b11111111]))) |
1711 | | .add_child_data(value_data) |
1712 | | .build() |
1713 | | .unwrap(); |
1714 | | let list_array = $list_array_type::from(list_data); |
1715 | | |
1716 | | // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]] |
1717 | | let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]); |
1718 | | |
1719 | | let a = take(&list_array, &index, None).unwrap(); |
1720 | | let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap(); |
1721 | | |
1722 | | // construct a value array with expected results: |
1723 | | // [[null], null, [-1,-2,3], [5,null], [0,null,0]] |
1724 | | let expected_data = Int32Array::from(vec![ |
1725 | | None, |
1726 | | Some(-1), |
1727 | | Some(-2), |
1728 | | Some(3), |
1729 | | Some(5), |
1730 | | None, |
1731 | | Some(0), |
1732 | | None, |
1733 | | Some(0), |
1734 | | ]) |
1735 | | .into_data(); |
1736 | | // construct offsets |
1737 | | let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9]; |
1738 | | let expected_offsets = Buffer::from_slice_ref(&expected_offsets); |
1739 | | // construct list array from the two |
1740 | | let expected_list_data = ArrayData::builder(list_data_type) |
1741 | | .len(5) |
1742 | | // null buffer remains the same as only the indices have nulls |
1743 | | .nulls(index.nulls().cloned()) |
1744 | | .add_buffer(expected_offsets) |
1745 | | .add_child_data(expected_data) |
1746 | | .build() |
1747 | | .unwrap(); |
1748 | | let expected_list_array = $list_array_type::from(expected_list_data); |
1749 | | |
1750 | | assert_eq!(a, &expected_list_array); |
1751 | | }}; |
1752 | | } |
1753 | | |
1754 | | macro_rules! test_take_list_with_nulls { |
1755 | | ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{ |
1756 | | // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]] |
1757 | | let value_data = Int32Array::from(vec![ |
1758 | | Some(0), |
1759 | | None, |
1760 | | Some(0), |
1761 | | Some(-1), |
1762 | | Some(-2), |
1763 | | Some(3), |
1764 | | Some(5), |
1765 | | None, |
1766 | | ]) |
1767 | | .into_data(); |
1768 | | // Construct offsets |
1769 | | let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8]; |
1770 | | let value_offsets = Buffer::from_slice_ref(&value_offsets); |
1771 | | // Construct a list array from the above two |
1772 | | let list_data_type = |
1773 | | DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true))); |
1774 | | let list_data = ArrayData::builder(list_data_type.clone()) |
1775 | | .len(4) |
1776 | | .add_buffer(value_offsets) |
1777 | | .null_bit_buffer(Some(Buffer::from([0b11111011]))) |
1778 | | .add_child_data(value_data) |
1779 | | .build() |
1780 | | .unwrap(); |
1781 | | let list_array = $list_array_type::from(list_data); |
1782 | | |
1783 | | // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]] |
1784 | | let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]); |
1785 | | |
1786 | | let a = take(&list_array, &index, None).unwrap(); |
1787 | | let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap(); |
1788 | | |
1789 | | // construct a value array with expected results: |
1790 | | // [null, null, [-1,-2,3], [5,null], [0,null,0]] |
1791 | | let expected_data = Int32Array::from(vec![ |
1792 | | Some(-1), |
1793 | | Some(-2), |
1794 | | Some(3), |
1795 | | Some(5), |
1796 | | None, |
1797 | | Some(0), |
1798 | | None, |
1799 | | Some(0), |
1800 | | ]) |
1801 | | .into_data(); |
1802 | | // construct offsets |
1803 | | let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8]; |
1804 | | let expected_offsets = Buffer::from_slice_ref(&expected_offsets); |
1805 | | // construct list array from the two |
1806 | | let mut null_bits: [u8; 1] = [0; 1]; |
1807 | | bit_util::set_bit(&mut null_bits, 2); |
1808 | | bit_util::set_bit(&mut null_bits, 3); |
1809 | | bit_util::set_bit(&mut null_bits, 4); |
1810 | | let expected_list_data = ArrayData::builder(list_data_type) |
1811 | | .len(5) |
1812 | | // null buffer must be recalculated as both values and indices have nulls |
1813 | | .null_bit_buffer(Some(Buffer::from(null_bits))) |
1814 | | .add_buffer(expected_offsets) |
1815 | | .add_child_data(expected_data) |
1816 | | .build() |
1817 | | .unwrap(); |
1818 | | let expected_list_array = $list_array_type::from(expected_list_data); |
1819 | | |
1820 | | assert_eq!(a, &expected_list_array); |
1821 | | }}; |
1822 | | } |
1823 | | |
1824 | | fn do_take_fixed_size_list_test<T>( |
1825 | | length: <Int32Type as ArrowPrimitiveType>::Native, |
1826 | | input_data: Vec<Option<Vec<Option<T::Native>>>>, |
1827 | | indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>, |
1828 | | expected_data: Vec<Option<Vec<Option<T::Native>>>>, |
1829 | | ) where |
1830 | | T: ArrowPrimitiveType, |
1831 | | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
1832 | | { |
1833 | | let indices = UInt32Array::from(indices); |
1834 | | |
1835 | | let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length); |
1836 | | |
1837 | | let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap(); |
1838 | | |
1839 | | let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length); |
1840 | | |
1841 | | assert_eq!(&output, &expected) |
1842 | | } |
1843 | | |
1844 | | #[test] |
1845 | | fn test_take_list() { |
1846 | | test_take_list!(i32, List, ListArray); |
1847 | | } |
1848 | | |
1849 | | #[test] |
1850 | | fn test_take_large_list() { |
1851 | | test_take_list!(i64, LargeList, LargeListArray); |
1852 | | } |
1853 | | |
1854 | | #[test] |
1855 | | fn test_take_list_with_value_nulls() { |
1856 | | test_take_list_with_value_nulls!(i32, List, ListArray); |
1857 | | } |
1858 | | |
1859 | | #[test] |
1860 | | fn test_take_large_list_with_value_nulls() { |
1861 | | test_take_list_with_value_nulls!(i64, LargeList, LargeListArray); |
1862 | | } |
1863 | | |
1864 | | #[test] |
1865 | | fn test_test_take_list_with_nulls() { |
1866 | | test_take_list_with_nulls!(i32, List, ListArray); |
1867 | | } |
1868 | | |
1869 | | #[test] |
1870 | | fn test_test_take_large_list_with_nulls() { |
1871 | | test_take_list_with_nulls!(i64, LargeList, LargeListArray); |
1872 | | } |
1873 | | |
1874 | | #[test] |
1875 | | fn test_take_fixed_size_list() { |
1876 | | do_take_fixed_size_list_test::<Int32Type>( |
1877 | | 3, |
1878 | | vec![ |
1879 | | Some(vec![None, Some(1), Some(2)]), |
1880 | | Some(vec![Some(3), Some(4), None]), |
1881 | | Some(vec![Some(6), Some(7), Some(8)]), |
1882 | | ], |
1883 | | vec![2, 1, 0], |
1884 | | vec![ |
1885 | | Some(vec![Some(6), Some(7), Some(8)]), |
1886 | | Some(vec![Some(3), Some(4), None]), |
1887 | | Some(vec![None, Some(1), Some(2)]), |
1888 | | ], |
1889 | | ); |
1890 | | |
1891 | | do_take_fixed_size_list_test::<UInt8Type>( |
1892 | | 1, |
1893 | | vec![ |
1894 | | Some(vec![Some(1)]), |
1895 | | Some(vec![Some(2)]), |
1896 | | Some(vec![Some(3)]), |
1897 | | Some(vec![Some(4)]), |
1898 | | Some(vec![Some(5)]), |
1899 | | Some(vec![Some(6)]), |
1900 | | Some(vec![Some(7)]), |
1901 | | Some(vec![Some(8)]), |
1902 | | ], |
1903 | | vec![2, 7, 0], |
1904 | | vec![ |
1905 | | Some(vec![Some(3)]), |
1906 | | Some(vec![Some(8)]), |
1907 | | Some(vec![Some(1)]), |
1908 | | ], |
1909 | | ); |
1910 | | |
1911 | | do_take_fixed_size_list_test::<UInt64Type>( |
1912 | | 3, |
1913 | | vec![ |
1914 | | Some(vec![Some(10), Some(11), Some(12)]), |
1915 | | Some(vec![Some(13), Some(14), Some(15)]), |
1916 | | None, |
1917 | | Some(vec![Some(16), Some(17), Some(18)]), |
1918 | | ], |
1919 | | vec![3, 2, 1, 2, 0], |
1920 | | vec![ |
1921 | | Some(vec![Some(16), Some(17), Some(18)]), |
1922 | | None, |
1923 | | Some(vec![Some(13), Some(14), Some(15)]), |
1924 | | None, |
1925 | | Some(vec![Some(10), Some(11), Some(12)]), |
1926 | | ], |
1927 | | ); |
1928 | | } |
1929 | | |
1930 | | #[test] |
1931 | | #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")] |
1932 | | fn test_take_list_out_of_bounds() { |
1933 | | // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]] |
1934 | | let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data(); |
1935 | | // Construct offsets |
1936 | | let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); |
1937 | | // Construct a list array from the above two |
1938 | | let list_data_type = |
1939 | | DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); |
1940 | | let list_data = ArrayData::builder(list_data_type) |
1941 | | .len(3) |
1942 | | .add_buffer(value_offsets) |
1943 | | .add_child_data(value_data) |
1944 | | .build() |
1945 | | .unwrap(); |
1946 | | let list_array = ListArray::from(list_data); |
1947 | | |
1948 | | let index = UInt32Array::from(vec![1000]); |
1949 | | |
1950 | | // A panic is expected here since we have not supplied the check_bounds |
1951 | | // option. |
1952 | | take(&list_array, &index, None).unwrap(); |
1953 | | } |
1954 | | |
1955 | | #[test] |
1956 | | fn test_take_map() { |
1957 | | let values = Int32Array::from(vec![1, 2, 3, 4]); |
1958 | | let array = |
1959 | | MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4]) |
1960 | | .unwrap(); |
1961 | | |
1962 | | let index = UInt32Array::from(vec![0]); |
1963 | | |
1964 | | let result = take(&array, &index, None).unwrap(); |
1965 | | let expected: ArrayRef = Arc::new( |
1966 | | MapArray::new_from_strings( |
1967 | | vec!["a", "b", "c"].into_iter(), |
1968 | | &values.slice(0, 3), |
1969 | | &[0, 3], |
1970 | | ) |
1971 | | .unwrap(), |
1972 | | ); |
1973 | | assert_eq!(&expected, &result); |
1974 | | } |
1975 | | |
1976 | | #[test] |
1977 | | fn test_take_struct() { |
1978 | | let array = create_test_struct(vec![ |
1979 | | Some((Some(true), Some(42))), |
1980 | | Some((Some(false), Some(28))), |
1981 | | Some((Some(false), Some(19))), |
1982 | | Some((Some(true), Some(31))), |
1983 | | None, |
1984 | | ]); |
1985 | | |
1986 | | let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]); |
1987 | | let actual = take(&array, &index, None).unwrap(); |
1988 | | let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap(); |
1989 | | assert_eq!(index.len(), actual.len()); |
1990 | | assert_eq!(1, actual.null_count()); |
1991 | | |
1992 | | let expected = create_test_struct(vec![ |
1993 | | Some((Some(true), Some(42))), |
1994 | | Some((Some(true), Some(31))), |
1995 | | Some((Some(false), Some(28))), |
1996 | | Some((Some(true), Some(42))), |
1997 | | Some((Some(false), Some(19))), |
1998 | | None, |
1999 | | ]); |
2000 | | |
2001 | | assert_eq!(&expected, actual); |
2002 | | |
2003 | | let nulls = NullBuffer::from(&[false, true, false, true, false, true]); |
2004 | | let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls)); |
2005 | | let index = UInt32Array::from(vec![0, 2, 1, 4]); |
2006 | | let actual = take(&empty_struct_arr, &index, None).unwrap(); |
2007 | | |
2008 | | let expected_nulls = NullBuffer::from(&[false, false, true, false]); |
2009 | | let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls)); |
2010 | | assert_eq!(&expected_struct_arr, actual.as_struct()); |
2011 | | } |
2012 | | |
2013 | | #[test] |
2014 | | fn test_take_struct_with_null_indices() { |
2015 | | let array = create_test_struct(vec![ |
2016 | | Some((Some(true), Some(42))), |
2017 | | Some((Some(false), Some(28))), |
2018 | | Some((Some(false), Some(19))), |
2019 | | Some((Some(true), Some(31))), |
2020 | | None, |
2021 | | ]); |
2022 | | |
2023 | | let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]); |
2024 | | let actual = take(&array, &index, None).unwrap(); |
2025 | | let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap(); |
2026 | | assert_eq!(index.len(), actual.len()); |
2027 | | assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array |
2028 | | |
2029 | | let expected = create_test_struct(vec![ |
2030 | | None, |
2031 | | Some((Some(true), Some(31))), |
2032 | | Some((Some(false), Some(28))), |
2033 | | None, |
2034 | | Some((Some(true), Some(42))), |
2035 | | None, |
2036 | | ]); |
2037 | | |
2038 | | assert_eq!(&expected, actual); |
2039 | | } |
2040 | | |
2041 | | #[test] |
2042 | | fn test_take_out_of_bounds() { |
2043 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]); |
2044 | | let take_opt = TakeOptions { check_bounds: true }; |
2045 | | |
2046 | | // int64 |
2047 | | let result = test_take_primitive_arrays::<Int64Type>( |
2048 | | vec![Some(0), None, Some(2), Some(3), None], |
2049 | | &index, |
2050 | | Some(take_opt), |
2051 | | vec![None], |
2052 | | ); |
2053 | | assert!(result.is_err()); |
2054 | | } |
2055 | | |
2056 | | #[test] |
2057 | | #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")] |
2058 | | fn test_take_out_of_bounds_panic() { |
2059 | | let index = UInt32Array::from(vec![Some(1000)]); |
2060 | | |
2061 | | test_take_primitive_arrays::<Int64Type>( |
2062 | | vec![Some(0), Some(1), Some(2), Some(3)], |
2063 | | &index, |
2064 | | None, |
2065 | | vec![None], |
2066 | | ) |
2067 | | .unwrap(); |
2068 | | } |
2069 | | |
2070 | | #[test] |
2071 | | fn test_null_array_smaller_than_indices() { |
2072 | | let values = NullArray::new(2); |
2073 | | let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); |
2074 | | |
2075 | | let result = take(&values, &indices, None).unwrap(); |
2076 | | let expected: ArrayRef = Arc::new(NullArray::new(3)); |
2077 | | assert_eq!(&result, &expected); |
2078 | | } |
2079 | | |
2080 | | #[test] |
2081 | | fn test_null_array_larger_than_indices() { |
2082 | | let values = NullArray::new(5); |
2083 | | let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); |
2084 | | |
2085 | | let result = take(&values, &indices, None).unwrap(); |
2086 | | let expected: ArrayRef = Arc::new(NullArray::new(3)); |
2087 | | assert_eq!(&result, &expected); |
2088 | | } |
2089 | | |
2090 | | #[test] |
2091 | | fn test_null_array_indices_out_of_bounds() { |
2092 | | let values = NullArray::new(5); |
2093 | | let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); |
2094 | | |
2095 | | let result = take(&values, &indices, Some(TakeOptions { check_bounds: true })); |
2096 | | assert_eq!( |
2097 | | result.unwrap_err().to_string(), |
2098 | | "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries" |
2099 | | ); |
2100 | | } |
2101 | | |
2102 | | #[test] |
2103 | | fn test_take_dict() { |
2104 | | let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new(); |
2105 | | |
2106 | | dict_builder.append("foo").unwrap(); |
2107 | | dict_builder.append("bar").unwrap(); |
2108 | | dict_builder.append("").unwrap(); |
2109 | | dict_builder.append_null(); |
2110 | | dict_builder.append("foo").unwrap(); |
2111 | | dict_builder.append("bar").unwrap(); |
2112 | | dict_builder.append("bar").unwrap(); |
2113 | | dict_builder.append("foo").unwrap(); |
2114 | | |
2115 | | let array = dict_builder.finish(); |
2116 | | let dict_values = array.values().clone(); |
2117 | | let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap(); |
2118 | | |
2119 | | let indices = UInt32Array::from(vec![ |
2120 | | Some(0), // first "foo" |
2121 | | Some(7), // last "foo" |
2122 | | None, // null index should return null |
2123 | | Some(5), // second "bar" |
2124 | | Some(6), // another "bar" |
2125 | | Some(2), // empty string |
2126 | | Some(3), // input is null at this index |
2127 | | ]); |
2128 | | |
2129 | | let result = take(&array, &indices, None).unwrap(); |
2130 | | let result = result |
2131 | | .as_any() |
2132 | | .downcast_ref::<DictionaryArray<Int16Type>>() |
2133 | | .unwrap(); |
2134 | | |
2135 | | let result_values: StringArray = result.values().to_data().into(); |
2136 | | |
2137 | | // dictionary values should stay the same |
2138 | | let expected_values = StringArray::from(vec!["foo", "bar", ""]); |
2139 | | assert_eq!(&expected_values, dict_values); |
2140 | | assert_eq!(&expected_values, &result_values); |
2141 | | |
2142 | | let expected_keys = Int16Array::from(vec![ |
2143 | | Some(0), |
2144 | | Some(0), |
2145 | | None, |
2146 | | Some(1), |
2147 | | Some(1), |
2148 | | Some(2), |
2149 | | None, |
2150 | | ]); |
2151 | | assert_eq!(result.keys(), &expected_keys); |
2152 | | } |
2153 | | |
2154 | | fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S> |
2155 | | where |
2156 | | S: OffsetSizeTrait + 'static, |
2157 | | T: ArrowPrimitiveType, |
2158 | | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
2159 | | { |
2160 | | GenericListArray::from_iter_primitive::<T, _, _>( |
2161 | | data.iter() |
2162 | | .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))), |
2163 | | ) |
2164 | | } |
2165 | | |
2166 | | #[test] |
2167 | | fn test_take_value_index_from_list() { |
2168 | | let list = build_generic_list::<i32, Int32Type>(vec![ |
2169 | | Some(vec![0, 1]), |
2170 | | Some(vec![2, 3, 4]), |
2171 | | Some(vec![5, 6, 7, 8, 9]), |
2172 | | ]); |
2173 | | let indices = UInt32Array::from(vec![2, 0]); |
2174 | | |
2175 | | let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap(); |
2176 | | |
2177 | | assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1])); |
2178 | | assert_eq!(offsets, vec![0, 5, 7]); |
2179 | | assert_eq!(null_buf.as_slice(), &[0b11111111]); |
2180 | | } |
2181 | | |
2182 | | #[test] |
2183 | | fn test_take_value_index_from_large_list() { |
2184 | | let list = build_generic_list::<i64, Int32Type>(vec![ |
2185 | | Some(vec![0, 1]), |
2186 | | Some(vec![2, 3, 4]), |
2187 | | Some(vec![5, 6, 7, 8, 9]), |
2188 | | ]); |
2189 | | let indices = UInt32Array::from(vec![2, 0]); |
2190 | | |
2191 | | let (indexed, offsets, null_buf) = |
2192 | | take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap(); |
2193 | | |
2194 | | assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1])); |
2195 | | assert_eq!(offsets, vec![0, 5, 7]); |
2196 | | assert_eq!(null_buf.as_slice(), &[0b11111111]); |
2197 | | } |
2198 | | |
2199 | | #[test] |
2200 | | fn test_take_runs() { |
2201 | | let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2]; |
2202 | | |
2203 | | let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new(); |
2204 | | builder.extend(logical_array.into_iter().map(Some)); |
2205 | | let run_array = builder.finish(); |
2206 | | |
2207 | | let take_indices: PrimitiveArray<Int32Type> = |
2208 | | vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect(); |
2209 | | |
2210 | | let take_out = take_run(&run_array, &take_indices).unwrap(); |
2211 | | |
2212 | | assert_eq!(take_out.len(), 7); |
2213 | | assert_eq!(take_out.run_ends().len(), 7); |
2214 | | assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]); |
2215 | | |
2216 | | let take_out_values = take_out.values().as_primitive::<Int32Type>(); |
2217 | | assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]); |
2218 | | } |
2219 | | |
2220 | | #[test] |
2221 | | fn test_take_value_index_from_fixed_list() { |
2222 | | let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>( |
2223 | | vec![ |
2224 | | Some(vec![Some(1), Some(2), None]), |
2225 | | Some(vec![Some(4), None, Some(6)]), |
2226 | | None, |
2227 | | Some(vec![None, Some(8), Some(9)]), |
2228 | | ], |
2229 | | 3, |
2230 | | ); |
2231 | | |
2232 | | let indices = UInt32Array::from(vec![2, 1, 0]); |
2233 | | let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap(); |
2234 | | |
2235 | | assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2])); |
2236 | | |
2237 | | let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]); |
2238 | | let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap(); |
2239 | | |
2240 | | assert_eq!( |
2241 | | indexed, |
2242 | | UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2]) |
2243 | | ); |
2244 | | } |
2245 | | |
2246 | | #[test] |
2247 | | fn test_take_null_indices() { |
2248 | | // Build indices with values that are out of bounds, but masked by null mask |
2249 | | let indices = Int32Array::new( |
2250 | | vec![1, 2, 400, 400].into(), |
2251 | | Some(NullBuffer::from(vec![true, true, false, false])), |
2252 | | ); |
2253 | | let values = Int32Array::from(vec![1, 23, 4, 5]); |
2254 | | let r = take(&values, &indices, None).unwrap(); |
2255 | | let values = r |
2256 | | .as_primitive::<Int32Type>() |
2257 | | .into_iter() |
2258 | | .collect::<Vec<_>>(); |
2259 | | assert_eq!(&values, &[Some(23), Some(4), None, None]) |
2260 | | } |
2261 | | |
2262 | | #[test] |
2263 | | fn test_take_fixed_size_list_null_indices() { |
2264 | | let indices = Int32Array::from_iter([Some(0), None]); |
2265 | | let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3])); |
2266 | | let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true)); |
2267 | | let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap(); |
2268 | | |
2269 | | let r = take(&values, &indices, None).unwrap(); |
2270 | | let values = r |
2271 | | .as_fixed_size_list() |
2272 | | .values() |
2273 | | .as_primitive::<Int32Type>() |
2274 | | .into_iter() |
2275 | | .collect::<Vec<_>>(); |
2276 | | assert_eq!(values, &[Some(0), Some(1), None, None]) |
2277 | | } |
2278 | | |
2279 | | #[test] |
2280 | | fn test_take_bytes_null_indices() { |
2281 | | let indices = Int32Array::new( |
2282 | | vec![0, 1, 400, 400].into(), |
2283 | | Some(NullBuffer::from_iter(vec![true, true, false, false])), |
2284 | | ); |
2285 | | let values = StringArray::from(vec![Some("foo"), None]); |
2286 | | let r = take(&values, &indices, None).unwrap(); |
2287 | | let values = r.as_string::<i32>().iter().collect::<Vec<_>>(); |
2288 | | assert_eq!(&values, &[Some("foo"), None, None, None]) |
2289 | | } |
2290 | | |
2291 | | #[test] |
2292 | | fn test_take_union_sparse() { |
2293 | | let structs = create_test_struct(vec![ |
2294 | | Some((Some(true), Some(42))), |
2295 | | Some((Some(false), Some(28))), |
2296 | | Some((Some(false), Some(19))), |
2297 | | Some((Some(true), Some(31))), |
2298 | | None, |
2299 | | ]); |
2300 | | let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]); |
2301 | | let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>(); |
2302 | | |
2303 | | let union_fields = [ |
2304 | | ( |
2305 | | 0, |
2306 | | Arc::new(Field::new("f1", structs.data_type().clone(), true)), |
2307 | | ), |
2308 | | ( |
2309 | | 1, |
2310 | | Arc::new(Field::new("f2", strings.data_type().clone(), true)), |
2311 | | ), |
2312 | | ] |
2313 | | .into_iter() |
2314 | | .collect(); |
2315 | | let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)]; |
2316 | | let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); |
2317 | | |
2318 | | let indices = vec![0, 3, 1, 0, 2, 4]; |
2319 | | let index = UInt32Array::from(indices.clone()); |
2320 | | let actual = take(&array, &index, None).unwrap(); |
2321 | | let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap(); |
2322 | | let strings = actual.child(1); |
2323 | | let strings = strings.as_any().downcast_ref::<StringArray>().unwrap(); |
2324 | | |
2325 | | let actual = strings.iter().collect::<Vec<_>>(); |
2326 | | let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")]; |
2327 | | assert_eq!(expected, actual); |
2328 | | } |
2329 | | |
2330 | | #[test] |
2331 | | fn test_take_union_dense() { |
2332 | | let type_ids = vec![0, 1, 1, 0, 0, 1, 0]; |
2333 | | let offsets = vec![0, 0, 1, 1, 2, 2, 3]; |
2334 | | let ints = vec![10, 20, 30, 40]; |
2335 | | let strings = vec![Some("a"), None, Some("c"), Some("d")]; |
2336 | | |
2337 | | let indices = vec![0, 3, 1, 0, 2, 4]; |
2338 | | |
2339 | | let taken_type_ids = vec![0, 0, 1, 0, 1, 0]; |
2340 | | let taken_offsets = vec![0, 1, 0, 2, 1, 3]; |
2341 | | let taken_ints = vec![10, 20, 10, 30]; |
2342 | | let taken_strings = vec![Some("a"), None]; |
2343 | | |
2344 | | let type_ids = <ScalarBuffer<i8>>::from(type_ids); |
2345 | | let offsets = <ScalarBuffer<i32>>::from(offsets); |
2346 | | let ints = UInt32Array::from(ints); |
2347 | | let strings = StringArray::from(strings); |
2348 | | |
2349 | | let union_fields = [ |
2350 | | ( |
2351 | | 0, |
2352 | | Arc::new(Field::new("f1", ints.data_type().clone(), true)), |
2353 | | ), |
2354 | | ( |
2355 | | 1, |
2356 | | Arc::new(Field::new("f2", strings.data_type().clone(), true)), |
2357 | | ), |
2358 | | ] |
2359 | | .into_iter() |
2360 | | .collect(); |
2361 | | |
2362 | | let array = UnionArray::try_new( |
2363 | | union_fields, |
2364 | | type_ids, |
2365 | | Some(offsets), |
2366 | | vec![Arc::new(ints), Arc::new(strings)], |
2367 | | ) |
2368 | | .unwrap(); |
2369 | | |
2370 | | let index = UInt32Array::from(indices); |
2371 | | |
2372 | | let actual = take(&array, &index, None).unwrap(); |
2373 | | let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap(); |
2374 | | |
2375 | | assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets))); |
2376 | | assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids)); |
2377 | | assert_eq!( |
2378 | | UInt32Array::from(actual.child(0).to_data()), |
2379 | | UInt32Array::from(taken_ints) |
2380 | | ); |
2381 | | assert_eq!( |
2382 | | StringArray::from(actual.child(1).to_data()), |
2383 | | StringArray::from(taken_strings) |
2384 | | ); |
2385 | | } |
2386 | | |
2387 | | #[test] |
2388 | | fn test_take_union_dense_using_builder() { |
2389 | | let mut builder = UnionBuilder::new_dense(); |
2390 | | |
2391 | | builder.append::<Int32Type>("a", 1).unwrap(); |
2392 | | builder.append::<Float64Type>("b", 3.0).unwrap(); |
2393 | | builder.append::<Int32Type>("a", 4).unwrap(); |
2394 | | builder.append::<Int32Type>("a", 5).unwrap(); |
2395 | | builder.append::<Float64Type>("b", 2.0).unwrap(); |
2396 | | |
2397 | | let union = builder.build().unwrap(); |
2398 | | |
2399 | | let indices = UInt32Array::from(vec![2, 0, 1, 2]); |
2400 | | |
2401 | | let mut builder = UnionBuilder::new_dense(); |
2402 | | |
2403 | | builder.append::<Int32Type>("a", 4).unwrap(); |
2404 | | builder.append::<Int32Type>("a", 1).unwrap(); |
2405 | | builder.append::<Float64Type>("b", 3.0).unwrap(); |
2406 | | builder.append::<Int32Type>("a", 4).unwrap(); |
2407 | | |
2408 | | let taken = builder.build().unwrap(); |
2409 | | |
2410 | | assert_eq!( |
2411 | | taken.to_data(), |
2412 | | take(&union, &indices, None).unwrap().to_data() |
2413 | | ); |
2414 | | } |
2415 | | |
2416 | | #[test] |
2417 | | fn test_take_union_dense_all_match_issue_6206() { |
2418 | | let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]); |
2419 | | let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); |
2420 | | |
2421 | | let array = UnionArray::try_new( |
2422 | | fields, |
2423 | | ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]), |
2424 | | Some(ScalarBuffer::from_iter(0_i32..5)), |
2425 | | vec![ints], |
2426 | | ) |
2427 | | .unwrap(); |
2428 | | |
2429 | | let indicies = Int64Array::from(vec![0, 2, 4]); |
2430 | | let array = take(&array, &indicies, None).unwrap(); |
2431 | | assert_eq!(array.len(), 3); |
2432 | | } |
2433 | | |
2434 | | #[test] |
2435 | | fn test_take_bytes_offset_overflow() { |
2436 | | let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]); |
2437 | | let text = ('a'..='z').collect::<String>(); |
2438 | | let values = StringArray::from(vec![Some(text.clone())]); |
2439 | | assert!(matches!( |
2440 | | take(&values, &indices, None), |
2441 | | Err(ArrowError::OffsetOverflowError(_)) |
2442 | | )); |
2443 | | } |
2444 | | } |