/Users/andrewlamb/Software/arrow-rs/arrow-select/src/take.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines take kernel for [Array] |
19 | | |
20 | | use std::sync::Arc; |
21 | | |
22 | | use arrow_array::builder::{BufferBuilder, UInt32Builder}; |
23 | | use arrow_array::cast::AsArray; |
24 | | use arrow_array::types::*; |
25 | | use arrow_array::*; |
26 | | use arrow_buffer::{ |
27 | | ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer, |
28 | | bit_util, |
29 | | }; |
30 | | use arrow_data::ArrayDataBuilder; |
31 | | use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode}; |
32 | | |
33 | | use num_traits::{One, Zero}; |
34 | | |
35 | | /// Take elements by index from [Array], creating a new [Array] from those indexes. |
36 | | /// |
37 | | /// ```text |
38 | | /// ┌─────────────────┐ ┌─────────┐ ┌─────────────────┐ |
39 | | /// │ A │ │ 0 │ │ A │ |
40 | | /// ├─────────────────┤ ├─────────┤ ├─────────────────┤ |
41 | | /// │ D │ │ 2 │ │ B │ |
42 | | /// ├─────────────────┤ ├─────────┤ take(values, indices) ├─────────────────┤ |
43 | | /// │ B │ │ 3 │ ─────────────────────────▶ │ C │ |
44 | | /// ├─────────────────┤ ├─────────┤ ├─────────────────┤ |
45 | | /// │ C │ │ 1 │ │ D │ |
46 | | /// ├─────────────────┤ └─────────┘ └─────────────────┘ |
47 | | /// │ E │ |
48 | | /// └─────────────────┘ |
49 | | /// values array indices array result |
50 | | /// ``` |
51 | | /// |
52 | | /// For selecting values by index from multiple arrays see [`crate::interleave`] |
53 | | /// |
54 | | /// Note that this kernel, similar to other kernels in this crate, |
55 | | /// will avoid allocating where not necessary. Consequently |
56 | | /// the returned array may share buffers with the inputs |
57 | | /// |
58 | | /// # Errors |
59 | | /// This function errors whenever: |
60 | | /// * An index cannot be casted to `usize` (typically 32 bit architectures) |
61 | | /// * An index is out of bounds and `options` is set to check bounds. |
62 | | /// |
63 | | /// # Safety |
64 | | /// |
65 | | /// When `options` is not set to check bounds, taking indexes after `len` will panic. |
66 | | /// |
67 | | /// # See also |
68 | | /// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce |
69 | | /// the results into a single array. |
70 | | /// |
71 | | /// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer |
72 | | /// |
73 | | /// # Examples |
74 | | /// ``` |
75 | | /// # use arrow_array::{StringArray, UInt32Array, cast::AsArray}; |
76 | | /// # use arrow_select::take::take; |
77 | | /// let values = StringArray::from(vec!["zero", "one", "two"]); |
78 | | /// |
79 | | /// // Take items at index 2, and 1: |
80 | | /// let indices = UInt32Array::from(vec![2, 1]); |
81 | | /// let taken = take(&values, &indices, None).unwrap(); |
82 | | /// let taken = taken.as_string::<i32>(); |
83 | | /// |
84 | | /// assert_eq!(*taken, StringArray::from(vec!["two", "one"])); |
85 | | /// ``` |
86 | 80 | pub fn take( |
87 | 80 | values: &dyn Array, |
88 | 80 | indices: &dyn Array, |
89 | 80 | options: Option<TakeOptions>, |
90 | 80 | ) -> Result<ArrayRef, ArrowError> { |
91 | 80 | let options = options.unwrap_or_default(); |
92 | 80 | downcast_integer_array!( |
93 | | indices => { |
94 | 0 | if options.check_bounds { |
95 | 0 | check_bounds(values.len(), indices)?; |
96 | 0 | } |
97 | 0 | let indices = indices.to_indices(); |
98 | 0 | take_impl(values, &indices) |
99 | | }, |
100 | 0 | d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}"))) |
101 | | ) |
102 | 80 | } |
103 | | |
104 | | /// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new |
105 | | /// [`Vec<ArrayRef>`] from those indices. |
106 | | /// |
107 | | /// ```text |
108 | | /// ┌────────┬────────┐ |
109 | | /// │ │ │ ┌────────┐ ┌────────┬────────┐ |
110 | | /// │ A │ 1 │ │ │ │ │ │ |
111 | | /// ├────────┼────────┤ │ 0 │ │ A │ 1 │ |
112 | | /// │ │ │ ├────────┤ ├────────┼────────┤ |
113 | | /// │ D │ 4 │ │ │ │ │ │ |
114 | | /// ├────────┼────────┤ │ 2 │ take_arrays(values,indices) │ B │ 2 │ |
115 | | /// │ │ │ ├────────┤ ├────────┼────────┤ |
116 | | /// │ B │ 2 │ │ │ ───────────────────────────► │ │ │ |
117 | | /// ├────────┼────────┤ │ 3 │ │ C │ 3 │ |
118 | | /// │ │ │ ├────────┤ ├────────┼────────┤ |
119 | | /// │ C │ 3 │ │ │ │ │ │ |
120 | | /// ├────────┼────────┤ │ 1 │ │ D │ 4 │ |
121 | | /// │ │ │ └────────┘ └────────┼────────┘ |
122 | | /// │ E │ 5 │ |
123 | | /// └────────┴────────┘ |
124 | | /// values arrays indices array result |
125 | | /// ``` |
126 | | /// |
127 | | /// # Errors |
128 | | /// This function errors whenever: |
129 | | /// * An index cannot be casted to `usize` (typically 32 bit architectures) |
130 | | /// * An index is out of bounds and `options` is set to check bounds. |
131 | | /// |
132 | | /// # Safety |
133 | | /// |
134 | | /// When `options` is not set to check bounds, taking indexes after `len` will panic. |
135 | | /// |
136 | | /// # Examples |
137 | | /// ``` |
138 | | /// # use std::sync::Arc; |
139 | | /// # use arrow_array::{StringArray, UInt32Array, cast::AsArray}; |
140 | | /// # use arrow_select::take::{take, take_arrays}; |
141 | | /// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"])); |
142 | | /// let values = Arc::new(UInt32Array::from(vec![0, 1, 2])); |
143 | | /// |
144 | | /// // Take items at index 2, and 1: |
145 | | /// let indices = UInt32Array::from(vec![2, 1]); |
146 | | /// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap(); |
147 | | /// let taken_string = taken_arrays[0].as_string::<i32>(); |
148 | | /// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"])); |
149 | | /// let taken_values = taken_arrays[1].as_primitive(); |
150 | | /// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1])); |
151 | | /// ``` |
152 | 0 | pub fn take_arrays( |
153 | 0 | arrays: &[ArrayRef], |
154 | 0 | indices: &dyn Array, |
155 | 0 | options: Option<TakeOptions>, |
156 | 0 | ) -> Result<Vec<ArrayRef>, ArrowError> { |
157 | 0 | arrays |
158 | 0 | .iter() |
159 | 0 | .map(|array| take(array.as_ref(), indices, options.clone())) |
160 | 0 | .collect() |
161 | 0 | } |
162 | | |
163 | | /// Verifies that the non-null values of `indices` are all `< len` |
164 | 2 | fn check_bounds<T: ArrowPrimitiveType>( |
165 | 2 | len: usize, |
166 | 2 | indices: &PrimitiveArray<T>, |
167 | 2 | ) -> Result<(), ArrowError> { |
168 | 2 | if indices.null_count() > 0 { |
169 | 6 | indices2 .iter().flatten().try_for_each2 (|index| { |
170 | 6 | let ix = index |
171 | 6 | .to_usize() |
172 | 6 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed"0 .to_string0 ()))?0 ; |
173 | 6 | if ix >= len { |
174 | 2 | return Err(ArrowError::ComputeError(format!( |
175 | 2 | "Array index out of bounds, cannot get item at index {ix} from {len} entries" |
176 | 2 | ))); |
177 | 4 | } |
178 | 4 | Ok(()) |
179 | 6 | }) |
180 | | } else { |
181 | 0 | indices.values().iter().try_for_each(|index| { |
182 | 0 | let ix = index |
183 | 0 | .to_usize() |
184 | 0 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?; |
185 | 0 | if ix >= len { |
186 | 0 | return Err(ArrowError::ComputeError(format!( |
187 | 0 | "Array index out of bounds, cannot get item at index {ix} from {len} entries" |
188 | 0 | ))); |
189 | 0 | } |
190 | 0 | Ok(()) |
191 | 0 | }) |
192 | | } |
193 | 2 | } |
194 | | |
195 | | #[inline(never)] |
196 | 104 | fn take_impl<IndexType: ArrowPrimitiveType>( |
197 | 104 | values: &dyn Array, |
198 | 104 | indices: &PrimitiveArray<IndexType>, |
199 | 104 | ) -> Result<ArrayRef, ArrowError> { |
200 | 104 | downcast_primitive_array! { |
201 | 4 | values => Ok(Arc::new(take_primitive(values, indices)?0 )), |
202 | | DataType::Boolean => { |
203 | 7 | let values = values.as_any().downcast_ref::<BooleanArray>().unwrap(); |
204 | 7 | Ok(Arc::new(take_boolean(values, indices))) |
205 | | } |
206 | | DataType::Utf8 => { |
207 | 10 | Ok(Arc::new9 (take_bytes(values.as_string::<i32>(), indices)?1 )) |
208 | | } |
209 | | DataType::LargeUtf8 => { |
210 | 1 | Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?0 )) |
211 | | } |
212 | | DataType::Utf8View => { |
213 | 1 | Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?0 )) |
214 | | } |
215 | | DataType::List(_) => { |
216 | 4 | Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?0 )) |
217 | | } |
218 | | DataType::LargeList(_) => { |
219 | 3 | Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?0 )) |
220 | | } |
221 | | DataType::ListView(_) => { |
222 | 4 | Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?0 )) |
223 | | } |
224 | | DataType::LargeListView(_) => { |
225 | 4 | Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?0 )) |
226 | | } |
227 | 1 | DataType::FixedSizeList(_, length) => { |
228 | 1 | let values = values |
229 | 1 | .as_any() |
230 | 1 | .downcast_ref::<FixedSizeListArray>() |
231 | 1 | .unwrap(); |
232 | 1 | Ok(Arc::new(take_fixed_size_list( |
233 | 1 | values, |
234 | 1 | indices, |
235 | 1 | *length as u32, |
236 | 0 | )?)) |
237 | | } |
238 | | DataType::Map(_, _) => { |
239 | 1 | let list_arr = ListArray::from(values.as_map().clone()); |
240 | 1 | let list_data = take_list::<_, Int32Type>(&list_arr, indices)?0 ; |
241 | 1 | let builder = list_data.into_data().into_builder().data_type(values.data_type().clone()); |
242 | 1 | Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() }))) |
243 | | } |
244 | 5 | DataType::Struct(fields) => { |
245 | 5 | let array: &StructArray = values.as_struct(); |
246 | 5 | let arrays = array |
247 | 5 | .columns() |
248 | 5 | .iter() |
249 | 8 | .map5 (|a| take_impl(a.as_ref(), indices)) |
250 | 5 | .collect::<Result<Vec<ArrayRef>, _>>()?0 ; |
251 | 5 | let fields: Vec<(FieldRef, ArrayRef)> = |
252 | 5 | fields.iter().cloned().zip(arrays).collect(); |
253 | | |
254 | | // Create the null bit buffer. |
255 | 5 | let is_valid: Buffer = indices |
256 | 5 | .iter() |
257 | 25 | .map5 (|index| { |
258 | 25 | if let Some(index23 ) = index { |
259 | 23 | array.is_valid(index.to_usize().unwrap()) |
260 | | } else { |
261 | 2 | false |
262 | | } |
263 | 25 | }) |
264 | 5 | .collect(); |
265 | | |
266 | 5 | if fields.is_empty() { |
267 | 1 | let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len())); |
268 | 1 | Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls)))) |
269 | | } else { |
270 | 4 | Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef) |
271 | | } |
272 | | } |
273 | 1 | DataType::Dictionary(_, _) => downcast_dictionary_array! { |
274 | 0 | values => Ok(Arc::new(take_dict(values, indices)?)), |
275 | 0 | t => unimplemented!("Take not supported for dictionary type {:?}", t) |
276 | | } |
277 | 0 | DataType::RunEndEncoded(_, _) => downcast_run_array! { |
278 | 0 | values => Ok(Arc::new(take_run(values, indices)?)), |
279 | 0 | t => unimplemented!("Take not supported for run type {:?}", t) |
280 | | } |
281 | | DataType::Binary => { |
282 | 0 | Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?)) |
283 | | } |
284 | | DataType::LargeBinary => { |
285 | 0 | Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?)) |
286 | | } |
287 | | DataType::BinaryView => { |
288 | 1 | Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?0 )) |
289 | | } |
290 | 0 | DataType::FixedSizeBinary(size) => { |
291 | 0 | let values = values |
292 | 0 | .as_any() |
293 | 0 | .downcast_ref::<FixedSizeBinaryArray>() |
294 | 0 | .unwrap(); |
295 | 0 | Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?)) |
296 | | } |
297 | | DataType::Null => { |
298 | | // Take applied to a null array produces a null array. |
299 | 2 | if values.len() >= indices.len() { |
300 | | // If the existing null array is as big as the indices, we can use a slice of it |
301 | | // to avoid allocating a new null array. |
302 | 1 | Ok(values.slice(0, indices.len())) |
303 | | } else { |
304 | | // If the existing null array isn't big enough, create a new one. |
305 | 1 | Ok(new_null_array(&DataType::Null, indices.len())) |
306 | | } |
307 | | } |
308 | 1 | DataType::Union(fields, UnionMode::Sparse) => { |
309 | 1 | let mut children = Vec::with_capacity(fields.len()); |
310 | 1 | let values = values.as_any().downcast_ref::<UnionArray>().unwrap(); |
311 | 1 | let type_ids = take_native(values.type_ids(), indices); |
312 | 2 | for (type_id, _field) in fields1 .iter1 () { |
313 | 2 | let values = values.child(type_id); |
314 | 2 | let values = take_impl(values, indices)?0 ; |
315 | 2 | children.push(values); |
316 | | } |
317 | 1 | let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?0 ; |
318 | 1 | Ok(Arc::new(array)) |
319 | | } |
320 | 3 | DataType::Union(fields, UnionMode::Dense) => { |
321 | 3 | let values = values.as_any().downcast_ref::<UnionArray>().unwrap(); |
322 | | |
323 | 3 | let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?0 ; |
324 | 3 | let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?0 ; |
325 | | |
326 | 3 | let children = fields.iter() |
327 | 5 | .map3 (|(field_type_id, _)| { |
328 | 23 | let mask5 = BooleanArray::from_unary5 (&type_ids5 , |value_type_id| value_type_id == field_type_id); |
329 | | |
330 | 5 | let indices = crate::filter::filter(&offsets, &mask)?0 ; |
331 | | |
332 | 5 | let values = values.child(field_type_id); |
333 | | |
334 | 5 | take_impl(values, indices.as_primitive::<Int32Type>()) |
335 | 5 | }) |
336 | 3 | .collect::<Result<_, _>>()?0 ; |
337 | | |
338 | 3 | let mut child_offsets = [0; 128]; |
339 | | |
340 | 3 | let offsets = type_ids.values() |
341 | 3 | .iter() |
342 | 13 | .map3 (|&i| { |
343 | 13 | let offset = child_offsets[i as usize]; |
344 | | |
345 | 13 | child_offsets[i as usize] += 1; |
346 | | |
347 | 13 | offset |
348 | 13 | }) |
349 | 3 | .collect(); |
350 | | |
351 | 3 | let (_, type_ids, _) = type_ids.into_parts(); |
352 | | |
353 | 3 | let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?0 ; |
354 | | |
355 | 3 | Ok(Arc::new(array)) |
356 | | } |
357 | 0 | t => unimplemented!("Take not supported for data type {:?}", t) |
358 | | } |
359 | 104 | } |
360 | | |
361 | | /// Options that define how `take` should behave |
362 | | #[derive(Clone, Debug, Default)] |
363 | | pub struct TakeOptions { |
364 | | /// Perform bounds check before taking indices from values. |
365 | | /// If enabled, an `ArrowError` is returned if the indices are out of bounds. |
366 | | /// If not enabled, and indices exceed bounds, the kernel will panic. |
367 | | pub check_bounds: bool, |
368 | | } |
369 | | |
370 | | #[inline(always)] |
371 | 0 | fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> { |
372 | 0 | index |
373 | 0 | .to_usize() |
374 | 0 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string())) |
375 | 0 | } |
376 | | |
377 | | /// `take` implementation for all primitive arrays |
378 | | /// |
379 | | /// This checks if an `indices` slot is populated, and gets the value from `values` |
380 | | /// as the populated index. |
381 | | /// If the `indices` slot is null, a null value is returned. |
382 | | /// For example, given: |
383 | | /// values: [1, 2, 3, null, 5] |
384 | | /// indices: [0, null, 4, 3] |
385 | | /// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)] |
386 | 56 | fn take_primitive<T, I>( |
387 | 56 | values: &PrimitiveArray<T>, |
388 | 56 | indices: &PrimitiveArray<I>, |
389 | 56 | ) -> Result<PrimitiveArray<T>, ArrowError> |
390 | 56 | where |
391 | 56 | T: ArrowPrimitiveType, |
392 | 56 | I: ArrowPrimitiveType, |
393 | | { |
394 | 56 | let values_buf = take_native(values.values(), indices); |
395 | 56 | let nulls = take_nulls(values.nulls(), indices); |
396 | 56 | Ok(PrimitiveArray::try_new(values_buf, nulls)?0 .with_data_type(values.data_type().clone())) |
397 | 56 | } |
398 | | |
399 | | #[inline(never)] |
400 | 83 | fn take_nulls<I: ArrowPrimitiveType>( |
401 | 83 | values: Option<&NullBuffer>, |
402 | 83 | indices: &PrimitiveArray<I>, |
403 | 83 | ) -> Option<NullBuffer> { |
404 | 83 | match values.filter(|n| n60 .null_count60 () > 0) { |
405 | 60 | Some(n) => { |
406 | 60 | let buffer = take_bits(n.inner(), indices); |
407 | 60 | Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0) |
408 | | } |
409 | 23 | None => indices.nulls().cloned(), |
410 | | } |
411 | 83 | } |
412 | | |
413 | | #[inline(never)] |
414 | 81 | fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>( |
415 | 81 | values: &[T], |
416 | 81 | indices: &PrimitiveArray<I>, |
417 | 81 | ) -> ScalarBuffer<T> { |
418 | 81 | match indices.nulls().filter(|n| n49 .null_count49 () > 0) { |
419 | 49 | Some(n) => indices |
420 | 49 | .values() |
421 | 49 | .iter() |
422 | 49 | .enumerate() |
423 | 235 | .map49 (|(idx, index)| match values.get(index.as_usize()) { |
424 | 232 | Some(v) => *v, |
425 | 3 | None => match n.is_null(idx) { |
426 | 3 | true => T::default(), |
427 | 0 | false => panic!("Out-of-bounds index {index:?}"), |
428 | | }, |
429 | 235 | }) |
430 | 49 | .collect(), |
431 | 32 | None => indices |
432 | 32 | .values() |
433 | 32 | .iter() |
434 | 171 | .map32 (|index| values[index.as_usize()]) |
435 | 32 | .collect(), |
436 | | } |
437 | 81 | } |
438 | | |
439 | | #[inline(never)] |
440 | 67 | fn take_bits<I: ArrowPrimitiveType>( |
441 | 67 | values: &BooleanBuffer, |
442 | 67 | indices: &PrimitiveArray<I>, |
443 | 67 | ) -> BooleanBuffer { |
444 | 67 | let len = indices.len(); |
445 | | |
446 | 67 | match indices.nulls().filter(|n| n49 .null_count49 () > 0) { |
447 | 49 | Some(nulls) => { |
448 | 49 | let mut output_buffer = MutableBuffer::new_null(len); |
449 | 49 | let output_slice = output_buffer.as_slice_mut(); |
450 | 182 | nulls49 .valid_indices49 ().for_each49 (|idx| { |
451 | 182 | if values.value(indices.value(idx).as_usize()) { |
452 | 125 | bit_util::set_bit(output_slice, idx); |
453 | 125 | }57 |
454 | 182 | }); |
455 | 49 | BooleanBuffer::new(output_buffer.into(), 0, len) |
456 | | } |
457 | | None => { |
458 | 120 | BooleanBuffer::collect_bool18 (len18 , |idx: usize| { |
459 | | // SAFETY: idx<indices.len() |
460 | 120 | values.value(unsafe { indices.value_unchecked(idx).as_usize() }) |
461 | 120 | }) |
462 | | } |
463 | | } |
464 | 67 | } |
465 | | |
466 | | /// `take` implementation for boolean arrays |
467 | 7 | fn take_boolean<IndexType: ArrowPrimitiveType>( |
468 | 7 | values: &BooleanArray, |
469 | 7 | indices: &PrimitiveArray<IndexType>, |
470 | 7 | ) -> BooleanArray { |
471 | 7 | let val_buf = take_bits(values.values(), indices); |
472 | 7 | let null_buf = take_nulls(values.nulls(), indices); |
473 | 7 | BooleanArray::new(val_buf, null_buf) |
474 | 7 | } |
475 | | |
476 | | /// `take` implementation for string arrays |
477 | 11 | fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>( |
478 | 11 | array: &GenericByteArray<T>, |
479 | 11 | indices: &PrimitiveArray<IndexType>, |
480 | 11 | ) -> Result<GenericByteArray<T>, ArrowError> { |
481 | 11 | let mut offsets = Vec::with_capacity(indices.len() + 1); |
482 | 11 | offsets.push(T::Offset::default()); |
483 | | |
484 | 11 | let input_offsets = array.value_offsets(); |
485 | 11 | let mut capacity = 0; |
486 | 11 | let nulls = take_nulls(array.nulls(), indices); |
487 | | |
488 | 11 | let (offsets10 , values10 ) = if array.null_count() == 0 && indices.null_count() == 05 { |
489 | 5 | offsets.reserve(indices.len()); |
490 | 82.5M | for index in indices5 .values5 () { |
491 | 82.5M | let index = index.as_usize(); |
492 | 82.5M | capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); |
493 | 82.5M | offsets.push82.5M ( |
494 | 82.5M | T::Offset::from_usize(capacity) |
495 | 82.5M | .ok_or_else(|| ArrowError::OffsetOverflowError(capacity1 ))?1 , |
496 | | ); |
497 | | } |
498 | 4 | let mut values = Vec::with_capacity(capacity); |
499 | | |
500 | 9 | for index in indices4 .values4 () { |
501 | 9 | values.extend_from_slice(array.value(index.as_usize()).as_ref()); |
502 | 9 | } |
503 | 4 | (offsets, values) |
504 | 6 | } else if indices.null_count() == 0 { |
505 | 2 | offsets.reserve(indices.len()); |
506 | 8 | for index in indices2 .values2 () { |
507 | 8 | let index = index.as_usize(); |
508 | 8 | if array.is_valid(index) { |
509 | 5 | capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); |
510 | 5 | }3 |
511 | 8 | offsets.push( |
512 | 8 | T::Offset::from_usize(capacity) |
513 | 8 | .ok_or_else(|| ArrowError::OffsetOverflowError(capacity0 ))?0 , |
514 | | ); |
515 | | } |
516 | 2 | let mut values = Vec::with_capacity(capacity); |
517 | | |
518 | 8 | for index in indices2 .values2 () { |
519 | 8 | let index = index.as_usize(); |
520 | 8 | if array.is_valid(index) { |
521 | 5 | values.extend_from_slice(array.value(index).as_ref()); |
522 | 5 | }3 |
523 | | } |
524 | 2 | (offsets, values) |
525 | 4 | } else if array.null_count() == 0 { |
526 | 0 | offsets.reserve(indices.len()); |
527 | 0 | for (i, index) in indices.values().iter().enumerate() { |
528 | 0 | let index = index.as_usize(); |
529 | 0 | if indices.is_valid(i) { |
530 | 0 | capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); |
531 | 0 | } |
532 | 0 | offsets.push( |
533 | 0 | T::Offset::from_usize(capacity) |
534 | 0 | .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?, |
535 | | ); |
536 | | } |
537 | 0 | let mut values = Vec::with_capacity(capacity); |
538 | | |
539 | 0 | for (i, index) in indices.values().iter().enumerate() { |
540 | 0 | if indices.is_valid(i) { |
541 | 0 | values.extend_from_slice(array.value(index.as_usize()).as_ref()); |
542 | 0 | } |
543 | | } |
544 | 0 | (offsets, values) |
545 | | } else { |
546 | 4 | let nulls = nulls.as_ref().unwrap(); |
547 | 4 | offsets.reserve(indices.len()); |
548 | 18 | for (i, index) in indices.values().iter()4 .enumerate4 () { |
549 | 18 | let index = index.as_usize(); |
550 | 18 | if nulls.is_valid(i) { |
551 | 9 | capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); |
552 | 9 | } |
553 | 18 | offsets.push( |
554 | 18 | T::Offset::from_usize(capacity) |
555 | 18 | .ok_or_else(|| ArrowError::OffsetOverflowError(capacity0 ))?0 , |
556 | | ); |
557 | | } |
558 | 4 | let mut values = Vec::with_capacity(capacity); |
559 | | |
560 | 18 | for (i, index) in indices.values().iter()4 .enumerate4 () { |
561 | | // check index is valid before using index. The value in |
562 | | // NULL index slots may not be within bounds of array |
563 | 18 | let index = index.as_usize(); |
564 | 18 | if nulls.is_valid(i) { |
565 | 9 | values.extend_from_slice(array.value(index).as_ref()); |
566 | 9 | } |
567 | | } |
568 | 4 | (offsets, values) |
569 | | }; |
570 | | |
571 | 10 | T::Offset::from_usize(values.len()) |
572 | 10 | .ok_or_else(|| ArrowError::OffsetOverflowError(values0 .len0 ()))?0 ; |
573 | | |
574 | 10 | let array = unsafe { |
575 | 10 | let offsets = OffsetBuffer::new_unchecked(offsets.into()); |
576 | 10 | GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls) |
577 | | }; |
578 | | |
579 | 10 | Ok(array) |
580 | 11 | } |
581 | | |
582 | | /// `take` implementation for byte view arrays |
583 | 2 | fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>( |
584 | 2 | array: &GenericByteViewArray<T>, |
585 | 2 | indices: &PrimitiveArray<IndexType>, |
586 | 2 | ) -> Result<GenericByteViewArray<T>, ArrowError> { |
587 | 2 | let new_views = take_native(array.views(), indices); |
588 | 2 | let new_nulls = take_nulls(array.nulls(), indices); |
589 | | // Safety: array.views was valid, and take_native copies only valid values, and verifies bounds |
590 | 2 | Ok(unsafe { |
591 | 2 | GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls) |
592 | 2 | }) |
593 | 2 | } |
594 | | |
595 | | /// `take` implementation for list arrays |
596 | | /// |
597 | | /// Calculates the index and indexed offset for the inner array, |
598 | | /// applying `take` on the inner array, then reconstructing a list array |
599 | | /// with the indexed offsets |
600 | 8 | fn take_list<IndexType, OffsetType>( |
601 | 8 | values: &GenericListArray<OffsetType::Native>, |
602 | 8 | indices: &PrimitiveArray<IndexType>, |
603 | 8 | ) -> Result<GenericListArray<OffsetType::Native>, ArrowError> |
604 | 8 | where |
605 | 8 | IndexType: ArrowPrimitiveType, |
606 | 8 | OffsetType: ArrowPrimitiveType, |
607 | 8 | OffsetType::Native: OffsetSizeTrait, |
608 | 8 | PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>, |
609 | | { |
610 | | // TODO: Some optimizations can be done here such as if it is |
611 | | // taking the whole list or a contiguous sublist |
612 | 8 | let (list_indices, offsets, null_buf) = |
613 | 8 | take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?0 ; |
614 | | |
615 | 8 | let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?0 ; |
616 | 8 | let value_offsets = Buffer::from_vec(offsets); |
617 | | // create a new list with taken data and computed null information |
618 | 8 | let list_data = ArrayDataBuilder::new(values.data_type().clone()) |
619 | 8 | .len(indices.len()) |
620 | 8 | .null_bit_buffer(Some(null_buf.into())) |
621 | 8 | .offset(0) |
622 | 8 | .add_child_data(taken.into_data()) |
623 | 8 | .add_buffer(value_offsets); |
624 | | |
625 | 8 | let list_data = unsafe { list_data.build_unchecked() }; |
626 | | |
627 | 8 | Ok(GenericListArray::<OffsetType::Native>::from(list_data)) |
628 | 8 | } |
629 | | |
630 | 8 | fn take_list_view<IndexType, OffsetType>( |
631 | 8 | values: &GenericListViewArray<OffsetType::Native>, |
632 | 8 | indices: &PrimitiveArray<IndexType>, |
633 | 8 | ) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError> |
634 | 8 | where |
635 | 8 | IndexType: ArrowPrimitiveType, |
636 | 8 | OffsetType: ArrowPrimitiveType, |
637 | 8 | OffsetType::Native: OffsetSizeTrait, |
638 | | { |
639 | 8 | let taken_offsets = take_native(values.offsets(), indices); |
640 | 8 | let taken_sizes = take_native(values.sizes(), indices); |
641 | 8 | let nulls = take_nulls(values.nulls(), indices); |
642 | | |
643 | 8 | let list_view_data = ArrayDataBuilder::new(values.data_type().clone()) |
644 | 8 | .len(indices.len()) |
645 | 8 | .nulls(nulls) |
646 | 8 | .buffers(vec![taken_offsets.into(), taken_sizes.into()]) |
647 | 8 | .child_data(vec![values.values().to_data()]); |
648 | | |
649 | | // SAFETY: all buffers and child nodes for ListView added in constructor |
650 | 8 | let list_view_data = unsafe { list_view_data.build_unchecked() }; |
651 | | |
652 | 8 | Ok(GenericListViewArray::<OffsetType::Native>::from( |
653 | 8 | list_view_data, |
654 | 8 | )) |
655 | 8 | } |
656 | | |
657 | | /// `take` implementation for `FixedSizeListArray` |
658 | | /// |
659 | | /// Calculates the index and indexed offset for the inner array, |
660 | | /// applying `take` on the inner array, then reconstructing a list array |
661 | | /// with the indexed offsets |
662 | 4 | fn take_fixed_size_list<IndexType: ArrowPrimitiveType>( |
663 | 4 | values: &FixedSizeListArray, |
664 | 4 | indices: &PrimitiveArray<IndexType>, |
665 | 4 | length: <UInt32Type as ArrowPrimitiveType>::Native, |
666 | 4 | ) -> Result<FixedSizeListArray, ArrowError> { |
667 | 4 | let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?0 ; |
668 | 4 | let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?0 ; |
669 | | |
670 | | // determine null count and null buffer, which are a function of `values` and `indices` |
671 | 4 | let num_bytes = bit_util::ceil(indices.len(), 8); |
672 | 4 | let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); |
673 | 4 | let null_slice = null_buf.as_slice_mut(); |
674 | | |
675 | 13 | for i in 0..indices4 .len4 () { |
676 | 13 | let index = indices |
677 | 13 | .value(i) |
678 | 13 | .to_usize() |
679 | 13 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed"0 .to_string0 ()))?0 ; |
680 | 13 | if !indices.is_valid(i) || values12 .is_null12 (index12 ) { |
681 | 3 | bit_util::unset_bit(null_slice, i); |
682 | 10 | } |
683 | | } |
684 | | |
685 | 4 | let list_data = ArrayDataBuilder::new(values.data_type().clone()) |
686 | 4 | .len(indices.len()) |
687 | 4 | .null_bit_buffer(Some(null_buf.into())) |
688 | 4 | .offset(0) |
689 | 4 | .add_child_data(taken.into_data()); |
690 | | |
691 | 4 | let list_data = unsafe { list_data.build_unchecked() }; |
692 | | |
693 | 4 | Ok(FixedSizeListArray::from(list_data)) |
694 | 4 | } |
695 | | |
696 | 0 | fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>( |
697 | 0 | values: &FixedSizeBinaryArray, |
698 | 0 | indices: &PrimitiveArray<IndexType>, |
699 | 0 | size: i32, |
700 | 0 | ) -> Result<FixedSizeBinaryArray, ArrowError> { |
701 | 0 | let nulls = values.nulls(); |
702 | 0 | let array_iter = indices |
703 | 0 | .values() |
704 | 0 | .iter() |
705 | 0 | .map(|idx| { |
706 | 0 | let idx = maybe_usize::<IndexType::Native>(*idx)?; |
707 | 0 | if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) { |
708 | 0 | Ok(Some(values.value(idx))) |
709 | | } else { |
710 | 0 | Ok(None) |
711 | | } |
712 | 0 | }) |
713 | 0 | .collect::<Result<Vec<_>, ArrowError>>()? |
714 | 0 | .into_iter(); |
715 | | |
716 | 0 | FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size) |
717 | 0 | } |
718 | | |
719 | | /// `take` implementation for dictionary arrays |
720 | | /// |
721 | | /// applies `take` to the keys of the dictionary array and returns a new dictionary array |
722 | | /// with the same dictionary values and reordered keys |
723 | 1 | fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>( |
724 | 1 | values: &DictionaryArray<T>, |
725 | 1 | indices: &PrimitiveArray<I>, |
726 | 1 | ) -> Result<DictionaryArray<T>, ArrowError> { |
727 | 1 | let new_keys = take_primitive(values.keys(), indices)?0 ; |
728 | 1 | Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) }) |
729 | 1 | } |
730 | | |
731 | | /// `take` implementation for run arrays |
732 | | /// |
733 | | /// Finds physical indices for the given logical indices and builds output run array |
734 | | /// by taking values in the input run_array.values at the physical indices. |
735 | | /// The output run array will be run encoded on the physical indices and not on output values. |
736 | | /// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]` |
737 | | /// would be converted to `physical_indices=[1,1,3,3]` which will be used to build |
738 | | /// output `RunArray{ run_ends=[2,4], values=[2,2] }`. |
739 | 1 | fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>( |
740 | 1 | run_array: &RunArray<T>, |
741 | 1 | logical_indices: &PrimitiveArray<I>, |
742 | 1 | ) -> Result<RunArray<T>, ArrowError> { |
743 | | // get physical indices for the input logical indices |
744 | 1 | let physical_indices = run_array.get_physical_indices(logical_indices.values())?0 ; |
745 | | |
746 | | // Run encode the physical indices into new_run_ends_builder |
747 | | // Keep track of the physical indices to take in take_value_indices |
748 | | // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`. |
749 | 1 | let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1); |
750 | 1 | let mut take_value_indices = BufferBuilder::<I::Native>::new(1); |
751 | 1 | let mut new_physical_len = 1; |
752 | 6 | for ix in 1..physical_indices1 .len1 () { |
753 | 6 | if physical_indices[ix] != physical_indices[ix - 1] { |
754 | 4 | take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap()); |
755 | 4 | new_run_ends_builder.append(T::Native::from_usize(ix).unwrap()); |
756 | 4 | new_physical_len += 1; |
757 | 4 | }2 |
758 | | } |
759 | 1 | take_value_indices |
760 | 1 | .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap()); |
761 | 1 | new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap()); |
762 | 1 | let new_run_ends = unsafe { |
763 | | // Safety: |
764 | | // The function builds a valid run_ends array and hence need not be validated. |
765 | 1 | ArrayDataBuilder::new(T::DATA_TYPE) |
766 | 1 | .len(new_physical_len) |
767 | 1 | .null_count(0) |
768 | 1 | .add_buffer(new_run_ends_builder.finish()) |
769 | 1 | .build_unchecked() |
770 | | }; |
771 | | |
772 | 1 | let take_value_indices: PrimitiveArray<I> = unsafe { |
773 | | // Safety: |
774 | | // The function builds a valid take_value_indices array and hence need not be validated. |
775 | 1 | ArrayDataBuilder::new(I::DATA_TYPE) |
776 | 1 | .len(new_physical_len) |
777 | 1 | .null_count(0) |
778 | 1 | .add_buffer(take_value_indices.finish()) |
779 | 1 | .build_unchecked() |
780 | 1 | .into() |
781 | | }; |
782 | | |
783 | 1 | let new_values = take(run_array.values(), &take_value_indices, None)?0 ; |
784 | | |
785 | 1 | let builder = ArrayDataBuilder::new(run_array.data_type().clone()) |
786 | 1 | .len(physical_indices.len()) |
787 | 1 | .add_child_data(new_run_ends) |
788 | 1 | .add_child_data(new_values.into_data()); |
789 | 1 | let array_data = unsafe { |
790 | | // Safety: |
791 | | // This function builds a valid run array and hence can skip validation. |
792 | 1 | builder.build_unchecked() |
793 | | }; |
794 | 1 | Ok(array_data.into()) |
795 | 1 | } |
796 | | |
797 | | /// Takes/filters a list array's inner data using the offsets of the list array. |
798 | | /// |
799 | | /// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns |
800 | | /// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2 |
801 | | /// elements) |
802 | | #[allow(clippy::type_complexity)] |
803 | 10 | fn take_value_indices_from_list<IndexType, OffsetType>( |
804 | 10 | list: &GenericListArray<OffsetType::Native>, |
805 | 10 | indices: &PrimitiveArray<IndexType>, |
806 | 10 | ) -> Result< |
807 | 10 | ( |
808 | 10 | PrimitiveArray<OffsetType>, |
809 | 10 | Vec<OffsetType::Native>, |
810 | 10 | MutableBuffer, |
811 | 10 | ), |
812 | 10 | ArrowError, |
813 | 10 | > |
814 | 10 | where |
815 | 10 | IndexType: ArrowPrimitiveType, |
816 | 10 | OffsetType: ArrowPrimitiveType, |
817 | 10 | OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One, |
818 | 10 | PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>, |
819 | | { |
820 | | // TODO: benchmark this function, there might be a faster unsafe alternative |
821 | 10 | let offsets: &[OffsetType::Native] = list.value_offsets(); |
822 | | |
823 | 10 | let mut new_offsets = Vec::with_capacity(indices.len()); |
824 | 10 | let mut values = Vec::new(); |
825 | 10 | let mut current_offset = OffsetType::Native::zero(); |
826 | | // add first offset |
827 | 10 | new_offsets.push(OffsetType::Native::zero()); |
828 | | |
829 | | // Initialize null buffer |
830 | 10 | let num_bytes = bit_util::ceil(indices.len(), 8); |
831 | 10 | let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); |
832 | 10 | let null_slice = null_buf.as_slice_mut(); |
833 | | |
834 | | // compute the value indices, and set offsets accordingly |
835 | 36 | for i in 0..indices10 .len10 () { |
836 | 36 | if indices.is_valid(i) { |
837 | 30 | let ix29 = indices |
838 | 30 | .value(i) |
839 | 30 | .to_usize() |
840 | 30 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed"0 .to_string0 ()))?1 ; |
841 | 29 | let start = offsets[ix]; |
842 | 29 | let end = offsets[ix + 1]; |
843 | 29 | current_offset += end - start; |
844 | 29 | new_offsets.push(current_offset); |
845 | | |
846 | 29 | let mut curr = start; |
847 | | |
848 | | // if start == end, this slot is empty |
849 | 96 | while curr < end { |
850 | 67 | values.push(curr); |
851 | 67 | curr += One::one(); |
852 | 67 | } |
853 | 29 | if !list.is_valid(ix) { |
854 | 2 | bit_util::unset_bit(null_slice, i); |
855 | 27 | } |
856 | 6 | } else { |
857 | 6 | bit_util::unset_bit(null_slice, i); |
858 | 6 | new_offsets.push(current_offset); |
859 | 6 | } |
860 | | } |
861 | | |
862 | 9 | Ok(( |
863 | 9 | PrimitiveArray::<OffsetType>::from(values), |
864 | 9 | new_offsets, |
865 | 9 | null_buf, |
866 | 9 | )) |
867 | 10 | } |
868 | | |
869 | | /// Takes/filters a fixed size list array's inner data using the offsets of the list array. |
870 | 6 | fn take_value_indices_from_fixed_size_list<IndexType>( |
871 | 6 | list: &FixedSizeListArray, |
872 | 6 | indices: &PrimitiveArray<IndexType>, |
873 | 6 | length: <UInt32Type as ArrowPrimitiveType>::Native, |
874 | 6 | ) -> Result<PrimitiveArray<UInt32Type>, ArrowError> |
875 | 6 | where |
876 | 6 | IndexType: ArrowPrimitiveType, |
877 | | { |
878 | 6 | let mut values = UInt32Builder::with_capacity(length as usize * indices.len()); |
879 | | |
880 | 21 | for i in 0..indices6 .len6 () { |
881 | 21 | if indices.is_valid(i) { |
882 | 20 | let index = indices |
883 | 20 | .value(i) |
884 | 20 | .to_usize() |
885 | 20 | .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed"0 .to_string0 ()))?0 ; |
886 | 20 | let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native; |
887 | | |
888 | | // Safety: Range always has known length. |
889 | 20 | unsafe { |
890 | 20 | values.append_trusted_len_iter(start..start + length); |
891 | 20 | } |
892 | 1 | } else { |
893 | 1 | values.append_nulls(length as usize); |
894 | 1 | } |
895 | | } |
896 | | |
897 | 6 | Ok(values.finish()) |
898 | 6 | } |
899 | | |
900 | | /// To avoid generating take implementations for every index type, instead we |
901 | | /// only generate for UInt32 and UInt64 and coerce inputs to these types |
902 | | trait ToIndices { |
903 | | type T: ArrowPrimitiveType; |
904 | | |
905 | | fn to_indices(&self) -> PrimitiveArray<Self::T>; |
906 | | } |
907 | | |
908 | | macro_rules! to_indices_reinterpret { |
909 | | ($t:ty, $o:ty) => { |
910 | | impl ToIndices for PrimitiveArray<$t> { |
911 | | type T = $o; |
912 | | |
913 | 17 | fn to_indices(&self) -> PrimitiveArray<$o> { |
914 | 17 | let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len()); |
915 | 17 | PrimitiveArray::new(cast, self.nulls().cloned()) |
916 | 17 | } |
917 | | } |
918 | | }; |
919 | | } |
920 | | |
921 | | macro_rules! to_indices_identity { |
922 | | ($t:ty) => { |
923 | | impl ToIndices for PrimitiveArray<$t> { |
924 | | type T = $t; |
925 | | |
926 | 58 | fn to_indices(&self) -> PrimitiveArray<$t> { |
927 | 58 | self.clone() |
928 | 58 | } |
929 | | } |
930 | | }; |
931 | | } |
932 | | |
933 | | macro_rules! to_indices_widening { |
934 | | ($t:ty, $o:ty) => { |
935 | | impl ToIndices for PrimitiveArray<$t> { |
936 | | type T = UInt32Type; |
937 | | |
938 | 3 | fn to_indices(&self) -> PrimitiveArray<$o> { |
939 | 15 | let cast3 = self.values().iter()3 .copied3 ().map3 (|x| x as _).collect3 (); |
940 | 3 | PrimitiveArray::new(cast, self.nulls().cloned()) |
941 | 3 | } |
942 | | } |
943 | | }; |
944 | | } |
945 | | |
946 | | to_indices_widening!(UInt8Type, UInt32Type); |
947 | | to_indices_widening!(Int8Type, UInt32Type); |
948 | | |
949 | | to_indices_widening!(UInt16Type, UInt32Type); |
950 | | to_indices_widening!(Int16Type, UInt32Type); |
951 | | |
952 | | to_indices_identity!(UInt32Type); |
953 | | to_indices_reinterpret!(Int32Type, UInt32Type); |
954 | | |
955 | | to_indices_identity!(UInt64Type); |
956 | | to_indices_reinterpret!(Int64Type, UInt64Type); |
957 | | |
958 | | /// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes. |
959 | | /// |
960 | | /// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`]. |
961 | | /// |
962 | | /// # Example |
963 | | /// ``` |
964 | | /// # use std::sync::Arc; |
965 | | /// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch}; |
966 | | /// # use arrow_schema::{DataType, Field, Schema}; |
967 | | /// # use arrow_select::take::take_record_batch; |
968 | | /// |
969 | | /// let schema = Arc::new(Schema::new(vec![ |
970 | | /// Field::new("a", DataType::Int32, true), |
971 | | /// Field::new("b", DataType::Utf8, true), |
972 | | /// ])); |
973 | | /// let batch = RecordBatch::try_new( |
974 | | /// schema.clone(), |
975 | | /// vec![ |
976 | | /// Arc::new(Int32Array::from_iter_values(0..20)), |
977 | | /// Arc::new(StringArray::from_iter_values( |
978 | | /// (0..20).map(|i| format!("str-{}", i)), |
979 | | /// )), |
980 | | /// ], |
981 | | /// ) |
982 | | /// .unwrap(); |
983 | | /// |
984 | | /// let indices = UInt32Array::from(vec![1, 5, 10]); |
985 | | /// let taken = take_record_batch(&batch, &indices).unwrap(); |
986 | | /// |
987 | | /// let expected = RecordBatch::try_new( |
988 | | /// schema, |
989 | | /// vec![ |
990 | | /// Arc::new(Int32Array::from(vec![1, 5, 10])), |
991 | | /// Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])), |
992 | | /// ], |
993 | | /// ) |
994 | | /// .unwrap(); |
995 | | /// assert_eq!(taken, expected); |
996 | | /// ``` |
997 | 0 | pub fn take_record_batch( |
998 | 0 | record_batch: &RecordBatch, |
999 | 0 | indices: &dyn Array, |
1000 | 0 | ) -> Result<RecordBatch, ArrowError> { |
1001 | 0 | let columns = record_batch |
1002 | 0 | .columns() |
1003 | 0 | .iter() |
1004 | 0 | .map(|c| take(c, indices, None)) |
1005 | 0 | .collect::<Result<Vec<_>, _>>()?; |
1006 | 0 | RecordBatch::try_new(record_batch.schema(), columns) |
1007 | 0 | } |
1008 | | |
1009 | | #[cfg(test)] |
1010 | | mod tests { |
1011 | | use super::*; |
1012 | | use arrow_array::builder::*; |
1013 | | use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; |
1014 | | use arrow_data::ArrayData; |
1015 | | use arrow_schema::{Field, Fields, TimeUnit, UnionFields}; |
1016 | | use num_traits::ToPrimitive; |
1017 | | |
1018 | 2 | fn test_take_decimal_arrays( |
1019 | 2 | data: Vec<Option<i128>>, |
1020 | 2 | index: &UInt32Array, |
1021 | 2 | options: Option<TakeOptions>, |
1022 | 2 | expected_data: Vec<Option<i128>>, |
1023 | 2 | precision: &u8, |
1024 | 2 | scale: &i8, |
1025 | 2 | ) -> Result<(), ArrowError> { |
1026 | 2 | let output = data |
1027 | 2 | .into_iter() |
1028 | 2 | .collect::<Decimal128Array>() |
1029 | 2 | .with_precision_and_scale(*precision, *scale) |
1030 | 2 | .unwrap(); |
1031 | | |
1032 | 2 | let expected = expected_data |
1033 | 2 | .into_iter() |
1034 | 2 | .collect::<Decimal128Array>() |
1035 | 2 | .with_precision_and_scale(*precision, *scale) |
1036 | 2 | .unwrap(); |
1037 | | |
1038 | 2 | let expected = Arc::new(expected) as ArrayRef; |
1039 | 2 | let output = take(&output, index, options).unwrap(); |
1040 | 2 | assert_eq!(&output, &expected); |
1041 | 2 | Ok(()) |
1042 | 2 | } |
1043 | | |
1044 | 4 | fn test_take_boolean_arrays( |
1045 | 4 | data: Vec<Option<bool>>, |
1046 | 4 | index: &UInt32Array, |
1047 | 4 | options: Option<TakeOptions>, |
1048 | 4 | expected_data: Vec<Option<bool>>, |
1049 | 4 | ) { |
1050 | 4 | let output = BooleanArray::from(data); |
1051 | 4 | let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef; |
1052 | 4 | let output = take(&output, index, options).unwrap(); |
1053 | 4 | assert_eq!(&output, &expected) |
1054 | 4 | } |
1055 | | |
1056 | 23 | fn test_take_primitive_arrays<T>( |
1057 | 23 | data: Vec<Option<T::Native>>, |
1058 | 23 | index: &UInt32Array, |
1059 | 23 | options: Option<TakeOptions>, |
1060 | 23 | expected_data: Vec<Option<T::Native>>, |
1061 | 23 | ) -> Result<(), ArrowError> |
1062 | 23 | where |
1063 | 23 | T: ArrowPrimitiveType, |
1064 | 23 | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
1065 | | { |
1066 | 23 | let output = PrimitiveArray::<T>::from(data); |
1067 | 23 | let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef; |
1068 | 23 | let output22 = take(&output, index, options)?1 ; |
1069 | 22 | assert_eq!(&output, &expected); |
1070 | 21 | Ok(()) |
1071 | 22 | } |
1072 | | |
1073 | 1 | fn test_take_primitive_arrays_non_null<T>( |
1074 | 1 | data: Vec<T::Native>, |
1075 | 1 | index: &UInt32Array, |
1076 | 1 | options: Option<TakeOptions>, |
1077 | 1 | expected_data: Vec<Option<T::Native>>, |
1078 | 1 | ) -> Result<(), ArrowError> |
1079 | 1 | where |
1080 | 1 | T: ArrowPrimitiveType, |
1081 | 1 | PrimitiveArray<T>: From<Vec<T::Native>>, |
1082 | 1 | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
1083 | | { |
1084 | 1 | let output = PrimitiveArray::<T>::from(data); |
1085 | 1 | let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef; |
1086 | 1 | let output = take(&output, index, options)?0 ; |
1087 | 1 | assert_eq!(&output, &expected); |
1088 | 1 | Ok(()) |
1089 | 1 | } |
1090 | | |
1091 | 8 | fn test_take_impl_primitive_arrays<T, I>( |
1092 | 8 | data: Vec<Option<T::Native>>, |
1093 | 8 | index: &PrimitiveArray<I>, |
1094 | 8 | options: Option<TakeOptions>, |
1095 | 8 | expected_data: Vec<Option<T::Native>>, |
1096 | 8 | ) where |
1097 | 8 | T: ArrowPrimitiveType, |
1098 | 8 | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
1099 | 8 | I: ArrowPrimitiveType, |
1100 | | { |
1101 | 8 | let output = PrimitiveArray::<T>::from(data); |
1102 | 8 | let expected = PrimitiveArray::<T>::from(expected_data); |
1103 | 8 | let output = take(&output, index, options).unwrap(); |
1104 | 8 | let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); |
1105 | 8 | assert_eq!(output, &expected) |
1106 | 8 | } |
1107 | | |
1108 | | // create a simple struct for testing purposes |
1109 | 5 | fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray { |
1110 | 5 | let mut struct_builder = StructBuilder::new( |
1111 | 5 | Fields::from(vec![ |
1112 | 5 | Field::new("a", DataType::Boolean, true), |
1113 | 5 | Field::new("b", DataType::Int32, true), |
1114 | | ]), |
1115 | 5 | vec![ |
1116 | 5 | Box::new(BooleanBuilder::with_capacity(values.len())), |
1117 | 5 | Box::new(Int32Builder::with_capacity(values.len())), |
1118 | | ], |
1119 | | ); |
1120 | | |
1121 | 32 | for value27 in values { |
1122 | 27 | struct_builder |
1123 | 27 | .field_builder::<BooleanBuilder>(0) |
1124 | 27 | .unwrap() |
1125 | 27 | .append_option(value.and_then(|v| v.0)); |
1126 | 27 | struct_builder |
1127 | 27 | .field_builder::<Int32Builder>(1) |
1128 | 27 | .unwrap() |
1129 | 27 | .append_option(value.and_then(|v| v.1)); |
1130 | 27 | struct_builder.append(value.is_some()); |
1131 | | } |
1132 | 5 | struct_builder.finish() |
1133 | 5 | } |
1134 | | |
1135 | | #[test] |
1136 | 1 | fn test_take_decimal128_non_null_indices() { |
1137 | 1 | let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); |
1138 | 1 | let precision: u8 = 10; |
1139 | 1 | let scale: i8 = 5; |
1140 | 1 | test_take_decimal_arrays( |
1141 | 1 | vec![None, Some(3), Some(5), Some(2), Some(3), None], |
1142 | 1 | &index, |
1143 | 1 | None, |
1144 | 1 | vec![None, None, Some(2), Some(3), Some(3), Some(5)], |
1145 | 1 | &precision, |
1146 | 1 | &scale, |
1147 | | ) |
1148 | 1 | .unwrap(); |
1149 | 1 | } |
1150 | | |
1151 | | #[test] |
1152 | 1 | fn test_take_decimal128() { |
1153 | 1 | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1154 | 1 | let precision: u8 = 10; |
1155 | 1 | let scale: i8 = 5; |
1156 | 1 | test_take_decimal_arrays( |
1157 | 1 | vec![Some(0), Some(1), Some(2), Some(3), Some(4)], |
1158 | 1 | &index, |
1159 | 1 | None, |
1160 | 1 | vec![Some(3), None, Some(1), Some(3), Some(2)], |
1161 | 1 | &precision, |
1162 | 1 | &scale, |
1163 | | ) |
1164 | 1 | .unwrap(); |
1165 | 1 | } |
1166 | | |
1167 | | #[test] |
1168 | 1 | fn test_take_primitive_non_null_indices() { |
1169 | 1 | let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); |
1170 | 1 | test_take_primitive_arrays::<Int8Type>( |
1171 | 1 | vec![None, Some(3), Some(5), Some(2), Some(3), None], |
1172 | 1 | &index, |
1173 | 1 | None, |
1174 | 1 | vec![None, None, Some(2), Some(3), Some(3), Some(5)], |
1175 | | ) |
1176 | 1 | .unwrap(); |
1177 | 1 | } |
1178 | | |
1179 | | #[test] |
1180 | 1 | fn test_take_primitive_non_null_values() { |
1181 | 1 | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1182 | 1 | test_take_primitive_arrays::<Int8Type>( |
1183 | 1 | vec![Some(0), Some(1), Some(2), Some(3), Some(4)], |
1184 | 1 | &index, |
1185 | 1 | None, |
1186 | 1 | vec![Some(3), None, Some(1), Some(3), Some(2)], |
1187 | | ) |
1188 | 1 | .unwrap(); |
1189 | 1 | } |
1190 | | |
1191 | | #[test] |
1192 | 1 | fn test_take_primitive_non_null() { |
1193 | 1 | let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); |
1194 | 1 | test_take_primitive_arrays::<Int8Type>( |
1195 | 1 | vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)], |
1196 | 1 | &index, |
1197 | 1 | None, |
1198 | 1 | vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)], |
1199 | | ) |
1200 | 1 | .unwrap(); |
1201 | 1 | } |
1202 | | |
1203 | | #[test] |
1204 | 1 | fn test_take_primitive_nullable_indices_non_null_values_with_offset() { |
1205 | 1 | let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]); |
1206 | 1 | let index = index.slice(2, 4); |
1207 | 1 | let index = index.as_any().downcast_ref::<UInt32Array>().unwrap(); |
1208 | | |
1209 | 1 | assert_eq!( |
1210 | | index, |
1211 | 1 | &UInt32Array::from(vec![Some(2), Some(3), None, None]) |
1212 | | ); |
1213 | | |
1214 | 1 | test_take_primitive_arrays_non_null::<Int64Type>( |
1215 | 1 | vec![0, 10, 20, 30, 40, 50], |
1216 | 1 | index, |
1217 | 1 | None, |
1218 | 1 | vec![Some(20), Some(30), None, None], |
1219 | | ) |
1220 | 1 | .unwrap(); |
1221 | 1 | } |
1222 | | |
1223 | | #[test] |
1224 | 1 | fn test_take_primitive_nullable_indices_nullable_values_with_offset() { |
1225 | 1 | let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]); |
1226 | 1 | let index = index.slice(2, 4); |
1227 | 1 | let index = index.as_any().downcast_ref::<UInt32Array>().unwrap(); |
1228 | | |
1229 | 1 | assert_eq!( |
1230 | | index, |
1231 | 1 | &UInt32Array::from(vec![Some(2), Some(3), None, None]) |
1232 | | ); |
1233 | | |
1234 | 1 | test_take_primitive_arrays::<Int64Type>( |
1235 | 1 | vec![None, None, Some(20), Some(30), Some(40), Some(50)], |
1236 | 1 | index, |
1237 | 1 | None, |
1238 | 1 | vec![Some(20), Some(30), None, None], |
1239 | | ) |
1240 | 1 | .unwrap(); |
1241 | 1 | } |
1242 | | |
1243 | | #[test] |
1244 | 1 | fn test_take_primitive() { |
1245 | 1 | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1246 | | |
1247 | | // int8 |
1248 | 1 | test_take_primitive_arrays::<Int8Type>( |
1249 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1250 | 1 | &index, |
1251 | 1 | None, |
1252 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1253 | | ) |
1254 | 1 | .unwrap(); |
1255 | | |
1256 | | // int16 |
1257 | 1 | test_take_primitive_arrays::<Int16Type>( |
1258 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1259 | 1 | &index, |
1260 | 1 | None, |
1261 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1262 | | ) |
1263 | 1 | .unwrap(); |
1264 | | |
1265 | | // int32 |
1266 | 1 | test_take_primitive_arrays::<Int32Type>( |
1267 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1268 | 1 | &index, |
1269 | 1 | None, |
1270 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1271 | | ) |
1272 | 1 | .unwrap(); |
1273 | | |
1274 | | // int64 |
1275 | 1 | test_take_primitive_arrays::<Int64Type>( |
1276 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1277 | 1 | &index, |
1278 | 1 | None, |
1279 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1280 | | ) |
1281 | 1 | .unwrap(); |
1282 | | |
1283 | | // uint8 |
1284 | 1 | test_take_primitive_arrays::<UInt8Type>( |
1285 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1286 | 1 | &index, |
1287 | 1 | None, |
1288 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1289 | | ) |
1290 | 1 | .unwrap(); |
1291 | | |
1292 | | // uint16 |
1293 | 1 | test_take_primitive_arrays::<UInt16Type>( |
1294 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1295 | 1 | &index, |
1296 | 1 | None, |
1297 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1298 | | ) |
1299 | 1 | .unwrap(); |
1300 | | |
1301 | | // uint32 |
1302 | 1 | test_take_primitive_arrays::<UInt32Type>( |
1303 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1304 | 1 | &index, |
1305 | 1 | None, |
1306 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1307 | | ) |
1308 | 1 | .unwrap(); |
1309 | | |
1310 | | // int64 |
1311 | 1 | test_take_primitive_arrays::<Int64Type>( |
1312 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1313 | 1 | &index, |
1314 | 1 | None, |
1315 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1316 | | ) |
1317 | 1 | .unwrap(); |
1318 | | |
1319 | | // interval_year_month |
1320 | 1 | test_take_primitive_arrays::<IntervalYearMonthType>( |
1321 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1322 | 1 | &index, |
1323 | 1 | None, |
1324 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1325 | | ) |
1326 | 1 | .unwrap(); |
1327 | | |
1328 | | // interval_day_time |
1329 | 1 | let v1 = IntervalDayTime::new(0, 0); |
1330 | 1 | let v2 = IntervalDayTime::new(2, 0); |
1331 | 1 | let v3 = IntervalDayTime::new(-15, 0); |
1332 | 1 | test_take_primitive_arrays::<IntervalDayTimeType>( |
1333 | 1 | vec![Some(v1), None, Some(v2), Some(v3), None], |
1334 | 1 | &index, |
1335 | 1 | None, |
1336 | 1 | vec![Some(v3), None, None, Some(v3), Some(v2)], |
1337 | | ) |
1338 | 1 | .unwrap(); |
1339 | | |
1340 | | // interval_month_day_nano |
1341 | 1 | let v1 = IntervalMonthDayNano::new(0, 0, 0); |
1342 | 1 | let v2 = IntervalMonthDayNano::new(2, 0, 0); |
1343 | 1 | let v3 = IntervalMonthDayNano::new(-15, 0, 0); |
1344 | 1 | test_take_primitive_arrays::<IntervalMonthDayNanoType>( |
1345 | 1 | vec![Some(v1), None, Some(v2), Some(v3), None], |
1346 | 1 | &index, |
1347 | 1 | None, |
1348 | 1 | vec![Some(v3), None, None, Some(v3), Some(v2)], |
1349 | | ) |
1350 | 1 | .unwrap(); |
1351 | | |
1352 | | // duration_second |
1353 | 1 | test_take_primitive_arrays::<DurationSecondType>( |
1354 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1355 | 1 | &index, |
1356 | 1 | None, |
1357 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1358 | | ) |
1359 | 1 | .unwrap(); |
1360 | | |
1361 | | // duration_millisecond |
1362 | 1 | test_take_primitive_arrays::<DurationMillisecondType>( |
1363 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1364 | 1 | &index, |
1365 | 1 | None, |
1366 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1367 | | ) |
1368 | 1 | .unwrap(); |
1369 | | |
1370 | | // duration_microsecond |
1371 | 1 | test_take_primitive_arrays::<DurationMicrosecondType>( |
1372 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1373 | 1 | &index, |
1374 | 1 | None, |
1375 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1376 | | ) |
1377 | 1 | .unwrap(); |
1378 | | |
1379 | | // duration_nanosecond |
1380 | 1 | test_take_primitive_arrays::<DurationNanosecondType>( |
1381 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1382 | 1 | &index, |
1383 | 1 | None, |
1384 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1385 | | ) |
1386 | 1 | .unwrap(); |
1387 | | |
1388 | | // float32 |
1389 | 1 | test_take_primitive_arrays::<Float32Type>( |
1390 | 1 | vec![Some(0.0), None, Some(2.21), Some(-3.1), None], |
1391 | 1 | &index, |
1392 | 1 | None, |
1393 | 1 | vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], |
1394 | | ) |
1395 | 1 | .unwrap(); |
1396 | | |
1397 | | // float64 |
1398 | 1 | test_take_primitive_arrays::<Float64Type>( |
1399 | 1 | vec![Some(0.0), None, Some(2.21), Some(-3.1), None], |
1400 | 1 | &index, |
1401 | 1 | None, |
1402 | 1 | vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], |
1403 | | ) |
1404 | 1 | .unwrap(); |
1405 | 1 | } |
1406 | | |
1407 | | #[test] |
1408 | 1 | fn test_take_preserve_timezone() { |
1409 | 1 | let index = Int64Array::from(vec![Some(0), None]); |
1410 | | |
1411 | 1 | let input = TimestampNanosecondArray::from(vec![ |
1412 | | 1_639_715_368_000_000_000, |
1413 | | 1_639_715_368_000_000_000, |
1414 | | ]) |
1415 | 1 | .with_timezone("UTC".to_string()); |
1416 | 1 | let result = take(&input, &index, None).unwrap(); |
1417 | 1 | match result.data_type() { |
1418 | 1 | DataType::Timestamp(TimeUnit::Nanosecond, tz) => { |
1419 | 1 | assert_eq!(tz.clone(), Some("UTC".into())) |
1420 | | } |
1421 | 0 | _ => panic!(), |
1422 | | } |
1423 | 1 | } |
1424 | | |
1425 | | #[test] |
1426 | 1 | fn test_take_impl_primitive_with_int64_indices() { |
1427 | 1 | let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1428 | | |
1429 | | // int16 |
1430 | 1 | test_take_impl_primitive_arrays::<Int16Type, Int64Type>( |
1431 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1432 | 1 | &index, |
1433 | 1 | None, |
1434 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1435 | | ); |
1436 | | |
1437 | | // int64 |
1438 | 1 | test_take_impl_primitive_arrays::<Int64Type, Int64Type>( |
1439 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1440 | 1 | &index, |
1441 | 1 | None, |
1442 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1443 | | ); |
1444 | | |
1445 | | // uint64 |
1446 | 1 | test_take_impl_primitive_arrays::<UInt64Type, Int64Type>( |
1447 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1448 | 1 | &index, |
1449 | 1 | None, |
1450 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1451 | | ); |
1452 | | |
1453 | | // duration_millisecond |
1454 | 1 | test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>( |
1455 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1456 | 1 | &index, |
1457 | 1 | None, |
1458 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1459 | | ); |
1460 | | |
1461 | | // float32 |
1462 | 1 | test_take_impl_primitive_arrays::<Float32Type, Int64Type>( |
1463 | 1 | vec![Some(0.0), None, Some(2.21), Some(-3.1), None], |
1464 | 1 | &index, |
1465 | 1 | None, |
1466 | 1 | vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], |
1467 | | ); |
1468 | 1 | } |
1469 | | |
1470 | | #[test] |
1471 | 1 | fn test_take_impl_primitive_with_uint8_indices() { |
1472 | 1 | let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1473 | | |
1474 | | // int16 |
1475 | 1 | test_take_impl_primitive_arrays::<Int16Type, UInt8Type>( |
1476 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
1477 | 1 | &index, |
1478 | 1 | None, |
1479 | 1 | vec![Some(3), None, None, Some(3), Some(2)], |
1480 | | ); |
1481 | | |
1482 | | // duration_millisecond |
1483 | 1 | test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>( |
1484 | 1 | vec![Some(0), None, Some(2), Some(-15), None], |
1485 | 1 | &index, |
1486 | 1 | None, |
1487 | 1 | vec![Some(-15), None, None, Some(-15), Some(2)], |
1488 | | ); |
1489 | | |
1490 | | // float32 |
1491 | 1 | test_take_impl_primitive_arrays::<Float32Type, UInt8Type>( |
1492 | 1 | vec![Some(0.0), None, Some(2.21), Some(-3.1), None], |
1493 | 1 | &index, |
1494 | 1 | None, |
1495 | 1 | vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], |
1496 | | ); |
1497 | 1 | } |
1498 | | |
1499 | | #[test] |
1500 | 1 | fn test_take_bool() { |
1501 | 1 | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); |
1502 | | // boolean |
1503 | 1 | test_take_boolean_arrays( |
1504 | 1 | vec![Some(false), None, Some(true), Some(false), None], |
1505 | 1 | &index, |
1506 | 1 | None, |
1507 | 1 | vec![Some(false), None, None, Some(false), Some(true)], |
1508 | | ); |
1509 | 1 | } |
1510 | | |
1511 | | #[test] |
1512 | 1 | fn test_take_bool_nullable_index() { |
1513 | | // indices where the masked invalid elements would be out of bounds |
1514 | 1 | let index_data = ArrayData::try_new( |
1515 | 1 | DataType::UInt32, |
1516 | | 6, |
1517 | 1 | Some(Buffer::from_iter(vec![ |
1518 | 1 | false, true, false, true, false, true, |
1519 | 1 | ])), |
1520 | | 0, |
1521 | 1 | vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])], |
1522 | 1 | vec![], |
1523 | | ) |
1524 | 1 | .unwrap(); |
1525 | 1 | let index = UInt32Array::from(index_data); |
1526 | 1 | test_take_boolean_arrays( |
1527 | 1 | vec![Some(true), None, Some(false)], |
1528 | 1 | &index, |
1529 | 1 | None, |
1530 | 1 | vec![None, Some(true), None, None, None, Some(false)], |
1531 | | ); |
1532 | 1 | } |
1533 | | |
1534 | | #[test] |
1535 | 1 | fn test_take_bool_nullable_index_nonnull_values() { |
1536 | | // indices where the masked invalid elements would be out of bounds |
1537 | 1 | let index_data = ArrayData::try_new( |
1538 | 1 | DataType::UInt32, |
1539 | | 6, |
1540 | 1 | Some(Buffer::from_iter(vec![ |
1541 | 1 | false, true, false, true, false, true, |
1542 | 1 | ])), |
1543 | | 0, |
1544 | 1 | vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])], |
1545 | 1 | vec![], |
1546 | | ) |
1547 | 1 | .unwrap(); |
1548 | 1 | let index = UInt32Array::from(index_data); |
1549 | 1 | test_take_boolean_arrays( |
1550 | 1 | vec![Some(true), Some(true), Some(false)], |
1551 | 1 | &index, |
1552 | 1 | None, |
1553 | 1 | vec![None, Some(true), None, Some(true), None, Some(false)], |
1554 | | ); |
1555 | 1 | } |
1556 | | |
1557 | | #[test] |
1558 | 1 | fn test_take_bool_with_offset() { |
1559 | 1 | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]); |
1560 | 1 | let index = index.slice(2, 4); |
1561 | 1 | let index = index |
1562 | 1 | .as_any() |
1563 | 1 | .downcast_ref::<PrimitiveArray<UInt32Type>>() |
1564 | 1 | .unwrap(); |
1565 | | |
1566 | | // boolean |
1567 | 1 | test_take_boolean_arrays( |
1568 | 1 | vec![Some(false), None, Some(true), Some(false), None], |
1569 | 1 | index, |
1570 | 1 | None, |
1571 | 1 | vec![None, Some(false), Some(true), None], |
1572 | | ); |
1573 | 1 | } |
1574 | | |
1575 | 2 | fn _test_take_string<'a, K>() |
1576 | 2 | where |
1577 | 2 | K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static, |
1578 | | { |
1579 | 2 | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]); |
1580 | | |
1581 | 2 | let array = K::from(vec![ |
1582 | 2 | Some("one"), |
1583 | 2 | None, |
1584 | 2 | Some("three"), |
1585 | 2 | Some("four"), |
1586 | 2 | Some("five"), |
1587 | | ]); |
1588 | 2 | let actual = take(&array, &index, None).unwrap(); |
1589 | 2 | assert_eq!(actual.len(), index.len()); |
1590 | | |
1591 | 2 | let actual = actual.as_any().downcast_ref::<K>().unwrap(); |
1592 | | |
1593 | 2 | let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]); |
1594 | | |
1595 | 2 | assert_eq!(actual, &expected); |
1596 | 2 | } |
1597 | | |
1598 | | #[test] |
1599 | 1 | fn test_take_string() { |
1600 | 1 | _test_take_string::<StringArray>() |
1601 | 1 | } |
1602 | | |
1603 | | #[test] |
1604 | 1 | fn test_take_large_string() { |
1605 | 1 | _test_take_string::<LargeStringArray>() |
1606 | 1 | } |
1607 | | |
1608 | | #[test] |
1609 | 1 | fn test_take_slice_string() { |
1610 | 1 | let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]); |
1611 | 1 | let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]); |
1612 | 1 | let indices_slice = indices.slice(1, 4); |
1613 | 1 | let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]); |
1614 | 1 | let result = take(&strings, &indices_slice, None).unwrap(); |
1615 | 1 | assert_eq!(result.as_ref(), &expected); |
1616 | 1 | } |
1617 | | |
1618 | 2 | fn _test_byte_view<T>() |
1619 | 2 | where |
1620 | 2 | T: ByteViewType, |
1621 | 2 | str: AsRef<T::Native>, |
1622 | 2 | T::Native: PartialEq, |
1623 | | { |
1624 | 2 | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]); |
1625 | 2 | let array = { |
1626 | | // ["hello", "world", null, "large payload over 12 bytes", "lulu"] |
1627 | 2 | let mut builder = GenericByteViewBuilder::<T>::new(); |
1628 | 2 | builder.append_value("hello"); |
1629 | 2 | builder.append_value("world"); |
1630 | 2 | builder.append_null(); |
1631 | 2 | builder.append_value("large payload over 12 bytes"); |
1632 | 2 | builder.append_value("lulu"); |
1633 | 2 | builder.finish() |
1634 | | }; |
1635 | | |
1636 | 2 | let actual = take(&array, &index, None).unwrap(); |
1637 | | |
1638 | 2 | assert_eq!(actual.len(), index.len()); |
1639 | | |
1640 | 2 | let expected = { |
1641 | | // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null] |
1642 | 2 | let mut builder = GenericByteViewBuilder::<T>::new(); |
1643 | 2 | builder.append_value("large payload over 12 bytes"); |
1644 | 2 | builder.append_null(); |
1645 | 2 | builder.append_value("world"); |
1646 | 2 | builder.append_value("large payload over 12 bytes"); |
1647 | 2 | builder.append_value("lulu"); |
1648 | 2 | builder.append_null(); |
1649 | 2 | builder.finish() |
1650 | | }; |
1651 | | |
1652 | 2 | assert_eq!(actual.as_ref(), &expected); |
1653 | 2 | } |
1654 | | |
1655 | | #[test] |
1656 | 1 | fn test_take_string_view() { |
1657 | 1 | _test_byte_view::<StringViewType>() |
1658 | 1 | } |
1659 | | |
1660 | | #[test] |
1661 | 1 | fn test_take_binary_view() { |
1662 | 1 | _test_byte_view::<BinaryViewType>() |
1663 | 1 | } |
1664 | | |
1665 | | macro_rules! test_take_list { |
1666 | | ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{ |
1667 | | // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]] |
1668 | | let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data(); |
1669 | | // Construct offsets |
1670 | | let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8]; |
1671 | | let value_offsets = Buffer::from_slice_ref(&value_offsets); |
1672 | | // Construct a list array from the above two |
1673 | | let list_data_type = |
1674 | | DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false))); |
1675 | | let list_data = ArrayData::builder(list_data_type.clone()) |
1676 | | .len(4) |
1677 | | .add_buffer(value_offsets) |
1678 | | .add_child_data(value_data) |
1679 | | .build() |
1680 | | .unwrap(); |
1681 | | let list_array = $list_array_type::from(list_data); |
1682 | | |
1683 | | // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]] |
1684 | | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]); |
1685 | | |
1686 | | let a = take(&list_array, &index, None).unwrap(); |
1687 | | let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap(); |
1688 | | |
1689 | | // construct a value array with expected results: |
1690 | | // [[2,3], null, [-1,-2,-1], [], [0,0,0]] |
1691 | | let expected_data = Int32Array::from(vec![ |
1692 | | Some(2), |
1693 | | Some(3), |
1694 | | Some(-1), |
1695 | | Some(-2), |
1696 | | Some(-1), |
1697 | | Some(0), |
1698 | | Some(0), |
1699 | | Some(0), |
1700 | | ]) |
1701 | | .into_data(); |
1702 | | // construct offsets |
1703 | | let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8]; |
1704 | | let expected_offsets = Buffer::from_slice_ref(&expected_offsets); |
1705 | | // construct list array from the two |
1706 | | let expected_list_data = ArrayData::builder(list_data_type) |
1707 | | .len(5) |
1708 | | // null buffer remains the same as only the indices have nulls |
1709 | | .nulls(index.nulls().cloned()) |
1710 | | .add_buffer(expected_offsets) |
1711 | | .add_child_data(expected_data) |
1712 | | .build() |
1713 | | .unwrap(); |
1714 | | let expected_list_array = $list_array_type::from(expected_list_data); |
1715 | | |
1716 | | assert_eq!(a, &expected_list_array); |
1717 | | }}; |
1718 | | } |
1719 | | |
1720 | | macro_rules! test_take_list_with_value_nulls { |
1721 | | ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{ |
1722 | | // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]] |
1723 | | let value_data = Int32Array::from(vec![ |
1724 | | Some(0), |
1725 | | None, |
1726 | | Some(0), |
1727 | | Some(-1), |
1728 | | Some(-2), |
1729 | | Some(3), |
1730 | | None, |
1731 | | Some(5), |
1732 | | None, |
1733 | | ]) |
1734 | | .into_data(); |
1735 | | // Construct offsets |
1736 | | let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9]; |
1737 | | let value_offsets = Buffer::from_slice_ref(&value_offsets); |
1738 | | // Construct a list array from the above two |
1739 | | let list_data_type = |
1740 | | DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true))); |
1741 | | let list_data = ArrayData::builder(list_data_type.clone()) |
1742 | | .len(4) |
1743 | | .add_buffer(value_offsets) |
1744 | | .null_bit_buffer(Some(Buffer::from([0b11111111]))) |
1745 | | .add_child_data(value_data) |
1746 | | .build() |
1747 | | .unwrap(); |
1748 | | let list_array = $list_array_type::from(list_data); |
1749 | | |
1750 | | // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]] |
1751 | | let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]); |
1752 | | |
1753 | | let a = take(&list_array, &index, None).unwrap(); |
1754 | | let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap(); |
1755 | | |
1756 | | // construct a value array with expected results: |
1757 | | // [[null], null, [-1,-2,3], [5,null], [0,null,0]] |
1758 | | let expected_data = Int32Array::from(vec![ |
1759 | | None, |
1760 | | Some(-1), |
1761 | | Some(-2), |
1762 | | Some(3), |
1763 | | Some(5), |
1764 | | None, |
1765 | | Some(0), |
1766 | | None, |
1767 | | Some(0), |
1768 | | ]) |
1769 | | .into_data(); |
1770 | | // construct offsets |
1771 | | let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9]; |
1772 | | let expected_offsets = Buffer::from_slice_ref(&expected_offsets); |
1773 | | // construct list array from the two |
1774 | | let expected_list_data = ArrayData::builder(list_data_type) |
1775 | | .len(5) |
1776 | | // null buffer remains the same as only the indices have nulls |
1777 | | .nulls(index.nulls().cloned()) |
1778 | | .add_buffer(expected_offsets) |
1779 | | .add_child_data(expected_data) |
1780 | | .build() |
1781 | | .unwrap(); |
1782 | | let expected_list_array = $list_array_type::from(expected_list_data); |
1783 | | |
1784 | | assert_eq!(a, &expected_list_array); |
1785 | | }}; |
1786 | | } |
1787 | | |
1788 | | macro_rules! test_take_list_with_nulls { |
1789 | | ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{ |
1790 | | // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]] |
1791 | | let value_data = Int32Array::from(vec![ |
1792 | | Some(0), |
1793 | | None, |
1794 | | Some(0), |
1795 | | Some(-1), |
1796 | | Some(-2), |
1797 | | Some(3), |
1798 | | Some(5), |
1799 | | None, |
1800 | | ]) |
1801 | | .into_data(); |
1802 | | // Construct offsets |
1803 | | let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8]; |
1804 | | let value_offsets = Buffer::from_slice_ref(&value_offsets); |
1805 | | // Construct a list array from the above two |
1806 | | let list_data_type = |
1807 | | DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true))); |
1808 | | let list_data = ArrayData::builder(list_data_type.clone()) |
1809 | | .len(4) |
1810 | | .add_buffer(value_offsets) |
1811 | | .null_bit_buffer(Some(Buffer::from([0b11111011]))) |
1812 | | .add_child_data(value_data) |
1813 | | .build() |
1814 | | .unwrap(); |
1815 | | let list_array = $list_array_type::from(list_data); |
1816 | | |
1817 | | // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]] |
1818 | | let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]); |
1819 | | |
1820 | | let a = take(&list_array, &index, None).unwrap(); |
1821 | | let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap(); |
1822 | | |
1823 | | // construct a value array with expected results: |
1824 | | // [null, null, [-1,-2,3], [5,null], [0,null,0]] |
1825 | | let expected_data = Int32Array::from(vec![ |
1826 | | Some(-1), |
1827 | | Some(-2), |
1828 | | Some(3), |
1829 | | Some(5), |
1830 | | None, |
1831 | | Some(0), |
1832 | | None, |
1833 | | Some(0), |
1834 | | ]) |
1835 | | .into_data(); |
1836 | | // construct offsets |
1837 | | let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8]; |
1838 | | let expected_offsets = Buffer::from_slice_ref(&expected_offsets); |
1839 | | // construct list array from the two |
1840 | | let mut null_bits: [u8; 1] = [0; 1]; |
1841 | | bit_util::set_bit(&mut null_bits, 2); |
1842 | | bit_util::set_bit(&mut null_bits, 3); |
1843 | | bit_util::set_bit(&mut null_bits, 4); |
1844 | | let expected_list_data = ArrayData::builder(list_data_type) |
1845 | | .len(5) |
1846 | | // null buffer must be recalculated as both values and indices have nulls |
1847 | | .null_bit_buffer(Some(Buffer::from(null_bits))) |
1848 | | .add_buffer(expected_offsets) |
1849 | | .add_child_data(expected_data) |
1850 | | .build() |
1851 | | .unwrap(); |
1852 | | let expected_list_array = $list_array_type::from(expected_list_data); |
1853 | | |
1854 | | assert_eq!(a, &expected_list_array); |
1855 | | }}; |
1856 | | } |
1857 | | |
1858 | 8 | fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>( |
1859 | 8 | values: Vec<Option<Vec<Option<ValuesType::Native>>>>, |
1860 | 8 | take_indices: Vec<Option<usize>>, |
1861 | 8 | expected: Vec<Option<Vec<Option<ValuesType::Native>>>>, |
1862 | 8 | mapper: F, |
1863 | 8 | ) where |
1864 | 8 | F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>, |
1865 | | { |
1866 | 8 | let mut list_view_array = |
1867 | 8 | GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new()); |
1868 | | |
1869 | 38 | for value30 in values { |
1870 | 30 | list_view_array.append_option(value); |
1871 | 30 | } |
1872 | 8 | let list_view_array = list_view_array.finish(); |
1873 | 8 | let list_view_array = mapper(list_view_array); |
1874 | | |
1875 | 8 | let mut indices = UInt64Builder::new(); |
1876 | 40 | for idx32 in take_indices { |
1877 | 32 | indices.append_option(idx.map(|i| i22 .to_u6422 ().unwrap22 ())); |
1878 | | } |
1879 | 8 | let indices = indices.finish(); |
1880 | | |
1881 | 8 | let taken = take(&list_view_array, &indices, None) |
1882 | 8 | .unwrap() |
1883 | 8 | .as_list_view() |
1884 | 8 | .clone(); |
1885 | | |
1886 | 8 | let mut expected_array = |
1887 | 8 | GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new()); |
1888 | 40 | for value32 in expected { |
1889 | 32 | expected_array.append_option(value); |
1890 | 32 | } |
1891 | 8 | let expected_array = expected_array.finish(); |
1892 | | |
1893 | 8 | assert_eq!(taken, expected_array); |
1894 | 8 | } |
1895 | | |
1896 | | macro_rules! list_view_test_case { |
1897 | | (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{ |
1898 | | test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x); |
1899 | | test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x); |
1900 | | }}; |
1901 | | (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{ |
1902 | | test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn); |
1903 | | test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn); |
1904 | | }}; |
1905 | | } |
1906 | | |
1907 | 3 | fn do_take_fixed_size_list_test<T>( |
1908 | 3 | length: <Int32Type as ArrowPrimitiveType>::Native, |
1909 | 3 | input_data: Vec<Option<Vec<Option<T::Native>>>>, |
1910 | 3 | indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>, |
1911 | 3 | expected_data: Vec<Option<Vec<Option<T::Native>>>>, |
1912 | 3 | ) where |
1913 | 3 | T: ArrowPrimitiveType, |
1914 | 3 | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
1915 | | { |
1916 | 3 | let indices = UInt32Array::from(indices); |
1917 | | |
1918 | 3 | let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length); |
1919 | | |
1920 | 3 | let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap(); |
1921 | | |
1922 | 3 | let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length); |
1923 | | |
1924 | 3 | assert_eq!(&output, &expected) |
1925 | 3 | } |
1926 | | |
1927 | | #[test] |
1928 | 1 | fn test_take_list() { |
1929 | 1 | test_take_list!(i32, List, ListArray); |
1930 | 1 | } |
1931 | | |
1932 | | #[test] |
1933 | 1 | fn test_take_large_list() { |
1934 | 1 | test_take_list!(i64, LargeList, LargeListArray); |
1935 | 1 | } |
1936 | | |
1937 | | #[test] |
1938 | 1 | fn test_take_list_with_value_nulls() { |
1939 | 1 | test_take_list_with_value_nulls!(i32, List, ListArray); |
1940 | 1 | } |
1941 | | |
1942 | | #[test] |
1943 | 1 | fn test_take_large_list_with_value_nulls() { |
1944 | 1 | test_take_list_with_value_nulls!(i64, LargeList, LargeListArray); |
1945 | 1 | } |
1946 | | |
1947 | | #[test] |
1948 | 1 | fn test_test_take_list_with_nulls() { |
1949 | 1 | test_take_list_with_nulls!(i32, List, ListArray); |
1950 | 1 | } |
1951 | | |
1952 | | #[test] |
1953 | 1 | fn test_test_take_large_list_with_nulls() { |
1954 | 1 | test_take_list_with_nulls!(i64, LargeList, LargeListArray); |
1955 | 1 | } |
1956 | | |
1957 | | #[test] |
1958 | 1 | fn test_test_take_list_view_reversed() { |
1959 | | // Take reversed indices |
1960 | 1 | list_view_test_case! { |
1961 | 1 | values: vec![ |
1962 | 1 | Some(vec![Some(1), None, Some(3)]), |
1963 | 1 | None, |
1964 | 1 | Some(vec![Some(7), Some(8), None]), |
1965 | | ], |
1966 | 1 | indices: vec![Some(2), Some(1), Some(0)], |
1967 | 1 | expected: vec![ |
1968 | 1 | Some(vec![Some(7), Some(8), None]), |
1969 | 1 | None, |
1970 | 1 | Some(vec![Some(1), None, Some(3)]), |
1971 | | ] |
1972 | | } |
1973 | 1 | } |
1974 | | |
1975 | | #[test] |
1976 | 1 | fn test_take_list_view_null_indices() { |
1977 | | // Take with null indices |
1978 | 1 | list_view_test_case! { |
1979 | 1 | values: vec![ |
1980 | 1 | Some(vec![Some(1), None, Some(3)]), |
1981 | 1 | None, |
1982 | 1 | Some(vec![Some(7), Some(8), None]), |
1983 | | ], |
1984 | 1 | indices: vec![None, Some(0), None], |
1985 | 1 | expected: vec![None, Some(vec![Some(1), None, Some(3)]), None] |
1986 | | } |
1987 | 1 | } |
1988 | | |
1989 | | #[test] |
1990 | 1 | fn test_take_list_view_null_values() { |
1991 | | // Take at null values |
1992 | 1 | list_view_test_case! { |
1993 | 1 | values: vec![ |
1994 | 1 | Some(vec![Some(1), None, Some(3)]), |
1995 | 1 | None, |
1996 | 1 | Some(vec![Some(7), Some(8), None]), |
1997 | | ], |
1998 | 1 | indices: vec![Some(1), Some(1), Some(1), None, None], |
1999 | 1 | expected: vec![None; 5] |
2000 | | } |
2001 | 1 | } |
2002 | | |
2003 | | #[test] |
2004 | 1 | fn test_take_list_view_sliced() { |
2005 | | // Take null indices/values, with slicing. |
2006 | 1 | list_view_test_case! { |
2007 | 1 | values: vec![ |
2008 | 1 | Some(vec![Some(1)]), |
2009 | 1 | None, |
2010 | 1 | None, |
2011 | 1 | Some(vec![Some(2), Some(3)]), |
2012 | 1 | Some(vec![Some(4), Some(5)]), |
2013 | 1 | None, |
2014 | | ], |
2015 | 2 | transform: |l| l.slice(2, 4), |
2016 | 1 | indices: vec![Some(0), Some(3), None, Some(1), Some(2)], |
2017 | 1 | expected: vec![ |
2018 | 1 | None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)]) |
2019 | | ] |
2020 | | } |
2021 | 1 | } |
2022 | | |
2023 | | #[test] |
2024 | 1 | fn test_take_fixed_size_list() { |
2025 | 1 | do_take_fixed_size_list_test::<Int32Type>( |
2026 | | 3, |
2027 | 1 | vec![ |
2028 | 1 | Some(vec![None, Some(1), Some(2)]), |
2029 | 1 | Some(vec![Some(3), Some(4), None]), |
2030 | 1 | Some(vec![Some(6), Some(7), Some(8)]), |
2031 | | ], |
2032 | 1 | vec![2, 1, 0], |
2033 | 1 | vec![ |
2034 | 1 | Some(vec![Some(6), Some(7), Some(8)]), |
2035 | 1 | Some(vec![Some(3), Some(4), None]), |
2036 | 1 | Some(vec![None, Some(1), Some(2)]), |
2037 | | ], |
2038 | | ); |
2039 | | |
2040 | 1 | do_take_fixed_size_list_test::<UInt8Type>( |
2041 | | 1, |
2042 | 1 | vec![ |
2043 | 1 | Some(vec![Some(1)]), |
2044 | 1 | Some(vec![Some(2)]), |
2045 | 1 | Some(vec![Some(3)]), |
2046 | 1 | Some(vec![Some(4)]), |
2047 | 1 | Some(vec![Some(5)]), |
2048 | 1 | Some(vec![Some(6)]), |
2049 | 1 | Some(vec![Some(7)]), |
2050 | 1 | Some(vec![Some(8)]), |
2051 | | ], |
2052 | 1 | vec![2, 7, 0], |
2053 | 1 | vec![ |
2054 | 1 | Some(vec![Some(3)]), |
2055 | 1 | Some(vec![Some(8)]), |
2056 | 1 | Some(vec![Some(1)]), |
2057 | | ], |
2058 | | ); |
2059 | | |
2060 | 1 | do_take_fixed_size_list_test::<UInt64Type>( |
2061 | | 3, |
2062 | 1 | vec![ |
2063 | 1 | Some(vec![Some(10), Some(11), Some(12)]), |
2064 | 1 | Some(vec![Some(13), Some(14), Some(15)]), |
2065 | 1 | None, |
2066 | 1 | Some(vec![Some(16), Some(17), Some(18)]), |
2067 | | ], |
2068 | 1 | vec![3, 2, 1, 2, 0], |
2069 | 1 | vec![ |
2070 | 1 | Some(vec![Some(16), Some(17), Some(18)]), |
2071 | 1 | None, |
2072 | 1 | Some(vec![Some(13), Some(14), Some(15)]), |
2073 | 1 | None, |
2074 | 1 | Some(vec![Some(10), Some(11), Some(12)]), |
2075 | | ], |
2076 | | ); |
2077 | 1 | } |
2078 | | |
2079 | | #[test] |
2080 | | #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")] |
2081 | 1 | fn test_take_list_out_of_bounds() { |
2082 | | // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]] |
2083 | 1 | let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data(); |
2084 | | // Construct offsets |
2085 | 1 | let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); |
2086 | | // Construct a list array from the above two |
2087 | 1 | let list_data_type = |
2088 | 1 | DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))); |
2089 | 1 | let list_data = ArrayData::builder(list_data_type) |
2090 | 1 | .len(3) |
2091 | 1 | .add_buffer(value_offsets) |
2092 | 1 | .add_child_data(value_data) |
2093 | 1 | .build() |
2094 | 1 | .unwrap(); |
2095 | 1 | let list_array = ListArray::from(list_data); |
2096 | | |
2097 | 1 | let index = UInt32Array::from(vec![1000]); |
2098 | | |
2099 | | // A panic is expected here since we have not supplied the check_bounds |
2100 | | // option. |
2101 | 1 | take(&list_array, &index, None).unwrap(); |
2102 | 1 | } |
2103 | | |
2104 | | #[test] |
2105 | 1 | fn test_take_map() { |
2106 | 1 | let values = Int32Array::from(vec![1, 2, 3, 4]); |
2107 | 1 | let array = |
2108 | 1 | MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4]) |
2109 | 1 | .unwrap(); |
2110 | | |
2111 | 1 | let index = UInt32Array::from(vec![0]); |
2112 | | |
2113 | 1 | let result = take(&array, &index, None).unwrap(); |
2114 | 1 | let expected: ArrayRef = Arc::new( |
2115 | 1 | MapArray::new_from_strings( |
2116 | 1 | vec!["a", "b", "c"].into_iter(), |
2117 | 1 | &values.slice(0, 3), |
2118 | 1 | &[0, 3], |
2119 | 1 | ) |
2120 | 1 | .unwrap(), |
2121 | 1 | ); |
2122 | 1 | assert_eq!(&expected, &result); |
2123 | 1 | } |
2124 | | |
2125 | | #[test] |
2126 | 1 | fn test_take_struct() { |
2127 | 1 | let array = create_test_struct(vec![ |
2128 | 1 | Some((Some(true), Some(42))), |
2129 | 1 | Some((Some(false), Some(28))), |
2130 | 1 | Some((Some(false), Some(19))), |
2131 | 1 | Some((Some(true), Some(31))), |
2132 | 1 | None, |
2133 | | ]); |
2134 | | |
2135 | 1 | let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]); |
2136 | 1 | let actual = take(&array, &index, None).unwrap(); |
2137 | 1 | let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap(); |
2138 | 1 | assert_eq!(index.len(), actual.len()); |
2139 | 1 | assert_eq!(1, actual.null_count()); |
2140 | | |
2141 | 1 | let expected = create_test_struct(vec![ |
2142 | 1 | Some((Some(true), Some(42))), |
2143 | 1 | Some((Some(true), Some(31))), |
2144 | 1 | Some((Some(false), Some(28))), |
2145 | 1 | Some((Some(true), Some(42))), |
2146 | 1 | Some((Some(false), Some(19))), |
2147 | 1 | None, |
2148 | | ]); |
2149 | | |
2150 | 1 | assert_eq!(&expected, actual); |
2151 | | |
2152 | 1 | let nulls = NullBuffer::from(&[false, true, false, true, false, true]); |
2153 | 1 | let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls)); |
2154 | 1 | let index = UInt32Array::from(vec![0, 2, 1, 4]); |
2155 | 1 | let actual = take(&empty_struct_arr, &index, None).unwrap(); |
2156 | | |
2157 | 1 | let expected_nulls = NullBuffer::from(&[false, false, true, false]); |
2158 | 1 | let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls)); |
2159 | 1 | assert_eq!(&expected_struct_arr, actual.as_struct()); |
2160 | 1 | } |
2161 | | |
2162 | | #[test] |
2163 | 1 | fn test_take_struct_with_null_indices() { |
2164 | 1 | let array = create_test_struct(vec![ |
2165 | 1 | Some((Some(true), Some(42))), |
2166 | 1 | Some((Some(false), Some(28))), |
2167 | 1 | Some((Some(false), Some(19))), |
2168 | 1 | Some((Some(true), Some(31))), |
2169 | 1 | None, |
2170 | | ]); |
2171 | | |
2172 | 1 | let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]); |
2173 | 1 | let actual = take(&array, &index, None).unwrap(); |
2174 | 1 | let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap(); |
2175 | 1 | assert_eq!(index.len(), actual.len()); |
2176 | 1 | assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array |
2177 | | |
2178 | 1 | let expected = create_test_struct(vec![ |
2179 | 1 | None, |
2180 | 1 | Some((Some(true), Some(31))), |
2181 | 1 | Some((Some(false), Some(28))), |
2182 | 1 | None, |
2183 | 1 | Some((Some(true), Some(42))), |
2184 | 1 | None, |
2185 | | ]); |
2186 | | |
2187 | 1 | assert_eq!(&expected, actual); |
2188 | 1 | } |
2189 | | |
2190 | | #[test] |
2191 | 1 | fn test_take_out_of_bounds() { |
2192 | 1 | let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]); |
2193 | 1 | let take_opt = TakeOptions { check_bounds: true }; |
2194 | | |
2195 | | // int64 |
2196 | 1 | let result = test_take_primitive_arrays::<Int64Type>( |
2197 | 1 | vec![Some(0), None, Some(2), Some(3), None], |
2198 | 1 | &index, |
2199 | 1 | Some(take_opt), |
2200 | 1 | vec![None], |
2201 | | ); |
2202 | 1 | assert!(result.is_err()); |
2203 | 1 | } |
2204 | | |
2205 | | #[test] |
2206 | | #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")] |
2207 | 1 | fn test_take_out_of_bounds_panic() { |
2208 | 1 | let index = UInt32Array::from(vec![Some(1000)]); |
2209 | | |
2210 | 1 | test_take_primitive_arrays::<Int64Type>( |
2211 | 1 | vec![Some(0), Some(1), Some(2), Some(3)], |
2212 | 1 | &index, |
2213 | 1 | None, |
2214 | 1 | vec![None], |
2215 | | ) |
2216 | 1 | .unwrap(); |
2217 | 1 | } |
2218 | | |
2219 | | #[test] |
2220 | 1 | fn test_null_array_smaller_than_indices() { |
2221 | 1 | let values = NullArray::new(2); |
2222 | 1 | let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); |
2223 | | |
2224 | 1 | let result = take(&values, &indices, None).unwrap(); |
2225 | 1 | let expected: ArrayRef = Arc::new(NullArray::new(3)); |
2226 | 1 | assert_eq!(&result, &expected); |
2227 | 1 | } |
2228 | | |
2229 | | #[test] |
2230 | 1 | fn test_null_array_larger_than_indices() { |
2231 | 1 | let values = NullArray::new(5); |
2232 | 1 | let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); |
2233 | | |
2234 | 1 | let result = take(&values, &indices, None).unwrap(); |
2235 | 1 | let expected: ArrayRef = Arc::new(NullArray::new(3)); |
2236 | 1 | assert_eq!(&result, &expected); |
2237 | 1 | } |
2238 | | |
2239 | | #[test] |
2240 | 1 | fn test_null_array_indices_out_of_bounds() { |
2241 | 1 | let values = NullArray::new(5); |
2242 | 1 | let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); |
2243 | | |
2244 | 1 | let result = take(&values, &indices, Some(TakeOptions { check_bounds: true })); |
2245 | 1 | assert_eq!( |
2246 | 1 | result.unwrap_err().to_string(), |
2247 | | "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries" |
2248 | | ); |
2249 | 1 | } |
2250 | | |
2251 | | #[test] |
2252 | 1 | fn test_take_dict() { |
2253 | 1 | let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new(); |
2254 | | |
2255 | 1 | dict_builder.append("foo").unwrap(); |
2256 | 1 | dict_builder.append("bar").unwrap(); |
2257 | 1 | dict_builder.append("").unwrap(); |
2258 | 1 | dict_builder.append_null(); |
2259 | 1 | dict_builder.append("foo").unwrap(); |
2260 | 1 | dict_builder.append("bar").unwrap(); |
2261 | 1 | dict_builder.append("bar").unwrap(); |
2262 | 1 | dict_builder.append("foo").unwrap(); |
2263 | | |
2264 | 1 | let array = dict_builder.finish(); |
2265 | 1 | let dict_values = array.values().clone(); |
2266 | 1 | let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap(); |
2267 | | |
2268 | 1 | let indices = UInt32Array::from(vec![ |
2269 | 1 | Some(0), // first "foo" |
2270 | 1 | Some(7), // last "foo" |
2271 | 1 | None, // null index should return null |
2272 | 1 | Some(5), // second "bar" |
2273 | 1 | Some(6), // another "bar" |
2274 | 1 | Some(2), // empty string |
2275 | 1 | Some(3), // input is null at this index |
2276 | | ]); |
2277 | | |
2278 | 1 | let result = take(&array, &indices, None).unwrap(); |
2279 | 1 | let result = result |
2280 | 1 | .as_any() |
2281 | 1 | .downcast_ref::<DictionaryArray<Int16Type>>() |
2282 | 1 | .unwrap(); |
2283 | | |
2284 | 1 | let result_values: StringArray = result.values().to_data().into(); |
2285 | | |
2286 | | // dictionary values should stay the same |
2287 | 1 | let expected_values = StringArray::from(vec!["foo", "bar", ""]); |
2288 | 1 | assert_eq!(&expected_values, dict_values); |
2289 | 1 | assert_eq!(&expected_values, &result_values); |
2290 | | |
2291 | 1 | let expected_keys = Int16Array::from(vec![ |
2292 | 1 | Some(0), |
2293 | 1 | Some(0), |
2294 | 1 | None, |
2295 | 1 | Some(1), |
2296 | 1 | Some(1), |
2297 | 1 | Some(2), |
2298 | 1 | None, |
2299 | | ]); |
2300 | 1 | assert_eq!(result.keys(), &expected_keys); |
2301 | 1 | } |
2302 | | |
2303 | 2 | fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S> |
2304 | 2 | where |
2305 | 2 | S: OffsetSizeTrait + 'static, |
2306 | 2 | T: ArrowPrimitiveType, |
2307 | 2 | PrimitiveArray<T>: From<Vec<Option<T::Native>>>, |
2308 | | { |
2309 | 2 | GenericListArray::from_iter_primitive::<T, _, _>( |
2310 | 2 | data.iter() |
2311 | 20 | .map2 (|x| x6 .as_ref6 ().map6 (|x| x.iter()6 .map6 (|x| Some(*x)))), |
2312 | | ) |
2313 | 2 | } |
2314 | | |
2315 | | #[test] |
2316 | 1 | fn test_take_value_index_from_list() { |
2317 | 1 | let list = build_generic_list::<i32, Int32Type>(vec![ |
2318 | 1 | Some(vec![0, 1]), |
2319 | 1 | Some(vec![2, 3, 4]), |
2320 | 1 | Some(vec![5, 6, 7, 8, 9]), |
2321 | | ]); |
2322 | 1 | let indices = UInt32Array::from(vec![2, 0]); |
2323 | | |
2324 | 1 | let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap(); |
2325 | | |
2326 | 1 | assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1])); |
2327 | 1 | assert_eq!(offsets, vec![0, 5, 7]); |
2328 | 1 | assert_eq!(null_buf.as_slice(), &[0b11111111]); |
2329 | 1 | } |
2330 | | |
2331 | | #[test] |
2332 | 1 | fn test_take_value_index_from_large_list() { |
2333 | 1 | let list = build_generic_list::<i64, Int32Type>(vec![ |
2334 | 1 | Some(vec![0, 1]), |
2335 | 1 | Some(vec![2, 3, 4]), |
2336 | 1 | Some(vec![5, 6, 7, 8, 9]), |
2337 | | ]); |
2338 | 1 | let indices = UInt32Array::from(vec![2, 0]); |
2339 | | |
2340 | 1 | let (indexed, offsets, null_buf) = |
2341 | 1 | take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap(); |
2342 | | |
2343 | 1 | assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1])); |
2344 | 1 | assert_eq!(offsets, vec![0, 5, 7]); |
2345 | 1 | assert_eq!(null_buf.as_slice(), &[0b11111111]); |
2346 | 1 | } |
2347 | | |
2348 | | #[test] |
2349 | 1 | fn test_take_runs() { |
2350 | 1 | let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2]; |
2351 | | |
2352 | 1 | let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new(); |
2353 | 1 | builder.extend(logical_array.into_iter().map(Some)); |
2354 | 1 | let run_array = builder.finish(); |
2355 | | |
2356 | 1 | let take_indices: PrimitiveArray<Int32Type> = |
2357 | 1 | vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect(); |
2358 | | |
2359 | 1 | let take_out = take_run(&run_array, &take_indices).unwrap(); |
2360 | | |
2361 | 1 | assert_eq!(take_out.len(), 7); |
2362 | 1 | assert_eq!(take_out.run_ends().len(), 7); |
2363 | 1 | assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]); |
2364 | | |
2365 | 1 | let take_out_values = take_out.values().as_primitive::<Int32Type>(); |
2366 | 1 | assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]); |
2367 | 1 | } |
2368 | | |
2369 | | #[test] |
2370 | 1 | fn test_take_value_index_from_fixed_list() { |
2371 | 1 | let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>( |
2372 | 1 | vec![ |
2373 | 1 | Some(vec![Some(1), Some(2), None]), |
2374 | 1 | Some(vec![Some(4), None, Some(6)]), |
2375 | 1 | None, |
2376 | 1 | Some(vec![None, Some(8), Some(9)]), |
2377 | | ], |
2378 | | 3, |
2379 | | ); |
2380 | | |
2381 | 1 | let indices = UInt32Array::from(vec![2, 1, 0]); |
2382 | 1 | let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap(); |
2383 | | |
2384 | 1 | assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2])); |
2385 | | |
2386 | 1 | let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]); |
2387 | 1 | let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap(); |
2388 | | |
2389 | 1 | assert_eq!( |
2390 | | indexed, |
2391 | 1 | UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2]) |
2392 | | ); |
2393 | 1 | } |
2394 | | |
2395 | | #[test] |
2396 | 1 | fn test_take_null_indices() { |
2397 | | // Build indices with values that are out of bounds, but masked by null mask |
2398 | 1 | let indices = Int32Array::new( |
2399 | 1 | vec![1, 2, 400, 400].into(), |
2400 | 1 | Some(NullBuffer::from(vec![true, true, false, false])), |
2401 | | ); |
2402 | 1 | let values = Int32Array::from(vec![1, 23, 4, 5]); |
2403 | 1 | let r = take(&values, &indices, None).unwrap(); |
2404 | 1 | let values = r |
2405 | 1 | .as_primitive::<Int32Type>() |
2406 | 1 | .into_iter() |
2407 | 1 | .collect::<Vec<_>>(); |
2408 | 1 | assert_eq!(&values, &[Some(23), Some(4), None, None]) |
2409 | 1 | } |
2410 | | |
2411 | | #[test] |
2412 | 1 | fn test_take_fixed_size_list_null_indices() { |
2413 | 1 | let indices = Int32Array::from_iter([Some(0), None]); |
2414 | 1 | let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3])); |
2415 | 1 | let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true)); |
2416 | 1 | let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap(); |
2417 | | |
2418 | 1 | let r = take(&values, &indices, None).unwrap(); |
2419 | 1 | let values = r |
2420 | 1 | .as_fixed_size_list() |
2421 | 1 | .values() |
2422 | 1 | .as_primitive::<Int32Type>() |
2423 | 1 | .into_iter() |
2424 | 1 | .collect::<Vec<_>>(); |
2425 | 1 | assert_eq!(values, &[Some(0), Some(1), None, None]) |
2426 | 1 | } |
2427 | | |
2428 | | #[test] |
2429 | 1 | fn test_take_bytes_null_indices() { |
2430 | 1 | let indices = Int32Array::new( |
2431 | 1 | vec![0, 1, 400, 400].into(), |
2432 | 1 | Some(NullBuffer::from_iter(vec![true, true, false, false])), |
2433 | | ); |
2434 | 1 | let values = StringArray::from(vec![Some("foo"), None]); |
2435 | 1 | let r = take(&values, &indices, None).unwrap(); |
2436 | 1 | let values = r.as_string::<i32>().iter().collect::<Vec<_>>(); |
2437 | 1 | assert_eq!(&values, &[Some("foo"), None, None, None]) |
2438 | 1 | } |
2439 | | |
2440 | | #[test] |
2441 | 1 | fn test_take_union_sparse() { |
2442 | 1 | let structs = create_test_struct(vec![ |
2443 | 1 | Some((Some(true), Some(42))), |
2444 | 1 | Some((Some(false), Some(28))), |
2445 | 1 | Some((Some(false), Some(19))), |
2446 | 1 | Some((Some(true), Some(31))), |
2447 | 1 | None, |
2448 | | ]); |
2449 | 1 | let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]); |
2450 | 1 | let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>(); |
2451 | | |
2452 | 1 | let union_fields = [ |
2453 | 1 | ( |
2454 | 1 | 0, |
2455 | 1 | Arc::new(Field::new("f1", structs.data_type().clone(), true)), |
2456 | 1 | ), |
2457 | 1 | ( |
2458 | 1 | 1, |
2459 | 1 | Arc::new(Field::new("f2", strings.data_type().clone(), true)), |
2460 | 1 | ), |
2461 | 1 | ] |
2462 | 1 | .into_iter() |
2463 | 1 | .collect(); |
2464 | 1 | let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)]; |
2465 | 1 | let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); |
2466 | | |
2467 | 1 | let indices = vec![0, 3, 1, 0, 2, 4]; |
2468 | 1 | let index = UInt32Array::from(indices.clone()); |
2469 | 1 | let actual = take(&array, &index, None).unwrap(); |
2470 | 1 | let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap(); |
2471 | 1 | let strings = actual.child(1); |
2472 | 1 | let strings = strings.as_any().downcast_ref::<StringArray>().unwrap(); |
2473 | | |
2474 | 1 | let actual = strings.iter().collect::<Vec<_>>(); |
2475 | 1 | let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")]; |
2476 | 1 | assert_eq!(expected, actual); |
2477 | 1 | } |
2478 | | |
2479 | | #[test] |
2480 | 1 | fn test_take_union_dense() { |
2481 | 1 | let type_ids = vec![0, 1, 1, 0, 0, 1, 0]; |
2482 | 1 | let offsets = vec![0, 0, 1, 1, 2, 2, 3]; |
2483 | 1 | let ints = vec![10, 20, 30, 40]; |
2484 | 1 | let strings = vec![Some("a"), None, Some("c"), Some("d")]; |
2485 | | |
2486 | 1 | let indices = vec![0, 3, 1, 0, 2, 4]; |
2487 | | |
2488 | 1 | let taken_type_ids = vec![0, 0, 1, 0, 1, 0]; |
2489 | 1 | let taken_offsets = vec![0, 1, 0, 2, 1, 3]; |
2490 | 1 | let taken_ints = vec![10, 20, 10, 30]; |
2491 | 1 | let taken_strings = vec![Some("a"), None]; |
2492 | | |
2493 | 1 | let type_ids = <ScalarBuffer<i8>>::from(type_ids); |
2494 | 1 | let offsets = <ScalarBuffer<i32>>::from(offsets); |
2495 | 1 | let ints = UInt32Array::from(ints); |
2496 | 1 | let strings = StringArray::from(strings); |
2497 | | |
2498 | 1 | let union_fields = [ |
2499 | 1 | ( |
2500 | 1 | 0, |
2501 | 1 | Arc::new(Field::new("f1", ints.data_type().clone(), true)), |
2502 | 1 | ), |
2503 | 1 | ( |
2504 | 1 | 1, |
2505 | 1 | Arc::new(Field::new("f2", strings.data_type().clone(), true)), |
2506 | 1 | ), |
2507 | 1 | ] |
2508 | 1 | .into_iter() |
2509 | 1 | .collect(); |
2510 | | |
2511 | 1 | let array = UnionArray::try_new( |
2512 | 1 | union_fields, |
2513 | 1 | type_ids, |
2514 | 1 | Some(offsets), |
2515 | 1 | vec![Arc::new(ints), Arc::new(strings)], |
2516 | | ) |
2517 | 1 | .unwrap(); |
2518 | | |
2519 | 1 | let index = UInt32Array::from(indices); |
2520 | | |
2521 | 1 | let actual = take(&array, &index, None).unwrap(); |
2522 | 1 | let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap(); |
2523 | | |
2524 | 1 | assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets))); |
2525 | 1 | assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids)); |
2526 | 1 | assert_eq!( |
2527 | 1 | UInt32Array::from(actual.child(0).to_data()), |
2528 | 1 | UInt32Array::from(taken_ints) |
2529 | | ); |
2530 | 1 | assert_eq!( |
2531 | 1 | StringArray::from(actual.child(1).to_data()), |
2532 | 1 | StringArray::from(taken_strings) |
2533 | | ); |
2534 | 1 | } |
2535 | | |
2536 | | #[test] |
2537 | 1 | fn test_take_union_dense_using_builder() { |
2538 | 1 | let mut builder = UnionBuilder::new_dense(); |
2539 | | |
2540 | 1 | builder.append::<Int32Type>("a", 1).unwrap(); |
2541 | 1 | builder.append::<Float64Type>("b", 3.0).unwrap(); |
2542 | 1 | builder.append::<Int32Type>("a", 4).unwrap(); |
2543 | 1 | builder.append::<Int32Type>("a", 5).unwrap(); |
2544 | 1 | builder.append::<Float64Type>("b", 2.0).unwrap(); |
2545 | | |
2546 | 1 | let union = builder.build().unwrap(); |
2547 | | |
2548 | 1 | let indices = UInt32Array::from(vec![2, 0, 1, 2]); |
2549 | | |
2550 | 1 | let mut builder = UnionBuilder::new_dense(); |
2551 | | |
2552 | 1 | builder.append::<Int32Type>("a", 4).unwrap(); |
2553 | 1 | builder.append::<Int32Type>("a", 1).unwrap(); |
2554 | 1 | builder.append::<Float64Type>("b", 3.0).unwrap(); |
2555 | 1 | builder.append::<Int32Type>("a", 4).unwrap(); |
2556 | | |
2557 | 1 | let taken = builder.build().unwrap(); |
2558 | | |
2559 | 1 | assert_eq!( |
2560 | 1 | taken.to_data(), |
2561 | 1 | take(&union, &indices, None).unwrap().to_data() |
2562 | | ); |
2563 | 1 | } |
2564 | | |
2565 | | #[test] |
2566 | 1 | fn test_take_union_dense_all_match_issue_6206() { |
2567 | 1 | let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]); |
2568 | 1 | let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); |
2569 | | |
2570 | 1 | let array = UnionArray::try_new( |
2571 | 1 | fields, |
2572 | 1 | ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]), |
2573 | 1 | Some(ScalarBuffer::from_iter(0_i32..5)), |
2574 | 1 | vec![ints], |
2575 | | ) |
2576 | 1 | .unwrap(); |
2577 | | |
2578 | 1 | let indicies = Int64Array::from(vec![0, 2, 4]); |
2579 | 1 | let array = take(&array, &indicies, None).unwrap(); |
2580 | 1 | assert_eq!(array.len(), 3); |
2581 | 1 | } |
2582 | | |
2583 | | #[test] |
2584 | 1 | fn test_take_bytes_offset_overflow() { |
2585 | 1 | let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]); |
2586 | 1 | let text = ('a'..='z').collect::<String>(); |
2587 | 1 | let values = StringArray::from(vec![Some(text.clone())]); |
2588 | 1 | assert!(matches!0 ( |
2589 | 1 | take(&values, &indices, None), |
2590 | | Err(ArrowError::OffsetOverflowError(_)) |
2591 | | )); |
2592 | 1 | } |
2593 | | } |