/Users/andrewlamb/Software/arrow-rs/arrow-row/src/lib.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! A comparable row-oriented representation of a collection of [`Array`]. |
19 | | //! |
20 | | //! [`Row`]s are [normalized for sorting], and can therefore be very efficiently [compared], |
21 | | //! using [`memcmp`] under the hood, or used in [non-comparison sorts] such as [radix sort]. |
22 | | //! This makes the row format ideal for implementing efficient multi-column sorting, |
23 | | //! grouping, aggregation, windowing and more, as described in more detail |
24 | | //! [in this blog post](https://arrow.apache.org/blog/2022/11/07/multi-column-sorts-in-arrow-rust-part-1/). |
25 | | //! |
26 | | //! For example, given three input [`Array`], [`RowConverter`] creates byte |
27 | | //! sequences that [compare] the same as when using [`lexsort`]. |
28 | | //! |
29 | | //! ```text |
30 | | //! ┌─────┐ ┌─────┐ ┌─────┐ |
31 | | //! │ │ │ │ │ │ |
32 | | //! ├─────┤ ┌ ┼─────┼ ─ ┼─────┼ ┐ ┏━━━━━━━━━━━━━┓ |
33 | | //! │ │ │ │ │ │ ─────────────▶┃ ┃ |
34 | | //! ├─────┤ └ ┼─────┼ ─ ┼─────┼ ┘ ┗━━━━━━━━━━━━━┛ |
35 | | //! │ │ │ │ │ │ |
36 | | //! └─────┘ └─────┘ └─────┘ |
37 | | //! ... |
38 | | //! ┌─────┐ ┌ ┬─────┬ ─ ┬─────┬ ┐ ┏━━━━━━━━┓ |
39 | | //! │ │ │ │ │ │ ─────────────▶┃ ┃ |
40 | | //! └─────┘ └ ┴─────┴ ─ ┴─────┴ ┘ ┗━━━━━━━━┛ |
41 | | //! UInt64 Utf8 F64 |
42 | | //! |
43 | | //! Input Arrays Row Format |
44 | | //! (Columns) |
45 | | //! ``` |
46 | | //! |
47 | | //! _[`Rows`] must be generated by the same [`RowConverter`] for the comparison |
48 | | //! to be meaningful._ |
49 | | //! |
50 | | //! # Basic Example |
51 | | //! ``` |
52 | | //! # use std::sync::Arc; |
53 | | //! # use arrow_row::{RowConverter, SortField}; |
54 | | //! # use arrow_array::{ArrayRef, Int32Array, StringArray}; |
55 | | //! # use arrow_array::cast::{AsArray, as_string_array}; |
56 | | //! # use arrow_array::types::Int32Type; |
57 | | //! # use arrow_schema::DataType; |
58 | | //! |
59 | | //! let a1 = Arc::new(Int32Array::from_iter_values([-1, -1, 0, 3, 3])) as ArrayRef; |
60 | | //! let a2 = Arc::new(StringArray::from_iter_values(["a", "b", "c", "d", "d"])) as ArrayRef; |
61 | | //! let arrays = vec![a1, a2]; |
62 | | //! |
63 | | //! // Convert arrays to rows |
64 | | //! let converter = RowConverter::new(vec![ |
65 | | //! SortField::new(DataType::Int32), |
66 | | //! SortField::new(DataType::Utf8), |
67 | | //! ]).unwrap(); |
68 | | //! let rows = converter.convert_columns(&arrays).unwrap(); |
69 | | //! |
70 | | //! // Compare rows |
71 | | //! for i in 0..4 { |
72 | | //! assert!(rows.row(i) <= rows.row(i + 1)); |
73 | | //! } |
74 | | //! assert_eq!(rows.row(3), rows.row(4)); |
75 | | //! |
76 | | //! // Convert rows back to arrays |
77 | | //! let converted = converter.convert_rows(&rows).unwrap(); |
78 | | //! assert_eq!(arrays, converted); |
79 | | //! |
80 | | //! // Compare rows from different arrays |
81 | | //! let a1 = Arc::new(Int32Array::from_iter_values([3, 4])) as ArrayRef; |
82 | | //! let a2 = Arc::new(StringArray::from_iter_values(["e", "f"])) as ArrayRef; |
83 | | //! let arrays = vec![a1, a2]; |
84 | | //! let rows2 = converter.convert_columns(&arrays).unwrap(); |
85 | | //! |
86 | | //! assert!(rows.row(4) < rows2.row(0)); |
87 | | //! assert!(rows.row(4) < rows2.row(1)); |
88 | | //! |
89 | | //! // Convert selection of rows back to arrays |
90 | | //! let selection = [rows.row(0), rows2.row(1), rows.row(2), rows2.row(0)]; |
91 | | //! let converted = converter.convert_rows(selection).unwrap(); |
92 | | //! let c1 = converted[0].as_primitive::<Int32Type>(); |
93 | | //! assert_eq!(c1.values(), &[-1, 4, 0, 3]); |
94 | | //! |
95 | | //! let c2 = converted[1].as_string::<i32>(); |
96 | | //! let c2_values: Vec<_> = c2.iter().flatten().collect(); |
97 | | //! assert_eq!(&c2_values, &["a", "f", "c", "e"]); |
98 | | //! ``` |
99 | | //! |
100 | | //! # Lexicographic Sorts (lexsort) |
101 | | //! |
102 | | //! The row format can also be used to implement a fast multi-column / lexicographic sort |
103 | | //! |
104 | | //! ``` |
105 | | //! # use arrow_row::{RowConverter, SortField}; |
106 | | //! # use arrow_array::{ArrayRef, UInt32Array}; |
107 | | //! fn lexsort_to_indices(arrays: &[ArrayRef]) -> UInt32Array { |
108 | | //! let fields = arrays |
109 | | //! .iter() |
110 | | //! .map(|a| SortField::new(a.data_type().clone())) |
111 | | //! .collect(); |
112 | | //! let converter = RowConverter::new(fields).unwrap(); |
113 | | //! let rows = converter.convert_columns(arrays).unwrap(); |
114 | | //! let mut sort: Vec<_> = rows.iter().enumerate().collect(); |
115 | | //! sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); |
116 | | //! UInt32Array::from_iter_values(sort.iter().map(|(i, _)| *i as u32)) |
117 | | //! } |
118 | | //! ``` |
119 | | //! |
120 | | //! # Flattening Dictionaries |
121 | | //! |
122 | | //! For performance reasons, dictionary arrays are flattened ("hydrated") to their |
123 | | //! underlying values during row conversion. See [the issue] for more details. |
124 | | //! |
125 | | //! This means that the arrays that come out of [`RowConverter::convert_rows`] |
126 | | //! may not have the same data types as the input arrays. For example, encoding |
127 | | //! a `Dictionary<Int8, Utf8>` and then will come out as a `Utf8` array. |
128 | | //! |
129 | | //! ``` |
130 | | //! # use arrow_array::{Array, ArrayRef, DictionaryArray}; |
131 | | //! # use arrow_array::types::Int8Type; |
132 | | //! # use arrow_row::{RowConverter, SortField}; |
133 | | //! # use arrow_schema::DataType; |
134 | | //! # use std::sync::Arc; |
135 | | //! // Input is a Dictionary array |
136 | | //! let dict: DictionaryArray::<Int8Type> = ["a", "b", "c", "a", "b"].into_iter().collect(); |
137 | | //! let sort_fields = vec![SortField::new(dict.data_type().clone())]; |
138 | | //! let arrays = vec![Arc::new(dict) as ArrayRef]; |
139 | | //! let converter = RowConverter::new(sort_fields).unwrap(); |
140 | | //! // Convert to rows |
141 | | //! let rows = converter.convert_columns(&arrays).unwrap(); |
142 | | //! let converted = converter.convert_rows(&rows).unwrap(); |
143 | | //! // result was a Utf8 array, not a Dictionary array |
144 | | //! assert_eq!(converted[0].data_type(), &DataType::Utf8); |
145 | | //! ``` |
146 | | //! |
147 | | //! [non-comparison sorts]: https://en.wikipedia.org/wiki/Sorting_algorithm#Non-comparison_sorts |
148 | | //! [radix sort]: https://en.wikipedia.org/wiki/Radix_sort |
149 | | //! [normalized for sorting]: http://wwwlgis.informatik.uni-kl.de/archiv/wwwdvs.informatik.uni-kl.de/courses/DBSREAL/SS2005/Vorlesungsunterlagen/Implementing_Sorting.pdf |
150 | | //! [`memcmp`]: https://www.man7.org/linux/man-pages/man3/memcmp.3.html |
151 | | //! [`lexsort`]: https://docs.rs/arrow-ord/latest/arrow_ord/sort/fn.lexsort.html |
152 | | //! [compared]: PartialOrd |
153 | | //! [compare]: PartialOrd |
154 | | //! [the issue]: https://github.com/apache/arrow-rs/issues/4811 |
155 | | |
156 | | #![doc( |
157 | | html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", |
158 | | html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" |
159 | | )] |
160 | | #![cfg_attr(docsrs, feature(doc_auto_cfg))] |
161 | | #![warn(missing_docs)] |
162 | | use std::cmp::Ordering; |
163 | | use std::hash::{Hash, Hasher}; |
164 | | use std::sync::Arc; |
165 | | |
166 | | use arrow_array::cast::*; |
167 | | use arrow_array::types::ArrowDictionaryKeyType; |
168 | | use arrow_array::*; |
169 | | use arrow_buffer::{ArrowNativeType, Buffer, OffsetBuffer, ScalarBuffer}; |
170 | | use arrow_data::{ArrayData, ArrayDataBuilder}; |
171 | | use arrow_schema::*; |
172 | | use variable::{decode_binary_view, decode_string_view}; |
173 | | |
174 | | use crate::fixed::{decode_bool, decode_fixed_size_binary, decode_primitive}; |
175 | | use crate::list::{compute_lengths_fixed_size_list, encode_fixed_size_list}; |
176 | | use crate::variable::{decode_binary, decode_string}; |
177 | | use arrow_array::types::{Int16Type, Int32Type, Int64Type}; |
178 | | |
179 | | mod fixed; |
180 | | mod list; |
181 | | mod run; |
182 | | mod variable; |
183 | | |
184 | | /// Converts [`ArrayRef`] columns into a [row-oriented](self) format. |
185 | | /// |
186 | | /// *Note: The encoding of the row format may change from release to release.* |
187 | | /// |
188 | | /// ## Overview |
189 | | /// |
190 | | /// The row format is a variable length byte sequence created by |
191 | | /// concatenating the encoded form of each column. The encoding for |
192 | | /// each column depends on its datatype (and sort options). |
193 | | /// |
194 | | /// The encoding is carefully designed in such a way that escaping is |
195 | | /// unnecessary: it is never ambiguous as to whether a byte is part of |
196 | | /// a sentinel (e.g. null) or a value. |
197 | | /// |
198 | | /// ## Unsigned Integer Encoding |
199 | | /// |
200 | | /// A null integer is encoded as a `0_u8`, followed by a zero-ed number of bytes corresponding |
201 | | /// to the integer's length. |
202 | | /// |
203 | | /// A valid integer is encoded as `1_u8`, followed by the big-endian representation of the |
204 | | /// integer. |
205 | | /// |
206 | | /// ```text |
207 | | /// ┌──┬──┬──┬──┐ ┌──┬──┬──┬──┬──┐ |
208 | | /// 3 │03│00│00│00│ │01│00│00│00│03│ |
209 | | /// └──┴──┴──┴──┘ └──┴──┴──┴──┴──┘ |
210 | | /// ┌──┬──┬──┬──┐ ┌──┬──┬──┬──┬──┐ |
211 | | /// 258 │02│01│00│00│ │01│00│00│01│02│ |
212 | | /// └──┴──┴──┴──┘ └──┴──┴──┴──┴──┘ |
213 | | /// ┌──┬──┬──┬──┐ ┌──┬──┬──┬──┬──┐ |
214 | | /// 23423 │7F│5B│00│00│ │01│00│00│5B│7F│ |
215 | | /// └──┴──┴──┴──┘ └──┴──┴──┴──┴──┘ |
216 | | /// ┌──┬──┬──┬──┐ ┌──┬──┬──┬──┬──┐ |
217 | | /// NULL │??│??│??│??│ │00│00│00│00│00│ |
218 | | /// └──┴──┴──┴──┘ └──┴──┴──┴──┴──┘ |
219 | | /// |
220 | | /// 32-bit (4 bytes) Row Format |
221 | | /// Value Little Endian |
222 | | /// ``` |
223 | | /// |
224 | | /// ## Signed Integer Encoding |
225 | | /// |
226 | | /// Signed integers have their most significant sign bit flipped, and are then encoded in the |
227 | | /// same manner as an unsigned integer. |
228 | | /// |
229 | | /// ```text |
230 | | /// ┌──┬──┬──┬──┐ ┌──┬──┬──┬──┐ ┌──┬──┬──┬──┬──┐ |
231 | | /// 5 │05│00│00│00│ │05│00│00│80│ │01│80│00│00│05│ |
232 | | /// └──┴──┴──┴──┘ └──┴──┴──┴──┘ └──┴──┴──┴──┴──┘ |
233 | | /// ┌──┬──┬──┬──┐ ┌──┬──┬──┬──┐ ┌──┬──┬──┬──┬──┐ |
234 | | /// -5 │FB│FF│FF│FF│ │FB│FF│FF│7F│ │01│7F│FF│FF│FB│ |
235 | | /// └──┴──┴──┴──┘ └──┴──┴──┴──┘ └──┴──┴──┴──┴──┘ |
236 | | /// |
237 | | /// Value 32-bit (4 bytes) High bit flipped Row Format |
238 | | /// Little Endian |
239 | | /// ``` |
240 | | /// |
241 | | /// ## Float Encoding |
242 | | /// |
243 | | /// Floats are converted from IEEE 754 representation to a signed integer representation |
244 | | /// by flipping all bar the sign bit if they are negative. |
245 | | /// |
246 | | /// They are then encoded in the same manner as a signed integer. |
247 | | /// |
248 | | /// ## Fixed Length Bytes Encoding |
249 | | /// |
250 | | /// Fixed length bytes are encoded in the same fashion as primitive types above. |
251 | | /// |
252 | | /// For a fixed length array of length `n`: |
253 | | /// |
254 | | /// A null is encoded as `0_u8` null sentinel followed by `n` `0_u8` bytes |
255 | | /// |
256 | | /// A valid value is encoded as `1_u8` followed by the value bytes |
257 | | /// |
258 | | /// ## Variable Length Bytes (including Strings) Encoding |
259 | | /// |
260 | | /// A null is encoded as a `0_u8`. |
261 | | /// |
262 | | /// An empty byte array is encoded as `1_u8`. |
263 | | /// |
264 | | /// A non-null, non-empty byte array is encoded as `2_u8` followed by the byte array |
265 | | /// encoded using a block based scheme described below. |
266 | | /// |
267 | | /// The byte array is broken up into fixed-width blocks, each block is written in turn |
268 | | /// to the output, followed by `0xFF_u8`. The final block is padded to 32-bytes |
269 | | /// with `0_u8` and written to the output, followed by the un-padded length in bytes |
270 | | /// of this final block as a `u8`. The first 4 blocks have a length of 8, with subsequent |
271 | | /// blocks using a length of 32, this is to reduce space amplification for small strings. |
272 | | /// |
273 | | /// Note the following example encodings use a block size of 4 bytes for brevity: |
274 | | /// |
275 | | /// ```text |
276 | | /// ┌───┬───┬───┬───┬───┬───┐ |
277 | | /// "MEEP" │02 │'M'│'E'│'E'│'P'│04 │ |
278 | | /// └───┴───┴───┴───┴───┴───┘ |
279 | | /// |
280 | | /// ┌───┐ |
281 | | /// "" │01 | |
282 | | /// └───┘ |
283 | | /// |
284 | | /// NULL ┌───┐ |
285 | | /// │00 │ |
286 | | /// └───┘ |
287 | | /// |
288 | | /// "Defenestration" ┌───┬───┬───┬───┬───┬───┐ |
289 | | /// │02 │'D'│'e'│'f'│'e'│FF │ |
290 | | /// └───┼───┼───┼───┼───┼───┤ |
291 | | /// │'n'│'e'│'s'│'t'│FF │ |
292 | | /// ├───┼───┼───┼───┼───┤ |
293 | | /// │'r'│'a'│'t'│'r'│FF │ |
294 | | /// ├───┼───┼───┼───┼───┤ |
295 | | /// │'a'│'t'│'i'│'o'│FF │ |
296 | | /// ├───┼───┼───┼───┼───┤ |
297 | | /// │'n'│00 │00 │00 │01 │ |
298 | | /// └───┴───┴───┴───┴───┘ |
299 | | /// ``` |
300 | | /// |
301 | | /// This approach is loosely inspired by [COBS] encoding, and chosen over more traditional |
302 | | /// [byte stuffing] as it is more amenable to vectorisation, in particular AVX-256. |
303 | | /// |
304 | | /// ## Dictionary Encoding |
305 | | /// |
306 | | /// Dictionary encoded arrays are hydrated to their underlying values |
307 | | /// |
308 | | /// ## REE Encoding |
309 | | /// |
310 | | /// REE (Run End Encoding) arrays, A form of Run Length Encoding, are hydrated to their underlying values. |
311 | | /// |
312 | | /// ## Struct Encoding |
313 | | /// |
314 | | /// A null is encoded as a `0_u8`. |
315 | | /// |
316 | | /// A valid value is encoded as `1_u8` followed by the row encoding of each child. |
317 | | /// |
318 | | /// This encoding effectively flattens the schema in a depth-first fashion. |
319 | | /// |
320 | | /// For example |
321 | | /// |
322 | | /// ```text |
323 | | /// ┌───────┬────────────────────────┬───────┐ |
324 | | /// │ Int32 │ Struct[Int32, Float32] │ Int32 │ |
325 | | /// └───────┴────────────────────────┴───────┘ |
326 | | /// ``` |
327 | | /// |
328 | | /// Is encoded as |
329 | | /// |
330 | | /// ```text |
331 | | /// ┌───────┬───────────────┬───────┬─────────┬───────┐ |
332 | | /// │ Int32 │ Null Sentinel │ Int32 │ Float32 │ Int32 │ |
333 | | /// └───────┴───────────────┴───────┴─────────┴───────┘ |
334 | | /// ``` |
335 | | /// |
336 | | /// ## List Encoding |
337 | | /// |
338 | | /// Lists are encoded by first encoding all child elements to the row format. |
339 | | /// |
340 | | /// A list value is then encoded as the concatenation of each of the child elements, |
341 | | /// separately encoded using the variable length encoding described above, followed |
342 | | /// by the variable length encoding of an empty byte array. |
343 | | /// |
344 | | /// For example given: |
345 | | /// |
346 | | /// ```text |
347 | | /// [1_u8, 2_u8, 3_u8] |
348 | | /// [1_u8, null] |
349 | | /// [] |
350 | | /// null |
351 | | /// ``` |
352 | | /// |
353 | | /// The elements would be converted to: |
354 | | /// |
355 | | /// ```text |
356 | | /// ┌──┬──┐ ┌──┬──┐ ┌──┬──┐ ┌──┬──┐ ┌──┬──┐ |
357 | | /// 1 │01│01│ 2 │01│02│ 3 │01│03│ 1 │01│01│ null │00│00│ |
358 | | /// └──┴──┘ └──┴──┘ └──┴──┘ └──┴──┘ └──┴──┘ |
359 | | ///``` |
360 | | /// |
361 | | /// Which would be encoded as |
362 | | /// |
363 | | /// ```text |
364 | | /// ┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┐ |
365 | | /// [1_u8, 2_u8, 3_u8] │02│01│01│00│00│02│02│01│02│00│00│02│02│01│03│00│00│02│01│ |
366 | | /// └──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┘ |
367 | | /// └──── 1_u8 ────┘ └──── 2_u8 ────┘ └──── 3_u8 ────┘ |
368 | | /// |
369 | | /// ┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┐ |
370 | | /// [1_u8, null] │02│01│01│00│00│02│02│00│00│00│00│02│01│ |
371 | | /// └──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┘ |
372 | | /// └──── 1_u8 ────┘ └──── null ────┘ |
373 | | /// |
374 | | ///``` |
375 | | /// |
376 | | /// With `[]` represented by an empty byte array, and `null` a null byte array. |
377 | | /// |
378 | | /// ## Fixed Size List Encoding |
379 | | /// |
380 | | /// Fixed Size Lists are encoded by first encoding all child elements to the row format. |
381 | | /// |
382 | | /// A non-null list value is then encoded as 0x01 followed by the concatenation of each |
383 | | /// of the child elements. A null list value is encoded as a null marker. |
384 | | /// |
385 | | /// For example given: |
386 | | /// |
387 | | /// ```text |
388 | | /// [1_u8, 2_u8] |
389 | | /// [3_u8, null] |
390 | | /// null |
391 | | /// ``` |
392 | | /// |
393 | | /// The elements would be converted to: |
394 | | /// |
395 | | /// ```text |
396 | | /// ┌──┬──┐ ┌──┬──┐ ┌──┬──┐ ┌──┬──┐ |
397 | | /// 1 │01│01│ 2 │01│02│ 3 │01│03│ null │00│00│ |
398 | | /// └──┴──┘ └──┴──┘ └──┴──┘ └──┴──┘ |
399 | | ///``` |
400 | | /// |
401 | | /// Which would be encoded as |
402 | | /// |
403 | | /// ```text |
404 | | /// ┌──┬──┬──┬──┬──┐ |
405 | | /// [1_u8, 2_u8] │01│01│01│01│02│ |
406 | | /// └──┴──┴──┴──┴──┘ |
407 | | /// └ 1 ┘ └ 2 ┘ |
408 | | /// ┌──┬──┬──┬──┬──┐ |
409 | | /// [3_u8, null] │01│01│03│00│00│ |
410 | | /// └──┴──┴──┴──┴──┘ |
411 | | /// └ 1 ┘ └null┘ |
412 | | /// ┌──┐ |
413 | | /// null │00│ |
414 | | /// └──┘ |
415 | | /// |
416 | | ///``` |
417 | | /// |
418 | | /// # Ordering |
419 | | /// |
420 | | /// ## Float Ordering |
421 | | /// |
422 | | /// Floats are totally ordered in accordance to the `totalOrder` predicate as defined |
423 | | /// in the IEEE 754 (2008 revision) floating point standard. |
424 | | /// |
425 | | /// The ordering established by this does not always agree with the |
426 | | /// [`PartialOrd`] and [`PartialEq`] implementations of `f32`. For example, |
427 | | /// they consider negative and positive zero equal, while this does not |
428 | | /// |
429 | | /// ## Null Ordering |
430 | | /// |
431 | | /// The encoding described above will order nulls first, this can be inverted by representing |
432 | | /// nulls as `0xFF_u8` instead of `0_u8` |
433 | | /// |
434 | | /// ## Reverse Column Ordering |
435 | | /// |
436 | | /// The order of a given column can be reversed by negating the encoded bytes of non-null values |
437 | | /// |
438 | | /// [COBS]: https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing |
439 | | /// [byte stuffing]: https://en.wikipedia.org/wiki/High-Level_Data_Link_Control#Asynchronous_framing |
440 | | #[derive(Debug)] |
441 | | pub struct RowConverter { |
442 | | fields: Arc<[SortField]>, |
443 | | /// State for codecs |
444 | | codecs: Vec<Codec>, |
445 | | } |
446 | | |
447 | | #[derive(Debug)] |
448 | | enum Codec { |
449 | | /// No additional codec state is necessary |
450 | | Stateless, |
451 | | /// A row converter for the dictionary values |
452 | | /// and the encoding of a row containing only nulls |
453 | | Dictionary(RowConverter, OwnedRow), |
454 | | /// A row converter for the child fields |
455 | | /// and the encoding of a row containing only nulls |
456 | | Struct(RowConverter, OwnedRow), |
457 | | /// A row converter for the child field |
458 | | List(RowConverter), |
459 | | /// A row converter for the values array of a run-end encoded array |
460 | | RunEndEncoded(RowConverter), |
461 | | } |
462 | | |
463 | | impl Codec { |
464 | 0 | fn new(sort_field: &SortField) -> Result<Self, ArrowError> { |
465 | 0 | match &sort_field.data_type { |
466 | 0 | DataType::Dictionary(_, values) => { |
467 | 0 | let sort_field = |
468 | 0 | SortField::new_with_options(values.as_ref().clone(), sort_field.options); |
469 | | |
470 | 0 | let converter = RowConverter::new(vec![sort_field])?; |
471 | 0 | let null_array = new_null_array(values.as_ref(), 1); |
472 | 0 | let nulls = converter.convert_columns(&[null_array])?; |
473 | | |
474 | 0 | let owned = OwnedRow { |
475 | 0 | data: nulls.buffer.into(), |
476 | 0 | config: nulls.config, |
477 | 0 | }; |
478 | 0 | Ok(Self::Dictionary(converter, owned)) |
479 | | } |
480 | 0 | DataType::RunEndEncoded(_, values) => { |
481 | | // Similar to List implementation |
482 | 0 | let options = SortOptions { |
483 | 0 | descending: false, |
484 | 0 | nulls_first: sort_field.options.nulls_first != sort_field.options.descending, |
485 | 0 | }; |
486 | | |
487 | 0 | let field = SortField::new_with_options(values.data_type().clone(), options); |
488 | 0 | let converter = RowConverter::new(vec![field])?; |
489 | 0 | Ok(Self::RunEndEncoded(converter)) |
490 | | } |
491 | 0 | d if !d.is_nested() => Ok(Self::Stateless), |
492 | 0 | DataType::List(f) | DataType::LargeList(f) => { |
493 | | // The encoded contents will be inverted if descending is set to true |
494 | | // As such we set `descending` to false and negate nulls first if it |
495 | | // it set to true |
496 | 0 | let options = SortOptions { |
497 | 0 | descending: false, |
498 | 0 | nulls_first: sort_field.options.nulls_first != sort_field.options.descending, |
499 | 0 | }; |
500 | | |
501 | 0 | let field = SortField::new_with_options(f.data_type().clone(), options); |
502 | 0 | let converter = RowConverter::new(vec![field])?; |
503 | 0 | Ok(Self::List(converter)) |
504 | | } |
505 | 0 | DataType::FixedSizeList(f, _) => { |
506 | 0 | let field = SortField::new_with_options(f.data_type().clone(), sort_field.options); |
507 | 0 | let converter = RowConverter::new(vec![field])?; |
508 | 0 | Ok(Self::List(converter)) |
509 | | } |
510 | 0 | DataType::Struct(f) => { |
511 | 0 | let sort_fields = f |
512 | 0 | .iter() |
513 | 0 | .map(|x| SortField::new_with_options(x.data_type().clone(), sort_field.options)) |
514 | 0 | .collect(); |
515 | | |
516 | 0 | let converter = RowConverter::new(sort_fields)?; |
517 | 0 | let nulls: Vec<_> = f.iter().map(|x| new_null_array(x.data_type(), 1)).collect(); |
518 | | |
519 | 0 | let nulls = converter.convert_columns(&nulls)?; |
520 | 0 | let owned = OwnedRow { |
521 | 0 | data: nulls.buffer.into(), |
522 | 0 | config: nulls.config, |
523 | 0 | }; |
524 | | |
525 | 0 | Ok(Self::Struct(converter, owned)) |
526 | | } |
527 | 0 | _ => Err(ArrowError::NotYetImplemented(format!( |
528 | 0 | "not yet implemented: {:?}", |
529 | 0 | sort_field.data_type |
530 | 0 | ))), |
531 | | } |
532 | 0 | } |
533 | | |
534 | 0 | fn encoder(&self, array: &dyn Array) -> Result<Encoder<'_>, ArrowError> { |
535 | 0 | match self { |
536 | 0 | Codec::Stateless => Ok(Encoder::Stateless), |
537 | 0 | Codec::Dictionary(converter, nulls) => { |
538 | 0 | let values = array.as_any_dictionary().values().clone(); |
539 | 0 | let rows = converter.convert_columns(&[values])?; |
540 | 0 | Ok(Encoder::Dictionary(rows, nulls.row())) |
541 | | } |
542 | 0 | Codec::Struct(converter, null) => { |
543 | 0 | let v = as_struct_array(array); |
544 | 0 | let rows = converter.convert_columns(v.columns())?; |
545 | 0 | Ok(Encoder::Struct(rows, null.row())) |
546 | | } |
547 | 0 | Codec::List(converter) => { |
548 | 0 | let values = match array.data_type() { |
549 | | DataType::List(_) => { |
550 | 0 | let list_array = as_list_array(array); |
551 | 0 | let first_offset = list_array.offsets()[0] as usize; |
552 | 0 | let last_offset = |
553 | 0 | list_array.offsets()[list_array.offsets().len() - 1] as usize; |
554 | | |
555 | | // values can include more data than referenced in the ListArray, only encode |
556 | | // the referenced values. |
557 | 0 | list_array |
558 | 0 | .values() |
559 | 0 | .slice(first_offset, last_offset - first_offset) |
560 | | } |
561 | | DataType::LargeList(_) => { |
562 | 0 | let list_array = as_large_list_array(array); |
563 | | |
564 | 0 | let first_offset = list_array.offsets()[0] as usize; |
565 | 0 | let last_offset = |
566 | 0 | list_array.offsets()[list_array.offsets().len() - 1] as usize; |
567 | | |
568 | | // values can include more data than referenced in the LargeListArray, only encode |
569 | | // the referenced values. |
570 | 0 | list_array |
571 | 0 | .values() |
572 | 0 | .slice(first_offset, last_offset - first_offset) |
573 | | } |
574 | | DataType::FixedSizeList(_, _) => { |
575 | 0 | as_fixed_size_list_array(array).values().clone() |
576 | | } |
577 | 0 | _ => unreachable!(), |
578 | | }; |
579 | 0 | let rows = converter.convert_columns(&[values])?; |
580 | 0 | Ok(Encoder::List(rows)) |
581 | | } |
582 | 0 | Codec::RunEndEncoded(converter) => { |
583 | 0 | let values = match array.data_type() { |
584 | 0 | DataType::RunEndEncoded(r, _) => match r.data_type() { |
585 | 0 | DataType::Int16 => array.as_run::<Int16Type>().values(), |
586 | 0 | DataType::Int32 => array.as_run::<Int32Type>().values(), |
587 | 0 | DataType::Int64 => array.as_run::<Int64Type>().values(), |
588 | 0 | _ => unreachable!("Unsupported run end index type: {r:?}"), |
589 | | }, |
590 | 0 | _ => unreachable!(), |
591 | | }; |
592 | 0 | let rows = converter.convert_columns(std::slice::from_ref(values))?; |
593 | 0 | Ok(Encoder::RunEndEncoded(rows)) |
594 | | } |
595 | | } |
596 | 0 | } |
597 | | |
598 | 0 | fn size(&self) -> usize { |
599 | 0 | match self { |
600 | 0 | Codec::Stateless => 0, |
601 | 0 | Codec::Dictionary(converter, nulls) => converter.size() + nulls.data.len(), |
602 | 0 | Codec::Struct(converter, nulls) => converter.size() + nulls.data.len(), |
603 | 0 | Codec::List(converter) => converter.size(), |
604 | 0 | Codec::RunEndEncoded(converter) => converter.size(), |
605 | | } |
606 | 0 | } |
607 | | } |
608 | | |
609 | | #[derive(Debug)] |
610 | | enum Encoder<'a> { |
611 | | /// No additional encoder state is necessary |
612 | | Stateless, |
613 | | /// The encoding of the child array and the encoding of a null row |
614 | | Dictionary(Rows, Row<'a>), |
615 | | /// The row encoding of the child arrays and the encoding of a null row |
616 | | /// |
617 | | /// It is necessary to encode to a temporary [`Rows`] to avoid serializing |
618 | | /// values that are masked by a null in the parent StructArray, otherwise |
619 | | /// this would establish an ordering between semantically null values |
620 | | Struct(Rows, Row<'a>), |
621 | | /// The row encoding of the child array |
622 | | List(Rows), |
623 | | /// The row encoding of the values array |
624 | | RunEndEncoded(Rows), |
625 | | } |
626 | | |
627 | | /// Configure the data type and sort order for a given column |
628 | | #[derive(Debug, Clone, PartialEq, Eq)] |
629 | | pub struct SortField { |
630 | | /// Sort options |
631 | | options: SortOptions, |
632 | | /// Data type |
633 | | data_type: DataType, |
634 | | } |
635 | | |
636 | | impl SortField { |
637 | | /// Create a new column with the given data type |
638 | 0 | pub fn new(data_type: DataType) -> Self { |
639 | 0 | Self::new_with_options(data_type, Default::default()) |
640 | 0 | } |
641 | | |
642 | | /// Create a new column with the given data type and [`SortOptions`] |
643 | 0 | pub fn new_with_options(data_type: DataType, options: SortOptions) -> Self { |
644 | 0 | Self { options, data_type } |
645 | 0 | } |
646 | | |
647 | | /// Return size of this instance in bytes. |
648 | | /// |
649 | | /// Includes the size of `Self`. |
650 | 0 | pub fn size(&self) -> usize { |
651 | 0 | self.data_type.size() + std::mem::size_of::<Self>() - std::mem::size_of::<DataType>() |
652 | 0 | } |
653 | | } |
654 | | |
655 | | impl RowConverter { |
656 | | /// Create a new [`RowConverter`] with the provided schema |
657 | 0 | pub fn new(fields: Vec<SortField>) -> Result<Self, ArrowError> { |
658 | 0 | if !Self::supports_fields(&fields) { |
659 | 0 | return Err(ArrowError::NotYetImplemented(format!( |
660 | 0 | "Row format support not yet implemented for: {fields:?}" |
661 | 0 | ))); |
662 | 0 | } |
663 | | |
664 | 0 | let codecs = fields.iter().map(Codec::new).collect::<Result<_, _>>()?; |
665 | 0 | Ok(Self { |
666 | 0 | fields: fields.into(), |
667 | 0 | codecs, |
668 | 0 | }) |
669 | 0 | } |
670 | | |
671 | | /// Check if the given fields are supported by the row format. |
672 | 0 | pub fn supports_fields(fields: &[SortField]) -> bool { |
673 | 0 | fields.iter().all(|x| Self::supports_datatype(&x.data_type)) |
674 | 0 | } |
675 | | |
676 | 0 | fn supports_datatype(d: &DataType) -> bool { |
677 | 0 | match d { |
678 | 0 | _ if !d.is_nested() => true, |
679 | 0 | DataType::List(f) | DataType::LargeList(f) | DataType::FixedSizeList(f, _) => { |
680 | 0 | Self::supports_datatype(f.data_type()) |
681 | | } |
682 | 0 | DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())), |
683 | 0 | DataType::RunEndEncoded(_, values) => Self::supports_datatype(values.data_type()), |
684 | 0 | _ => false, |
685 | | } |
686 | 0 | } |
687 | | |
688 | | /// Convert [`ArrayRef`] columns into [`Rows`] |
689 | | /// |
690 | | /// See [`Row`] for information on when [`Row`] can be compared |
691 | | /// |
692 | | /// See [`Self::convert_rows`] for converting [`Rows`] back into [`ArrayRef`] |
693 | | /// |
694 | | /// # Panics |
695 | | /// |
696 | | /// Panics if the schema of `columns` does not match that provided to [`RowConverter::new`] |
697 | 0 | pub fn convert_columns(&self, columns: &[ArrayRef]) -> Result<Rows, ArrowError> { |
698 | 0 | let num_rows = columns.first().map(|x| x.len()).unwrap_or(0); |
699 | 0 | let mut rows = self.empty_rows(num_rows, 0); |
700 | 0 | self.append(&mut rows, columns)?; |
701 | 0 | Ok(rows) |
702 | 0 | } |
703 | | |
704 | | /// Convert [`ArrayRef`] columns appending to an existing [`Rows`] |
705 | | /// |
706 | | /// See [`Row`] for information on when [`Row`] can be compared |
707 | | /// |
708 | | /// # Panics |
709 | | /// |
710 | | /// Panics if |
711 | | /// * The schema of `columns` does not match that provided to [`RowConverter::new`] |
712 | | /// * The provided [`Rows`] were not created by this [`RowConverter`] |
713 | | /// |
714 | | /// ``` |
715 | | /// # use std::sync::Arc; |
716 | | /// # use std::collections::HashSet; |
717 | | /// # use arrow_array::cast::AsArray; |
718 | | /// # use arrow_array::StringArray; |
719 | | /// # use arrow_row::{Row, RowConverter, SortField}; |
720 | | /// # use arrow_schema::DataType; |
721 | | /// # |
722 | | /// let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
723 | | /// let a1 = StringArray::from(vec!["hello", "world"]); |
724 | | /// let a2 = StringArray::from(vec!["a", "a", "hello"]); |
725 | | /// |
726 | | /// let mut rows = converter.empty_rows(5, 128); |
727 | | /// converter.append(&mut rows, &[Arc::new(a1)]).unwrap(); |
728 | | /// converter.append(&mut rows, &[Arc::new(a2)]).unwrap(); |
729 | | /// |
730 | | /// let back = converter.convert_rows(&rows).unwrap(); |
731 | | /// let values: Vec<_> = back[0].as_string::<i32>().iter().map(Option::unwrap).collect(); |
732 | | /// assert_eq!(&values, &["hello", "world", "a", "a", "hello"]); |
733 | | /// ``` |
734 | 0 | pub fn append(&self, rows: &mut Rows, columns: &[ArrayRef]) -> Result<(), ArrowError> { |
735 | 0 | assert!( |
736 | 0 | Arc::ptr_eq(&rows.config.fields, &self.fields), |
737 | 0 | "rows were not produced by this RowConverter" |
738 | | ); |
739 | | |
740 | 0 | if columns.len() != self.fields.len() { |
741 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
742 | 0 | "Incorrect number of arrays provided to RowConverter, expected {} got {}", |
743 | 0 | self.fields.len(), |
744 | 0 | columns.len() |
745 | 0 | ))); |
746 | 0 | } |
747 | 0 | for colum in columns.iter().skip(1) { |
748 | 0 | if colum.len() != columns[0].len() { |
749 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
750 | 0 | "RowConverter columns must all have the same length, expected {} got {}", |
751 | 0 | columns[0].len(), |
752 | 0 | colum.len() |
753 | 0 | ))); |
754 | 0 | } |
755 | | } |
756 | | |
757 | 0 | let encoders = columns |
758 | 0 | .iter() |
759 | 0 | .zip(&self.codecs) |
760 | 0 | .zip(self.fields.iter()) |
761 | 0 | .map(|((column, codec), field)| { |
762 | 0 | if !column.data_type().equals_datatype(&field.data_type) { |
763 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
764 | 0 | "RowConverter column schema mismatch, expected {} got {}", |
765 | 0 | field.data_type, |
766 | 0 | column.data_type() |
767 | 0 | ))); |
768 | 0 | } |
769 | 0 | codec.encoder(column.as_ref()) |
770 | 0 | }) |
771 | 0 | .collect::<Result<Vec<_>, _>>()?; |
772 | | |
773 | 0 | let write_offset = rows.num_rows(); |
774 | 0 | let lengths = row_lengths(columns, &encoders); |
775 | 0 | let total = lengths.extend_offsets(rows.offsets[write_offset], &mut rows.offsets); |
776 | 0 | rows.buffer.resize(total, 0); |
777 | | |
778 | 0 | for ((column, field), encoder) in columns.iter().zip(self.fields.iter()).zip(encoders) { |
779 | | // We encode a column at a time to minimise dispatch overheads |
780 | 0 | encode_column( |
781 | 0 | &mut rows.buffer, |
782 | 0 | &mut rows.offsets[write_offset..], |
783 | 0 | column.as_ref(), |
784 | 0 | field.options, |
785 | 0 | &encoder, |
786 | | ) |
787 | | } |
788 | | |
789 | 0 | if cfg!(debug_assertions) { |
790 | 0 | assert_eq!(*rows.offsets.last().unwrap(), rows.buffer.len()); |
791 | 0 | rows.offsets |
792 | 0 | .windows(2) |
793 | 0 | .for_each(|w| assert!(w[0] <= w[1], "offsets should be monotonic")); |
794 | 0 | } |
795 | | |
796 | 0 | Ok(()) |
797 | 0 | } |
798 | | |
799 | | /// Convert [`Rows`] columns into [`ArrayRef`] |
800 | | /// |
801 | | /// See [`Self::convert_columns`] for converting [`ArrayRef`] into [`Rows`] |
802 | | /// |
803 | | /// # Panics |
804 | | /// |
805 | | /// Panics if the rows were not produced by this [`RowConverter`] |
806 | | pub fn convert_rows<'a, I>(&self, rows: I) -> Result<Vec<ArrayRef>, ArrowError> |
807 | | where |
808 | | I: IntoIterator<Item = Row<'a>>, |
809 | | { |
810 | | let mut validate_utf8 = false; |
811 | | let mut rows: Vec<_> = rows |
812 | | .into_iter() |
813 | | .map(|row| { |
814 | | assert!( |
815 | | Arc::ptr_eq(&row.config.fields, &self.fields), |
816 | | "rows were not produced by this RowConverter" |
817 | | ); |
818 | | validate_utf8 |= row.config.validate_utf8; |
819 | | row.data |
820 | | }) |
821 | | .collect(); |
822 | | |
823 | | // SAFETY |
824 | | // We have validated that the rows came from this [`RowConverter`] |
825 | | // and therefore must be valid |
826 | | let result = unsafe { self.convert_raw(&mut rows, validate_utf8) }?; |
827 | | |
828 | | if cfg!(test) { |
829 | | for (i, row) in rows.iter().enumerate() { |
830 | | if !row.is_empty() { |
831 | | return Err(ArrowError::InvalidArgumentError(format!( |
832 | | "Codecs {codecs:?} did not consume all bytes for row {i}, remaining bytes: {row:?}", |
833 | | codecs = &self.codecs |
834 | | ))); |
835 | | } |
836 | | } |
837 | | } |
838 | | |
839 | | Ok(result) |
840 | | } |
841 | | |
842 | | /// Returns an empty [`Rows`] with capacity for `row_capacity` rows with |
843 | | /// a total length of `data_capacity` |
844 | | /// |
845 | | /// This can be used to buffer a selection of [`Row`] |
846 | | /// |
847 | | /// ``` |
848 | | /// # use std::sync::Arc; |
849 | | /// # use std::collections::HashSet; |
850 | | /// # use arrow_array::cast::AsArray; |
851 | | /// # use arrow_array::StringArray; |
852 | | /// # use arrow_row::{Row, RowConverter, SortField}; |
853 | | /// # use arrow_schema::DataType; |
854 | | /// # |
855 | | /// let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
856 | | /// let array = StringArray::from(vec!["hello", "world", "a", "a", "hello"]); |
857 | | /// |
858 | | /// // Convert to row format and deduplicate |
859 | | /// let converted = converter.convert_columns(&[Arc::new(array)]).unwrap(); |
860 | | /// let mut distinct_rows = converter.empty_rows(3, 100); |
861 | | /// let mut dedup: HashSet<Row> = HashSet::with_capacity(3); |
862 | | /// converted.iter().filter(|row| dedup.insert(*row)).for_each(|row| distinct_rows.push(row)); |
863 | | /// |
864 | | /// // Note: we could skip buffering and feed the filtered iterator directly |
865 | | /// // into convert_rows, this is done for demonstration purposes only |
866 | | /// let distinct = converter.convert_rows(&distinct_rows).unwrap(); |
867 | | /// let values: Vec<_> = distinct[0].as_string::<i32>().iter().map(Option::unwrap).collect(); |
868 | | /// assert_eq!(&values, &["hello", "world", "a"]); |
869 | | /// ``` |
870 | 0 | pub fn empty_rows(&self, row_capacity: usize, data_capacity: usize) -> Rows { |
871 | 0 | let mut offsets = Vec::with_capacity(row_capacity.saturating_add(1)); |
872 | 0 | offsets.push(0); |
873 | | |
874 | 0 | Rows { |
875 | 0 | offsets, |
876 | 0 | buffer: Vec::with_capacity(data_capacity), |
877 | 0 | config: RowConfig { |
878 | 0 | fields: self.fields.clone(), |
879 | 0 | validate_utf8: false, |
880 | 0 | }, |
881 | 0 | } |
882 | 0 | } |
883 | | |
884 | | /// Create a new [Rows] instance from the given binary data. |
885 | | /// |
886 | | /// ``` |
887 | | /// # use std::sync::Arc; |
888 | | /// # use std::collections::HashSet; |
889 | | /// # use arrow_array::cast::AsArray; |
890 | | /// # use arrow_array::StringArray; |
891 | | /// # use arrow_row::{OwnedRow, Row, RowConverter, RowParser, SortField}; |
892 | | /// # use arrow_schema::DataType; |
893 | | /// # |
894 | | /// let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
895 | | /// let array = StringArray::from(vec!["hello", "world", "a", "a", "hello"]); |
896 | | /// let rows = converter.convert_columns(&[Arc::new(array)]).unwrap(); |
897 | | /// |
898 | | /// // We can convert rows into binary format and back in batch. |
899 | | /// let values: Vec<OwnedRow> = rows.iter().map(|r| r.owned()).collect(); |
900 | | /// let binary = rows.try_into_binary().expect("known-small array"); |
901 | | /// let converted = converter.from_binary(binary.clone()); |
902 | | /// assert!(converted.iter().eq(values.iter().map(|r| r.row()))); |
903 | | /// ``` |
904 | | /// |
905 | | /// # Panics |
906 | | /// |
907 | | /// This function expects the passed [BinaryArray] to contain valid row data as produced by this |
908 | | /// [RowConverter]. It will panic if any rows are null. Operations on the returned [Rows] may |
909 | | /// panic if the data is malformed. |
910 | 0 | pub fn from_binary(&self, array: BinaryArray) -> Rows { |
911 | 0 | assert_eq!( |
912 | 0 | array.null_count(), |
913 | | 0, |
914 | 0 | "can't construct Rows instance from array with nulls" |
915 | | ); |
916 | | Rows { |
917 | 0 | buffer: array.values().to_vec(), |
918 | 0 | offsets: array.offsets().iter().map(|&i| i.as_usize()).collect(), |
919 | 0 | config: RowConfig { |
920 | 0 | fields: Arc::clone(&self.fields), |
921 | 0 | validate_utf8: true, |
922 | 0 | }, |
923 | | } |
924 | 0 | } |
925 | | |
926 | | /// Convert raw bytes into [`ArrayRef`] |
927 | | /// |
928 | | /// # Safety |
929 | | /// |
930 | | /// `rows` must contain valid data for this [`RowConverter`] |
931 | 0 | unsafe fn convert_raw( |
932 | 0 | &self, |
933 | 0 | rows: &mut [&[u8]], |
934 | 0 | validate_utf8: bool, |
935 | 0 | ) -> Result<Vec<ArrayRef>, ArrowError> { |
936 | 0 | self.fields |
937 | 0 | .iter() |
938 | 0 | .zip(&self.codecs) |
939 | 0 | .map(|(field, codec)| decode_column(field, rows, codec, validate_utf8)) |
940 | 0 | .collect() |
941 | 0 | } |
942 | | |
943 | | /// Returns a [`RowParser`] that can be used to parse [`Row`] from bytes |
944 | 0 | pub fn parser(&self) -> RowParser { |
945 | 0 | RowParser::new(Arc::clone(&self.fields)) |
946 | 0 | } |
947 | | |
948 | | /// Returns the size of this instance in bytes |
949 | | /// |
950 | | /// Includes the size of `Self`. |
951 | 0 | pub fn size(&self) -> usize { |
952 | 0 | std::mem::size_of::<Self>() |
953 | 0 | + self.fields.iter().map(|x| x.size()).sum::<usize>() |
954 | 0 | + self.codecs.capacity() * std::mem::size_of::<Codec>() |
955 | 0 | + self.codecs.iter().map(Codec::size).sum::<usize>() |
956 | 0 | } |
957 | | } |
958 | | |
959 | | /// A [`RowParser`] can be created from a [`RowConverter`] and used to parse bytes to [`Row`] |
960 | | #[derive(Debug)] |
961 | | pub struct RowParser { |
962 | | config: RowConfig, |
963 | | } |
964 | | |
965 | | impl RowParser { |
966 | 0 | fn new(fields: Arc<[SortField]>) -> Self { |
967 | 0 | Self { |
968 | 0 | config: RowConfig { |
969 | 0 | fields, |
970 | 0 | validate_utf8: true, |
971 | 0 | }, |
972 | 0 | } |
973 | 0 | } |
974 | | |
975 | | /// Creates a [`Row`] from the provided `bytes`. |
976 | | /// |
977 | | /// `bytes` must be a [`Row`] produced by the [`RowConverter`] associated with |
978 | | /// this [`RowParser`], otherwise subsequent operations with the produced [`Row`] may panic |
979 | 0 | pub fn parse<'a>(&'a self, bytes: &'a [u8]) -> Row<'a> { |
980 | 0 | Row { |
981 | 0 | data: bytes, |
982 | 0 | config: &self.config, |
983 | 0 | } |
984 | 0 | } |
985 | | } |
986 | | |
987 | | /// The config of a given set of [`Row`] |
988 | | #[derive(Debug, Clone)] |
989 | | struct RowConfig { |
990 | | /// The schema for these rows |
991 | | fields: Arc<[SortField]>, |
992 | | /// Whether to run UTF-8 validation when converting to arrow arrays |
993 | | validate_utf8: bool, |
994 | | } |
995 | | |
996 | | /// A row-oriented representation of arrow data, that is normalized for comparison. |
997 | | /// |
998 | | /// See the [module level documentation](self) and [`RowConverter`] for more details. |
999 | | #[derive(Debug)] |
1000 | | pub struct Rows { |
1001 | | /// Underlying row bytes |
1002 | | buffer: Vec<u8>, |
1003 | | /// Row `i` has data `&buffer[offsets[i]..offsets[i+1]]` |
1004 | | offsets: Vec<usize>, |
1005 | | /// The config for these rows |
1006 | | config: RowConfig, |
1007 | | } |
1008 | | |
1009 | | impl Rows { |
1010 | | /// Append a [`Row`] to this [`Rows`] |
1011 | 0 | pub fn push(&mut self, row: Row<'_>) { |
1012 | 0 | assert!( |
1013 | 0 | Arc::ptr_eq(&row.config.fields, &self.config.fields), |
1014 | 0 | "row was not produced by this RowConverter" |
1015 | | ); |
1016 | 0 | self.config.validate_utf8 |= row.config.validate_utf8; |
1017 | 0 | self.buffer.extend_from_slice(row.data); |
1018 | 0 | self.offsets.push(self.buffer.len()) |
1019 | 0 | } |
1020 | | |
1021 | | /// Returns the row at index `row` |
1022 | 0 | pub fn row(&self, row: usize) -> Row<'_> { |
1023 | 0 | assert!(row + 1 < self.offsets.len()); |
1024 | 0 | unsafe { self.row_unchecked(row) } |
1025 | 0 | } |
1026 | | |
1027 | | /// Returns the row at `index` without bounds checking |
1028 | | /// |
1029 | | /// # Safety |
1030 | | /// Caller must ensure that `index` is less than the number of offsets (#rows + 1) |
1031 | 0 | pub unsafe fn row_unchecked(&self, index: usize) -> Row<'_> { |
1032 | 0 | let end = unsafe { self.offsets.get_unchecked(index + 1) }; |
1033 | 0 | let start = unsafe { self.offsets.get_unchecked(index) }; |
1034 | 0 | let data = unsafe { self.buffer.get_unchecked(*start..*end) }; |
1035 | 0 | Row { |
1036 | 0 | data, |
1037 | 0 | config: &self.config, |
1038 | 0 | } |
1039 | 0 | } |
1040 | | |
1041 | | /// Sets the length of this [`Rows`] to 0 |
1042 | 0 | pub fn clear(&mut self) { |
1043 | 0 | self.offsets.truncate(1); |
1044 | 0 | self.buffer.clear(); |
1045 | 0 | } |
1046 | | |
1047 | | /// Returns the number of [`Row`] in this [`Rows`] |
1048 | 0 | pub fn num_rows(&self) -> usize { |
1049 | 0 | self.offsets.len() - 1 |
1050 | 0 | } |
1051 | | |
1052 | | /// Returns an iterator over the [`Row`] in this [`Rows`] |
1053 | 0 | pub fn iter(&self) -> RowsIter<'_> { |
1054 | 0 | self.into_iter() |
1055 | 0 | } |
1056 | | |
1057 | | /// Returns the size of this instance in bytes |
1058 | | /// |
1059 | | /// Includes the size of `Self`. |
1060 | 0 | pub fn size(&self) -> usize { |
1061 | | // Size of fields is accounted for as part of RowConverter |
1062 | 0 | std::mem::size_of::<Self>() |
1063 | 0 | + self.buffer.len() |
1064 | 0 | + self.offsets.len() * std::mem::size_of::<usize>() |
1065 | 0 | } |
1066 | | |
1067 | | /// Create a [BinaryArray] from the [Rows] data without reallocating the |
1068 | | /// underlying bytes. |
1069 | | /// |
1070 | | /// |
1071 | | /// ``` |
1072 | | /// # use std::sync::Arc; |
1073 | | /// # use std::collections::HashSet; |
1074 | | /// # use arrow_array::cast::AsArray; |
1075 | | /// # use arrow_array::StringArray; |
1076 | | /// # use arrow_row::{OwnedRow, Row, RowConverter, RowParser, SortField}; |
1077 | | /// # use arrow_schema::DataType; |
1078 | | /// # |
1079 | | /// let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
1080 | | /// let array = StringArray::from(vec!["hello", "world", "a", "a", "hello"]); |
1081 | | /// let rows = converter.convert_columns(&[Arc::new(array)]).unwrap(); |
1082 | | /// |
1083 | | /// // We can convert rows into binary format and back. |
1084 | | /// let values: Vec<OwnedRow> = rows.iter().map(|r| r.owned()).collect(); |
1085 | | /// let binary = rows.try_into_binary().expect("known-small array"); |
1086 | | /// let parser = converter.parser(); |
1087 | | /// let parsed: Vec<OwnedRow> = |
1088 | | /// binary.iter().flatten().map(|b| parser.parse(b).owned()).collect(); |
1089 | | /// assert_eq!(values, parsed); |
1090 | | /// ``` |
1091 | | /// |
1092 | | /// # Errors |
1093 | | /// |
1094 | | /// This function will return an error if there is more data than can be stored in |
1095 | | /// a [BinaryArray] -- i.e. if the total data size is more than 2GiB. |
1096 | 0 | pub fn try_into_binary(self) -> Result<BinaryArray, ArrowError> { |
1097 | 0 | if self.buffer.len() > i32::MAX as usize { |
1098 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
1099 | 0 | "{}-byte rows buffer too long to convert into a i32-indexed BinaryArray", |
1100 | 0 | self.buffer.len() |
1101 | 0 | ))); |
1102 | 0 | } |
1103 | | // We've checked that the buffer length fits in an i32; so all offsets into that buffer should fit as well. |
1104 | 0 | let offsets_scalar = ScalarBuffer::from_iter(self.offsets.into_iter().map(i32::usize_as)); |
1105 | | // SAFETY: offsets buffer is nonempty, monotonically increasing, and all represent valid indexes into buffer. |
1106 | 0 | let array = unsafe { |
1107 | 0 | BinaryArray::new_unchecked( |
1108 | 0 | OffsetBuffer::new_unchecked(offsets_scalar), |
1109 | 0 | Buffer::from_vec(self.buffer), |
1110 | 0 | None, |
1111 | | ) |
1112 | | }; |
1113 | 0 | Ok(array) |
1114 | 0 | } |
1115 | | } |
1116 | | |
1117 | | impl<'a> IntoIterator for &'a Rows { |
1118 | | type Item = Row<'a>; |
1119 | | type IntoIter = RowsIter<'a>; |
1120 | | |
1121 | 0 | fn into_iter(self) -> Self::IntoIter { |
1122 | 0 | RowsIter { |
1123 | 0 | rows: self, |
1124 | 0 | start: 0, |
1125 | 0 | end: self.num_rows(), |
1126 | 0 | } |
1127 | 0 | } |
1128 | | } |
1129 | | |
1130 | | /// An iterator over [`Rows`] |
1131 | | #[derive(Debug)] |
1132 | | pub struct RowsIter<'a> { |
1133 | | rows: &'a Rows, |
1134 | | start: usize, |
1135 | | end: usize, |
1136 | | } |
1137 | | |
1138 | | impl<'a> Iterator for RowsIter<'a> { |
1139 | | type Item = Row<'a>; |
1140 | | |
1141 | 0 | fn next(&mut self) -> Option<Self::Item> { |
1142 | 0 | if self.end == self.start { |
1143 | 0 | return None; |
1144 | 0 | } |
1145 | | |
1146 | | // SAFETY: We have checked that `start` is less than `end` |
1147 | 0 | let row = unsafe { self.rows.row_unchecked(self.start) }; |
1148 | 0 | self.start += 1; |
1149 | 0 | Some(row) |
1150 | 0 | } |
1151 | | |
1152 | 0 | fn size_hint(&self) -> (usize, Option<usize>) { |
1153 | 0 | let len = self.len(); |
1154 | 0 | (len, Some(len)) |
1155 | 0 | } |
1156 | | } |
1157 | | |
1158 | | impl ExactSizeIterator for RowsIter<'_> { |
1159 | 0 | fn len(&self) -> usize { |
1160 | 0 | self.end - self.start |
1161 | 0 | } |
1162 | | } |
1163 | | |
1164 | | impl DoubleEndedIterator for RowsIter<'_> { |
1165 | 0 | fn next_back(&mut self) -> Option<Self::Item> { |
1166 | 0 | if self.end == self.start { |
1167 | 0 | return None; |
1168 | 0 | } |
1169 | | // Safety: We have checked that `start` is less than `end` |
1170 | 0 | let row = unsafe { self.rows.row_unchecked(self.end) }; |
1171 | 0 | self.end -= 1; |
1172 | 0 | Some(row) |
1173 | 0 | } |
1174 | | } |
1175 | | |
1176 | | /// A comparable representation of a row. |
1177 | | /// |
1178 | | /// See the [module level documentation](self) for more details. |
1179 | | /// |
1180 | | /// Two [`Row`] can only be compared if they both belong to [`Rows`] |
1181 | | /// returned by calls to [`RowConverter::convert_columns`] on the same |
1182 | | /// [`RowConverter`]. If different [`RowConverter`]s are used, any |
1183 | | /// ordering established by comparing the [`Row`] is arbitrary. |
1184 | | #[derive(Debug, Copy, Clone)] |
1185 | | pub struct Row<'a> { |
1186 | | data: &'a [u8], |
1187 | | config: &'a RowConfig, |
1188 | | } |
1189 | | |
1190 | | impl<'a> Row<'a> { |
1191 | | /// Create owned version of the row to detach it from the shared [`Rows`]. |
1192 | 0 | pub fn owned(&self) -> OwnedRow { |
1193 | 0 | OwnedRow { |
1194 | 0 | data: self.data.into(), |
1195 | 0 | config: self.config.clone(), |
1196 | 0 | } |
1197 | 0 | } |
1198 | | |
1199 | | /// The row's bytes, with the lifetime of the underlying data. |
1200 | 0 | pub fn data(&self) -> &'a [u8] { |
1201 | 0 | self.data |
1202 | 0 | } |
1203 | | } |
1204 | | |
1205 | | // Manually derive these as don't wish to include `fields` |
1206 | | |
1207 | | impl PartialEq for Row<'_> { |
1208 | | #[inline] |
1209 | | fn eq(&self, other: &Self) -> bool { |
1210 | | self.data.eq(other.data) |
1211 | | } |
1212 | | } |
1213 | | |
1214 | | impl Eq for Row<'_> {} |
1215 | | |
1216 | | impl PartialOrd for Row<'_> { |
1217 | | #[inline] |
1218 | | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
1219 | | Some(self.cmp(other)) |
1220 | | } |
1221 | | } |
1222 | | |
1223 | | impl Ord for Row<'_> { |
1224 | | #[inline] |
1225 | | fn cmp(&self, other: &Self) -> Ordering { |
1226 | | self.data.cmp(other.data) |
1227 | | } |
1228 | | } |
1229 | | |
1230 | | impl Hash for Row<'_> { |
1231 | | #[inline] |
1232 | | fn hash<H: Hasher>(&self, state: &mut H) { |
1233 | | self.data.hash(state) |
1234 | | } |
1235 | | } |
1236 | | |
1237 | | impl AsRef<[u8]> for Row<'_> { |
1238 | | #[inline] |
1239 | 0 | fn as_ref(&self) -> &[u8] { |
1240 | 0 | self.data |
1241 | 0 | } |
1242 | | } |
1243 | | |
1244 | | /// Owned version of a [`Row`] that can be moved/cloned freely. |
1245 | | /// |
1246 | | /// This contains the data for the one specific row (not the entire buffer of all rows). |
1247 | | #[derive(Debug, Clone)] |
1248 | | pub struct OwnedRow { |
1249 | | data: Box<[u8]>, |
1250 | | config: RowConfig, |
1251 | | } |
1252 | | |
1253 | | impl OwnedRow { |
1254 | | /// Get borrowed [`Row`] from owned version. |
1255 | | /// |
1256 | | /// This is helpful if you want to compare an [`OwnedRow`] with a [`Row`]. |
1257 | 0 | pub fn row(&self) -> Row<'_> { |
1258 | 0 | Row { |
1259 | 0 | data: &self.data, |
1260 | 0 | config: &self.config, |
1261 | 0 | } |
1262 | 0 | } |
1263 | | } |
1264 | | |
1265 | | // Manually derive these as don't wish to include `fields`. Also we just want to use the same `Row` implementations here. |
1266 | | |
1267 | | impl PartialEq for OwnedRow { |
1268 | | #[inline] |
1269 | | fn eq(&self, other: &Self) -> bool { |
1270 | | self.row().eq(&other.row()) |
1271 | | } |
1272 | | } |
1273 | | |
1274 | | impl Eq for OwnedRow {} |
1275 | | |
1276 | | impl PartialOrd for OwnedRow { |
1277 | | #[inline] |
1278 | | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
1279 | | Some(self.cmp(other)) |
1280 | | } |
1281 | | } |
1282 | | |
1283 | | impl Ord for OwnedRow { |
1284 | | #[inline] |
1285 | | fn cmp(&self, other: &Self) -> Ordering { |
1286 | | self.row().cmp(&other.row()) |
1287 | | } |
1288 | | } |
1289 | | |
1290 | | impl Hash for OwnedRow { |
1291 | | #[inline] |
1292 | | fn hash<H: Hasher>(&self, state: &mut H) { |
1293 | | self.row().hash(state) |
1294 | | } |
1295 | | } |
1296 | | |
1297 | | impl AsRef<[u8]> for OwnedRow { |
1298 | | #[inline] |
1299 | | fn as_ref(&self) -> &[u8] { |
1300 | | &self.data |
1301 | | } |
1302 | | } |
1303 | | |
1304 | | /// Returns the null sentinel, negated if `invert` is true |
1305 | | #[inline] |
1306 | 0 | fn null_sentinel(options: SortOptions) -> u8 { |
1307 | 0 | match options.nulls_first { |
1308 | 0 | true => 0, |
1309 | 0 | false => 0xFF, |
1310 | | } |
1311 | 0 | } |
1312 | | |
1313 | | /// Stores the lengths of the rows. Lazily materializes lengths for columns with fixed-size types. |
1314 | | enum LengthTracker { |
1315 | | /// Fixed state: All rows have length `length` |
1316 | | Fixed { length: usize, num_rows: usize }, |
1317 | | /// Variable state: The length of row `i` is `lengths[i] + fixed_length` |
1318 | | Variable { |
1319 | | fixed_length: usize, |
1320 | | lengths: Vec<usize>, |
1321 | | }, |
1322 | | } |
1323 | | |
1324 | | impl LengthTracker { |
1325 | 0 | fn new(num_rows: usize) -> Self { |
1326 | 0 | Self::Fixed { |
1327 | 0 | length: 0, |
1328 | 0 | num_rows, |
1329 | 0 | } |
1330 | 0 | } |
1331 | | |
1332 | | /// Adds a column of fixed-length elements, each of size `new_length` to the LengthTracker |
1333 | 0 | fn push_fixed(&mut self, new_length: usize) { |
1334 | 0 | match self { |
1335 | 0 | LengthTracker::Fixed { length, .. } => *length += new_length, |
1336 | 0 | LengthTracker::Variable { fixed_length, .. } => *fixed_length += new_length, |
1337 | | } |
1338 | 0 | } |
1339 | | |
1340 | | /// Adds a column of possibly variable-length elements, element `i` has length `new_lengths.nth(i)` |
1341 | 0 | fn push_variable(&mut self, new_lengths: impl ExactSizeIterator<Item = usize>) { |
1342 | 0 | match self { |
1343 | 0 | LengthTracker::Fixed { length, .. } => { |
1344 | 0 | *self = LengthTracker::Variable { |
1345 | 0 | fixed_length: *length, |
1346 | 0 | lengths: new_lengths.collect(), |
1347 | 0 | } |
1348 | | } |
1349 | 0 | LengthTracker::Variable { lengths, .. } => { |
1350 | 0 | assert_eq!(lengths.len(), new_lengths.len()); |
1351 | 0 | lengths |
1352 | 0 | .iter_mut() |
1353 | 0 | .zip(new_lengths) |
1354 | 0 | .for_each(|(length, new_length)| *length += new_length); |
1355 | | } |
1356 | | } |
1357 | 0 | } |
1358 | | |
1359 | | /// Returns the tracked row lengths as a slice |
1360 | 0 | fn materialized(&mut self) -> &mut [usize] { |
1361 | 0 | if let LengthTracker::Fixed { length, num_rows } = *self { |
1362 | 0 | *self = LengthTracker::Variable { |
1363 | 0 | fixed_length: length, |
1364 | 0 | lengths: vec![0; num_rows], |
1365 | 0 | }; |
1366 | 0 | } |
1367 | | |
1368 | 0 | match self { |
1369 | 0 | LengthTracker::Variable { lengths, .. } => lengths, |
1370 | 0 | LengthTracker::Fixed { .. } => unreachable!(), |
1371 | | } |
1372 | 0 | } |
1373 | | |
1374 | | /// Initializes the offsets using the tracked lengths. Returns the sum of the |
1375 | | /// lengths of the rows added. |
1376 | | /// |
1377 | | /// We initialize the offsets shifted down by one row index. |
1378 | | /// |
1379 | | /// As the rows are appended to the offsets will be incremented to match |
1380 | | /// |
1381 | | /// For example, consider the case of 3 rows of length 3, 4, and 6 respectively. |
1382 | | /// The offsets would be initialized to `0, 0, 3, 7` |
1383 | | /// |
1384 | | /// Writing the first row entirely would yield `0, 3, 3, 7` |
1385 | | /// The second, `0, 3, 7, 7` |
1386 | | /// The third, `0, 3, 7, 13` |
1387 | | // |
1388 | | /// This would be the final offsets for reading |
1389 | | // |
1390 | | /// In this way offsets tracks the position during writing whilst eventually serving |
1391 | 0 | fn extend_offsets(&self, initial_offset: usize, offsets: &mut Vec<usize>) -> usize { |
1392 | 0 | match self { |
1393 | 0 | LengthTracker::Fixed { length, num_rows } => { |
1394 | 0 | offsets.extend((0..*num_rows).map(|i| initial_offset + i * length)); |
1395 | | |
1396 | 0 | initial_offset + num_rows * length |
1397 | | } |
1398 | | LengthTracker::Variable { |
1399 | 0 | fixed_length, |
1400 | 0 | lengths, |
1401 | | } => { |
1402 | 0 | let mut acc = initial_offset; |
1403 | | |
1404 | 0 | offsets.extend(lengths.iter().map(|length| { |
1405 | 0 | let current = acc; |
1406 | 0 | acc += length + fixed_length; |
1407 | 0 | current |
1408 | 0 | })); |
1409 | | |
1410 | 0 | acc |
1411 | | } |
1412 | | } |
1413 | 0 | } |
1414 | | } |
1415 | | |
1416 | | /// Computes the length of each encoded [`Rows`] and returns an empty [`Rows`] |
1417 | 0 | fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { |
1418 | | use fixed::FixedLengthEncoding; |
1419 | | |
1420 | 0 | let num_rows = cols.first().map(|x| x.len()).unwrap_or(0); |
1421 | 0 | let mut tracker = LengthTracker::new(num_rows); |
1422 | | |
1423 | 0 | for (array, encoder) in cols.iter().zip(encoders) { |
1424 | 0 | match encoder { |
1425 | | Encoder::Stateless => { |
1426 | 0 | downcast_primitive_array! { |
1427 | 0 | array => tracker.push_fixed(fixed::encoded_len(array)), |
1428 | 0 | DataType::Null => {}, |
1429 | 0 | DataType::Boolean => tracker.push_fixed(bool::ENCODED_LEN), |
1430 | 0 | DataType::Binary => tracker.push_variable( |
1431 | 0 | as_generic_binary_array::<i32>(array) |
1432 | 0 | .iter() |
1433 | 0 | .map(|slice| variable::encoded_len(slice)) |
1434 | | ), |
1435 | 0 | DataType::LargeBinary => tracker.push_variable( |
1436 | 0 | as_generic_binary_array::<i64>(array) |
1437 | 0 | .iter() |
1438 | 0 | .map(|slice| variable::encoded_len(slice)) |
1439 | | ), |
1440 | 0 | DataType::BinaryView => tracker.push_variable( |
1441 | 0 | array.as_binary_view() |
1442 | 0 | .iter() |
1443 | 0 | .map(|slice| variable::encoded_len(slice)) |
1444 | | ), |
1445 | 0 | DataType::Utf8 => tracker.push_variable( |
1446 | 0 | array.as_string::<i32>() |
1447 | 0 | .iter() |
1448 | 0 | .map(|slice| variable::encoded_len(slice.map(|x| x.as_bytes()))) |
1449 | | ), |
1450 | 0 | DataType::LargeUtf8 => tracker.push_variable( |
1451 | 0 | array.as_string::<i64>() |
1452 | 0 | .iter() |
1453 | 0 | .map(|slice| variable::encoded_len(slice.map(|x| x.as_bytes()))) |
1454 | | ), |
1455 | 0 | DataType::Utf8View => tracker.push_variable( |
1456 | 0 | array.as_string_view() |
1457 | 0 | .iter() |
1458 | 0 | .map(|slice| variable::encoded_len(slice.map(|x| x.as_bytes()))) |
1459 | | ), |
1460 | 0 | DataType::FixedSizeBinary(len) => { |
1461 | 0 | let len = len.to_usize().unwrap(); |
1462 | 0 | tracker.push_fixed(1 + len) |
1463 | | } |
1464 | 0 | _ => unimplemented!("unsupported data type: {}", array.data_type()), |
1465 | | } |
1466 | | } |
1467 | 0 | Encoder::Dictionary(values, null) => { |
1468 | 0 | downcast_dictionary_array! { |
1469 | | array => { |
1470 | 0 | tracker.push_variable( |
1471 | 0 | array.keys().iter().map(|v| match v { |
1472 | 0 | Some(k) => values.row(k.as_usize()).data.len(), |
1473 | 0 | None => null.data.len(), |
1474 | 0 | }) |
1475 | | ) |
1476 | | } |
1477 | 0 | _ => unreachable!(), |
1478 | | } |
1479 | | } |
1480 | 0 | Encoder::Struct(rows, null) => { |
1481 | 0 | let array = as_struct_array(array); |
1482 | 0 | tracker.push_variable((0..array.len()).map(|idx| match array.is_valid(idx) { |
1483 | 0 | true => 1 + rows.row(idx).as_ref().len(), |
1484 | 0 | false => 1 + null.data.len(), |
1485 | 0 | })); |
1486 | | } |
1487 | 0 | Encoder::List(rows) => match array.data_type() { |
1488 | | DataType::List(_) => { |
1489 | 0 | list::compute_lengths(tracker.materialized(), rows, as_list_array(array)) |
1490 | | } |
1491 | | DataType::LargeList(_) => { |
1492 | 0 | list::compute_lengths(tracker.materialized(), rows, as_large_list_array(array)) |
1493 | | } |
1494 | 0 | DataType::FixedSizeList(_, _) => compute_lengths_fixed_size_list( |
1495 | 0 | &mut tracker, |
1496 | 0 | rows, |
1497 | 0 | as_fixed_size_list_array(array), |
1498 | | ), |
1499 | 0 | _ => unreachable!(), |
1500 | | }, |
1501 | 0 | Encoder::RunEndEncoded(rows) => match array.data_type() { |
1502 | 0 | DataType::RunEndEncoded(r, _) => match r.data_type() { |
1503 | 0 | DataType::Int16 => run::compute_lengths( |
1504 | 0 | tracker.materialized(), |
1505 | 0 | rows, |
1506 | 0 | array.as_run::<Int16Type>(), |
1507 | | ), |
1508 | 0 | DataType::Int32 => run::compute_lengths( |
1509 | 0 | tracker.materialized(), |
1510 | 0 | rows, |
1511 | 0 | array.as_run::<Int32Type>(), |
1512 | | ), |
1513 | 0 | DataType::Int64 => run::compute_lengths( |
1514 | 0 | tracker.materialized(), |
1515 | 0 | rows, |
1516 | 0 | array.as_run::<Int64Type>(), |
1517 | | ), |
1518 | 0 | _ => unreachable!("Unsupported run end index type: {r:?}"), |
1519 | | }, |
1520 | 0 | _ => unreachable!(), |
1521 | | }, |
1522 | | } |
1523 | | } |
1524 | | |
1525 | 0 | tracker |
1526 | 0 | } |
1527 | | |
1528 | | /// Encodes a column to the provided [`Rows`] incrementing the offsets as it progresses |
1529 | 0 | fn encode_column( |
1530 | 0 | data: &mut [u8], |
1531 | 0 | offsets: &mut [usize], |
1532 | 0 | column: &dyn Array, |
1533 | 0 | opts: SortOptions, |
1534 | 0 | encoder: &Encoder<'_>, |
1535 | 0 | ) { |
1536 | 0 | match encoder { |
1537 | | Encoder::Stateless => { |
1538 | 0 | downcast_primitive_array! { |
1539 | | column => { |
1540 | 0 | if let Some(nulls) = column.nulls().filter(|n| n.null_count() > 0){ |
1541 | 0 | fixed::encode(data, offsets, column.values(), nulls, opts) |
1542 | | } else { |
1543 | 0 | fixed::encode_not_null(data, offsets, column.values(), opts) |
1544 | | } |
1545 | | } |
1546 | 0 | DataType::Null => {} |
1547 | | DataType::Boolean => { |
1548 | 0 | if let Some(nulls) = column.nulls().filter(|n| n.null_count() > 0){ |
1549 | 0 | fixed::encode_boolean(data, offsets, column.as_boolean().values(), nulls, opts) |
1550 | | } else { |
1551 | 0 | fixed::encode_boolean_not_null(data, offsets, column.as_boolean().values(), opts) |
1552 | | } |
1553 | | } |
1554 | | DataType::Binary => { |
1555 | 0 | variable::encode(data, offsets, as_generic_binary_array::<i32>(column).iter(), opts) |
1556 | | } |
1557 | | DataType::BinaryView => { |
1558 | 0 | variable::encode(data, offsets, column.as_binary_view().iter(), opts) |
1559 | | } |
1560 | | DataType::LargeBinary => { |
1561 | 0 | variable::encode(data, offsets, as_generic_binary_array::<i64>(column).iter(), opts) |
1562 | | } |
1563 | 0 | DataType::Utf8 => variable::encode( |
1564 | 0 | data, offsets, |
1565 | 0 | column.as_string::<i32>().iter().map(|x| x.map(|x| x.as_bytes())), |
1566 | 0 | opts, |
1567 | | ), |
1568 | 0 | DataType::LargeUtf8 => variable::encode( |
1569 | 0 | data, offsets, |
1570 | 0 | column.as_string::<i64>() |
1571 | 0 | .iter() |
1572 | 0 | .map(|x| x.map(|x| x.as_bytes())), |
1573 | 0 | opts, |
1574 | | ), |
1575 | 0 | DataType::Utf8View => variable::encode( |
1576 | 0 | data, offsets, |
1577 | 0 | column.as_string_view().iter().map(|x| x.map(|x| x.as_bytes())), |
1578 | 0 | opts, |
1579 | | ), |
1580 | | DataType::FixedSizeBinary(_) => { |
1581 | 0 | let array = column.as_any().downcast_ref().unwrap(); |
1582 | 0 | fixed::encode_fixed_size_binary(data, offsets, array, opts) |
1583 | | } |
1584 | 0 | _ => unimplemented!("unsupported data type: {}", column.data_type()), |
1585 | | } |
1586 | | } |
1587 | 0 | Encoder::Dictionary(values, nulls) => { |
1588 | 0 | downcast_dictionary_array! { |
1589 | 0 | column => encode_dictionary_values(data, offsets, column, values, nulls), |
1590 | 0 | _ => unreachable!() |
1591 | | } |
1592 | | } |
1593 | 0 | Encoder::Struct(rows, null) => { |
1594 | 0 | let array = as_struct_array(column); |
1595 | 0 | let null_sentinel = null_sentinel(opts); |
1596 | 0 | offsets |
1597 | 0 | .iter_mut() |
1598 | 0 | .skip(1) |
1599 | 0 | .enumerate() |
1600 | 0 | .for_each(|(idx, offset)| { |
1601 | 0 | let (row, sentinel) = match array.is_valid(idx) { |
1602 | 0 | true => (rows.row(idx), 0x01), |
1603 | 0 | false => (*null, null_sentinel), |
1604 | | }; |
1605 | 0 | let end_offset = *offset + 1 + row.as_ref().len(); |
1606 | 0 | data[*offset] = sentinel; |
1607 | 0 | data[*offset + 1..end_offset].copy_from_slice(row.as_ref()); |
1608 | 0 | *offset = end_offset; |
1609 | 0 | }) |
1610 | | } |
1611 | 0 | Encoder::List(rows) => match column.data_type() { |
1612 | 0 | DataType::List(_) => list::encode(data, offsets, rows, opts, as_list_array(column)), |
1613 | | DataType::LargeList(_) => { |
1614 | 0 | list::encode(data, offsets, rows, opts, as_large_list_array(column)) |
1615 | | } |
1616 | | DataType::FixedSizeList(_, _) => { |
1617 | 0 | encode_fixed_size_list(data, offsets, rows, opts, as_fixed_size_list_array(column)) |
1618 | | } |
1619 | 0 | _ => unreachable!(), |
1620 | | }, |
1621 | 0 | Encoder::RunEndEncoded(rows) => match column.data_type() { |
1622 | 0 | DataType::RunEndEncoded(r, _) => match r.data_type() { |
1623 | | DataType::Int16 => { |
1624 | 0 | run::encode(data, offsets, rows, opts, column.as_run::<Int16Type>()) |
1625 | | } |
1626 | | DataType::Int32 => { |
1627 | 0 | run::encode(data, offsets, rows, opts, column.as_run::<Int32Type>()) |
1628 | | } |
1629 | | DataType::Int64 => { |
1630 | 0 | run::encode(data, offsets, rows, opts, column.as_run::<Int64Type>()) |
1631 | | } |
1632 | 0 | _ => unreachable!("Unsupported run end index type: {r:?}"), |
1633 | | }, |
1634 | 0 | _ => unreachable!(), |
1635 | | }, |
1636 | | } |
1637 | 0 | } |
1638 | | |
1639 | | /// Encode dictionary values not preserving the dictionary encoding |
1640 | 0 | pub fn encode_dictionary_values<K: ArrowDictionaryKeyType>( |
1641 | 0 | data: &mut [u8], |
1642 | 0 | offsets: &mut [usize], |
1643 | 0 | column: &DictionaryArray<K>, |
1644 | 0 | values: &Rows, |
1645 | 0 | null: &Row<'_>, |
1646 | 0 | ) { |
1647 | 0 | for (offset, k) in offsets.iter_mut().skip(1).zip(column.keys()) { |
1648 | 0 | let row = match k { |
1649 | 0 | Some(k) => values.row(k.as_usize()).data, |
1650 | 0 | None => null.data, |
1651 | | }; |
1652 | 0 | let end_offset = *offset + row.len(); |
1653 | 0 | data[*offset..end_offset].copy_from_slice(row); |
1654 | 0 | *offset = end_offset; |
1655 | | } |
1656 | 0 | } |
1657 | | |
1658 | | macro_rules! decode_primitive_helper { |
1659 | | ($t:ty, $rows:ident, $data_type:ident, $options:ident) => { |
1660 | | Arc::new(decode_primitive::<$t>($rows, $data_type, $options)) |
1661 | | }; |
1662 | | } |
1663 | | |
1664 | | /// Decodes a the provided `field` from `rows` |
1665 | | /// |
1666 | | /// # Safety |
1667 | | /// |
1668 | | /// Rows must contain valid data for the provided field |
1669 | 0 | unsafe fn decode_column( |
1670 | 0 | field: &SortField, |
1671 | 0 | rows: &mut [&[u8]], |
1672 | 0 | codec: &Codec, |
1673 | 0 | validate_utf8: bool, |
1674 | 0 | ) -> Result<ArrayRef, ArrowError> { |
1675 | 0 | let options = field.options; |
1676 | | |
1677 | 0 | let array: ArrayRef = match codec { |
1678 | | Codec::Stateless => { |
1679 | 0 | let data_type = field.data_type.clone(); |
1680 | 0 | downcast_primitive! { |
1681 | 0 | data_type => (decode_primitive_helper, rows, data_type, options), |
1682 | 0 | DataType::Null => Arc::new(NullArray::new(rows.len())), |
1683 | 0 | DataType::Boolean => Arc::new(decode_bool(rows, options)), |
1684 | 0 | DataType::Binary => Arc::new(decode_binary::<i32>(rows, options)), |
1685 | 0 | DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)), |
1686 | 0 | DataType::BinaryView => Arc::new(decode_binary_view(rows, options)), |
1687 | 0 | DataType::FixedSizeBinary(size) => Arc::new(decode_fixed_size_binary(rows, size, options)), |
1688 | 0 | DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options, validate_utf8)), |
1689 | 0 | DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options, validate_utf8)), |
1690 | 0 | DataType::Utf8View => Arc::new(decode_string_view(rows, options, validate_utf8)), |
1691 | 0 | _ => return Err(ArrowError::NotYetImplemented(format!("unsupported data type: {data_type}" ))) |
1692 | | } |
1693 | | } |
1694 | 0 | Codec::Dictionary(converter, _) => { |
1695 | 0 | let cols = converter.convert_raw(rows, validate_utf8)?; |
1696 | 0 | cols.into_iter().next().unwrap() |
1697 | | } |
1698 | 0 | Codec::Struct(converter, _) => { |
1699 | 0 | let (null_count, nulls) = fixed::decode_nulls(rows); |
1700 | 0 | rows.iter_mut().for_each(|row| *row = &row[1..]); |
1701 | 0 | let children = converter.convert_raw(rows, validate_utf8)?; |
1702 | | |
1703 | 0 | let child_data: Vec<ArrayData> = children.iter().map(|c| c.to_data()).collect(); |
1704 | | // Since RowConverter flattens certain data types (i.e. Dictionary), |
1705 | | // we need to use updated data type instead of original field |
1706 | 0 | let corrected_fields: Vec<Field> = match &field.data_type { |
1707 | 0 | DataType::Struct(struct_fields) => struct_fields |
1708 | 0 | .iter() |
1709 | 0 | .zip(child_data.iter()) |
1710 | 0 | .map(|(orig_field, child_array)| { |
1711 | 0 | orig_field |
1712 | 0 | .as_ref() |
1713 | 0 | .clone() |
1714 | 0 | .with_data_type(child_array.data_type().clone()) |
1715 | 0 | }) |
1716 | 0 | .collect(), |
1717 | 0 | _ => unreachable!("Only Struct types should be corrected here"), |
1718 | | }; |
1719 | 0 | let corrected_struct_type = DataType::Struct(corrected_fields.into()); |
1720 | 0 | let builder = ArrayDataBuilder::new(corrected_struct_type) |
1721 | 0 | .len(rows.len()) |
1722 | 0 | .null_count(null_count) |
1723 | 0 | .null_bit_buffer(Some(nulls)) |
1724 | 0 | .child_data(child_data); |
1725 | | |
1726 | 0 | Arc::new(StructArray::from(builder.build_unchecked())) |
1727 | | } |
1728 | 0 | Codec::List(converter) => match &field.data_type { |
1729 | | DataType::List(_) => { |
1730 | 0 | Arc::new(list::decode::<i32>(converter, rows, field, validate_utf8)?) |
1731 | | } |
1732 | | DataType::LargeList(_) => { |
1733 | 0 | Arc::new(list::decode::<i64>(converter, rows, field, validate_utf8)?) |
1734 | | } |
1735 | 0 | DataType::FixedSizeList(_, value_length) => Arc::new(list::decode_fixed_size_list( |
1736 | 0 | converter, |
1737 | 0 | rows, |
1738 | 0 | field, |
1739 | 0 | validate_utf8, |
1740 | 0 | value_length.as_usize(), |
1741 | 0 | )?), |
1742 | 0 | _ => unreachable!(), |
1743 | | }, |
1744 | 0 | Codec::RunEndEncoded(converter) => match &field.data_type { |
1745 | 0 | DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { |
1746 | 0 | DataType::Int16 => Arc::new(run::decode::<Int16Type>( |
1747 | 0 | converter, |
1748 | 0 | rows, |
1749 | 0 | field, |
1750 | 0 | validate_utf8, |
1751 | 0 | )?), |
1752 | 0 | DataType::Int32 => Arc::new(run::decode::<Int32Type>( |
1753 | 0 | converter, |
1754 | 0 | rows, |
1755 | 0 | field, |
1756 | 0 | validate_utf8, |
1757 | 0 | )?), |
1758 | 0 | DataType::Int64 => Arc::new(run::decode::<Int64Type>( |
1759 | 0 | converter, |
1760 | 0 | rows, |
1761 | 0 | field, |
1762 | 0 | validate_utf8, |
1763 | 0 | )?), |
1764 | 0 | _ => unreachable!(), |
1765 | | }, |
1766 | 0 | _ => unreachable!(), |
1767 | | }, |
1768 | | }; |
1769 | 0 | Ok(array) |
1770 | 0 | } |
1771 | | |
1772 | | #[cfg(test)] |
1773 | | mod tests { |
1774 | | use rand::distr::uniform::SampleUniform; |
1775 | | use rand::distr::{Distribution, StandardUniform}; |
1776 | | use rand::{rng, Rng}; |
1777 | | |
1778 | | use arrow_array::builder::*; |
1779 | | use arrow_array::types::*; |
1780 | | use arrow_array::*; |
1781 | | use arrow_buffer::{i256, NullBuffer}; |
1782 | | use arrow_buffer::{Buffer, OffsetBuffer}; |
1783 | | use arrow_cast::display::{ArrayFormatter, FormatOptions}; |
1784 | | use arrow_ord::sort::{LexicographicalComparator, SortColumn}; |
1785 | | |
1786 | | use super::*; |
1787 | | |
1788 | | #[test] |
1789 | | fn test_fixed_width() { |
1790 | | let cols = [ |
1791 | | Arc::new(Int16Array::from_iter([ |
1792 | | Some(1), |
1793 | | Some(2), |
1794 | | None, |
1795 | | Some(-5), |
1796 | | Some(2), |
1797 | | Some(2), |
1798 | | Some(0), |
1799 | | ])) as ArrayRef, |
1800 | | Arc::new(Float32Array::from_iter([ |
1801 | | Some(1.3), |
1802 | | Some(2.5), |
1803 | | None, |
1804 | | Some(4.), |
1805 | | Some(0.1), |
1806 | | Some(-4.), |
1807 | | Some(-0.), |
1808 | | ])) as ArrayRef, |
1809 | | ]; |
1810 | | |
1811 | | let converter = RowConverter::new(vec![ |
1812 | | SortField::new(DataType::Int16), |
1813 | | SortField::new(DataType::Float32), |
1814 | | ]) |
1815 | | .unwrap(); |
1816 | | let rows = converter.convert_columns(&cols).unwrap(); |
1817 | | |
1818 | | assert_eq!(rows.offsets, &[0, 8, 16, 24, 32, 40, 48, 56]); |
1819 | | assert_eq!( |
1820 | | rows.buffer, |
1821 | | &[ |
1822 | | 1, 128, 1, // |
1823 | | 1, 191, 166, 102, 102, // |
1824 | | 1, 128, 2, // |
1825 | | 1, 192, 32, 0, 0, // |
1826 | | 0, 0, 0, // |
1827 | | 0, 0, 0, 0, 0, // |
1828 | | 1, 127, 251, // |
1829 | | 1, 192, 128, 0, 0, // |
1830 | | 1, 128, 2, // |
1831 | | 1, 189, 204, 204, 205, // |
1832 | | 1, 128, 2, // |
1833 | | 1, 63, 127, 255, 255, // |
1834 | | 1, 128, 0, // |
1835 | | 1, 127, 255, 255, 255 // |
1836 | | ] |
1837 | | ); |
1838 | | |
1839 | | assert!(rows.row(3) < rows.row(6)); |
1840 | | assert!(rows.row(0) < rows.row(1)); |
1841 | | assert!(rows.row(3) < rows.row(0)); |
1842 | | assert!(rows.row(4) < rows.row(1)); |
1843 | | assert!(rows.row(5) < rows.row(4)); |
1844 | | |
1845 | | let back = converter.convert_rows(&rows).unwrap(); |
1846 | | for (expected, actual) in cols.iter().zip(&back) { |
1847 | | assert_eq!(expected, actual); |
1848 | | } |
1849 | | } |
1850 | | |
1851 | | #[test] |
1852 | | fn test_decimal32() { |
1853 | | let converter = RowConverter::new(vec![SortField::new(DataType::Decimal32( |
1854 | | DECIMAL32_MAX_PRECISION, |
1855 | | 7, |
1856 | | ))]) |
1857 | | .unwrap(); |
1858 | | let col = Arc::new( |
1859 | | Decimal32Array::from_iter([ |
1860 | | None, |
1861 | | Some(i32::MIN), |
1862 | | Some(-13), |
1863 | | Some(46_i32), |
1864 | | Some(5456_i32), |
1865 | | Some(i32::MAX), |
1866 | | ]) |
1867 | | .with_precision_and_scale(9, 7) |
1868 | | .unwrap(), |
1869 | | ) as ArrayRef; |
1870 | | |
1871 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
1872 | | for i in 0..rows.num_rows() - 1 { |
1873 | | assert!(rows.row(i) < rows.row(i + 1)); |
1874 | | } |
1875 | | |
1876 | | let back = converter.convert_rows(&rows).unwrap(); |
1877 | | assert_eq!(back.len(), 1); |
1878 | | assert_eq!(col.as_ref(), back[0].as_ref()) |
1879 | | } |
1880 | | |
1881 | | #[test] |
1882 | | fn test_decimal64() { |
1883 | | let converter = RowConverter::new(vec![SortField::new(DataType::Decimal64( |
1884 | | DECIMAL64_MAX_PRECISION, |
1885 | | 7, |
1886 | | ))]) |
1887 | | .unwrap(); |
1888 | | let col = Arc::new( |
1889 | | Decimal64Array::from_iter([ |
1890 | | None, |
1891 | | Some(i64::MIN), |
1892 | | Some(-13), |
1893 | | Some(46_i64), |
1894 | | Some(5456_i64), |
1895 | | Some(i64::MAX), |
1896 | | ]) |
1897 | | .with_precision_and_scale(18, 7) |
1898 | | .unwrap(), |
1899 | | ) as ArrayRef; |
1900 | | |
1901 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
1902 | | for i in 0..rows.num_rows() - 1 { |
1903 | | assert!(rows.row(i) < rows.row(i + 1)); |
1904 | | } |
1905 | | |
1906 | | let back = converter.convert_rows(&rows).unwrap(); |
1907 | | assert_eq!(back.len(), 1); |
1908 | | assert_eq!(col.as_ref(), back[0].as_ref()) |
1909 | | } |
1910 | | |
1911 | | #[test] |
1912 | | fn test_decimal128() { |
1913 | | let converter = RowConverter::new(vec![SortField::new(DataType::Decimal128( |
1914 | | DECIMAL128_MAX_PRECISION, |
1915 | | 7, |
1916 | | ))]) |
1917 | | .unwrap(); |
1918 | | let col = Arc::new( |
1919 | | Decimal128Array::from_iter([ |
1920 | | None, |
1921 | | Some(i128::MIN), |
1922 | | Some(-13), |
1923 | | Some(46_i128), |
1924 | | Some(5456_i128), |
1925 | | Some(i128::MAX), |
1926 | | ]) |
1927 | | .with_precision_and_scale(38, 7) |
1928 | | .unwrap(), |
1929 | | ) as ArrayRef; |
1930 | | |
1931 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
1932 | | for i in 0..rows.num_rows() - 1 { |
1933 | | assert!(rows.row(i) < rows.row(i + 1)); |
1934 | | } |
1935 | | |
1936 | | let back = converter.convert_rows(&rows).unwrap(); |
1937 | | assert_eq!(back.len(), 1); |
1938 | | assert_eq!(col.as_ref(), back[0].as_ref()) |
1939 | | } |
1940 | | |
1941 | | #[test] |
1942 | | fn test_decimal256() { |
1943 | | let converter = RowConverter::new(vec![SortField::new(DataType::Decimal256( |
1944 | | DECIMAL256_MAX_PRECISION, |
1945 | | 7, |
1946 | | ))]) |
1947 | | .unwrap(); |
1948 | | let col = Arc::new( |
1949 | | Decimal256Array::from_iter([ |
1950 | | None, |
1951 | | Some(i256::MIN), |
1952 | | Some(i256::from_parts(0, -1)), |
1953 | | Some(i256::from_parts(u128::MAX, -1)), |
1954 | | Some(i256::from_parts(u128::MAX, 0)), |
1955 | | Some(i256::from_parts(0, 46_i128)), |
1956 | | Some(i256::from_parts(5, 46_i128)), |
1957 | | Some(i256::MAX), |
1958 | | ]) |
1959 | | .with_precision_and_scale(DECIMAL256_MAX_PRECISION, 7) |
1960 | | .unwrap(), |
1961 | | ) as ArrayRef; |
1962 | | |
1963 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
1964 | | for i in 0..rows.num_rows() - 1 { |
1965 | | assert!(rows.row(i) < rows.row(i + 1)); |
1966 | | } |
1967 | | |
1968 | | let back = converter.convert_rows(&rows).unwrap(); |
1969 | | assert_eq!(back.len(), 1); |
1970 | | assert_eq!(col.as_ref(), back[0].as_ref()) |
1971 | | } |
1972 | | |
1973 | | #[test] |
1974 | | fn test_bool() { |
1975 | | let converter = RowConverter::new(vec![SortField::new(DataType::Boolean)]).unwrap(); |
1976 | | |
1977 | | let col = Arc::new(BooleanArray::from_iter([None, Some(false), Some(true)])) as ArrayRef; |
1978 | | |
1979 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
1980 | | assert!(rows.row(2) > rows.row(1)); |
1981 | | assert!(rows.row(2) > rows.row(0)); |
1982 | | assert!(rows.row(1) > rows.row(0)); |
1983 | | |
1984 | | let cols = converter.convert_rows(&rows).unwrap(); |
1985 | | assert_eq!(&cols[0], &col); |
1986 | | |
1987 | | let converter = RowConverter::new(vec![SortField::new_with_options( |
1988 | | DataType::Boolean, |
1989 | | SortOptions::default().desc().with_nulls_first(false), |
1990 | | )]) |
1991 | | .unwrap(); |
1992 | | |
1993 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
1994 | | assert!(rows.row(2) < rows.row(1)); |
1995 | | assert!(rows.row(2) < rows.row(0)); |
1996 | | assert!(rows.row(1) < rows.row(0)); |
1997 | | let cols = converter.convert_rows(&rows).unwrap(); |
1998 | | assert_eq!(&cols[0], &col); |
1999 | | } |
2000 | | |
2001 | | #[test] |
2002 | | fn test_timezone() { |
2003 | | let a = |
2004 | | TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]).with_timezone("+01:00".to_string()); |
2005 | | let d = a.data_type().clone(); |
2006 | | |
2007 | | let converter = RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap(); |
2008 | | let rows = converter.convert_columns(&[Arc::new(a) as _]).unwrap(); |
2009 | | let back = converter.convert_rows(&rows).unwrap(); |
2010 | | assert_eq!(back.len(), 1); |
2011 | | assert_eq!(back[0].data_type(), &d); |
2012 | | |
2013 | | // Test dictionary |
2014 | | let mut a = PrimitiveDictionaryBuilder::<Int32Type, TimestampNanosecondType>::new(); |
2015 | | a.append(34).unwrap(); |
2016 | | a.append_null(); |
2017 | | a.append(345).unwrap(); |
2018 | | |
2019 | | // Construct dictionary with a timezone |
2020 | | let dict = a.finish(); |
2021 | | let values = TimestampNanosecondArray::from(dict.values().to_data()); |
2022 | | let dict_with_tz = dict.with_values(Arc::new(values.with_timezone("+02:00"))); |
2023 | | let v = DataType::Timestamp(TimeUnit::Nanosecond, Some("+02:00".into())); |
2024 | | let d = DataType::Dictionary(Box::new(DataType::Int32), Box::new(v.clone())); |
2025 | | |
2026 | | assert_eq!(dict_with_tz.data_type(), &d); |
2027 | | let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); |
2028 | | let rows = converter |
2029 | | .convert_columns(&[Arc::new(dict_with_tz) as _]) |
2030 | | .unwrap(); |
2031 | | let back = converter.convert_rows(&rows).unwrap(); |
2032 | | assert_eq!(back.len(), 1); |
2033 | | assert_eq!(back[0].data_type(), &v); |
2034 | | } |
2035 | | |
2036 | | #[test] |
2037 | | fn test_null_encoding() { |
2038 | | let col = Arc::new(NullArray::new(10)); |
2039 | | let converter = RowConverter::new(vec![SortField::new(DataType::Null)]).unwrap(); |
2040 | | let rows = converter.convert_columns(&[col]).unwrap(); |
2041 | | assert_eq!(rows.num_rows(), 10); |
2042 | | assert_eq!(rows.row(1).data.len(), 0); |
2043 | | } |
2044 | | |
2045 | | #[test] |
2046 | | fn test_variable_width() { |
2047 | | let col = Arc::new(StringArray::from_iter([ |
2048 | | Some("hello"), |
2049 | | Some("he"), |
2050 | | None, |
2051 | | Some("foo"), |
2052 | | Some(""), |
2053 | | ])) as ArrayRef; |
2054 | | |
2055 | | let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
2056 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
2057 | | |
2058 | | assert!(rows.row(1) < rows.row(0)); |
2059 | | assert!(rows.row(2) < rows.row(4)); |
2060 | | assert!(rows.row(3) < rows.row(0)); |
2061 | | assert!(rows.row(3) < rows.row(1)); |
2062 | | |
2063 | | let cols = converter.convert_rows(&rows).unwrap(); |
2064 | | assert_eq!(&cols[0], &col); |
2065 | | |
2066 | | let col = Arc::new(BinaryArray::from_iter([ |
2067 | | None, |
2068 | | Some(vec![0_u8; 0]), |
2069 | | Some(vec![0_u8; 6]), |
2070 | | Some(vec![0_u8; variable::MINI_BLOCK_SIZE]), |
2071 | | Some(vec![0_u8; variable::MINI_BLOCK_SIZE + 1]), |
2072 | | Some(vec![0_u8; variable::BLOCK_SIZE]), |
2073 | | Some(vec![0_u8; variable::BLOCK_SIZE + 1]), |
2074 | | Some(vec![1_u8; 6]), |
2075 | | Some(vec![1_u8; variable::MINI_BLOCK_SIZE]), |
2076 | | Some(vec![1_u8; variable::MINI_BLOCK_SIZE + 1]), |
2077 | | Some(vec![1_u8; variable::BLOCK_SIZE]), |
2078 | | Some(vec![1_u8; variable::BLOCK_SIZE + 1]), |
2079 | | Some(vec![0xFF_u8; 6]), |
2080 | | Some(vec![0xFF_u8; variable::MINI_BLOCK_SIZE]), |
2081 | | Some(vec![0xFF_u8; variable::MINI_BLOCK_SIZE + 1]), |
2082 | | Some(vec![0xFF_u8; variable::BLOCK_SIZE]), |
2083 | | Some(vec![0xFF_u8; variable::BLOCK_SIZE + 1]), |
2084 | | ])) as ArrayRef; |
2085 | | |
2086 | | let converter = RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap(); |
2087 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
2088 | | |
2089 | | for i in 0..rows.num_rows() { |
2090 | | for j in i + 1..rows.num_rows() { |
2091 | | assert!( |
2092 | | rows.row(i) < rows.row(j), |
2093 | | "{} < {} - {:?} < {:?}", |
2094 | | i, |
2095 | | j, |
2096 | | rows.row(i), |
2097 | | rows.row(j) |
2098 | | ); |
2099 | | } |
2100 | | } |
2101 | | |
2102 | | let cols = converter.convert_rows(&rows).unwrap(); |
2103 | | assert_eq!(&cols[0], &col); |
2104 | | |
2105 | | let converter = RowConverter::new(vec![SortField::new_with_options( |
2106 | | DataType::Binary, |
2107 | | SortOptions::default().desc().with_nulls_first(false), |
2108 | | )]) |
2109 | | .unwrap(); |
2110 | | let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); |
2111 | | |
2112 | | for i in 0..rows.num_rows() { |
2113 | | for j in i + 1..rows.num_rows() { |
2114 | | assert!( |
2115 | | rows.row(i) > rows.row(j), |
2116 | | "{} > {} - {:?} > {:?}", |
2117 | | i, |
2118 | | j, |
2119 | | rows.row(i), |
2120 | | rows.row(j) |
2121 | | ); |
2122 | | } |
2123 | | } |
2124 | | |
2125 | | let cols = converter.convert_rows(&rows).unwrap(); |
2126 | | assert_eq!(&cols[0], &col); |
2127 | | } |
2128 | | |
2129 | | /// If `exact` is false performs a logical comparison between a and dictionary-encoded b |
2130 | | fn dictionary_eq(a: &dyn Array, b: &dyn Array) { |
2131 | | match b.data_type() { |
2132 | | DataType::Dictionary(_, v) => { |
2133 | | assert_eq!(a.data_type(), v.as_ref()); |
2134 | | let b = arrow_cast::cast(b, v).unwrap(); |
2135 | | assert_eq!(a, b.as_ref()) |
2136 | | } |
2137 | | _ => assert_eq!(a, b), |
2138 | | } |
2139 | | } |
2140 | | |
2141 | | #[test] |
2142 | | fn test_string_dictionary() { |
2143 | | let a = Arc::new(DictionaryArray::<Int32Type>::from_iter([ |
2144 | | Some("foo"), |
2145 | | Some("hello"), |
2146 | | Some("he"), |
2147 | | None, |
2148 | | Some("hello"), |
2149 | | Some(""), |
2150 | | Some("hello"), |
2151 | | Some("hello"), |
2152 | | ])) as ArrayRef; |
2153 | | |
2154 | | let field = SortField::new(a.data_type().clone()); |
2155 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2156 | | let rows_a = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); |
2157 | | |
2158 | | assert!(rows_a.row(3) < rows_a.row(5)); |
2159 | | assert!(rows_a.row(2) < rows_a.row(1)); |
2160 | | assert!(rows_a.row(0) < rows_a.row(1)); |
2161 | | assert!(rows_a.row(3) < rows_a.row(0)); |
2162 | | |
2163 | | assert_eq!(rows_a.row(1), rows_a.row(4)); |
2164 | | assert_eq!(rows_a.row(1), rows_a.row(6)); |
2165 | | assert_eq!(rows_a.row(1), rows_a.row(7)); |
2166 | | |
2167 | | let cols = converter.convert_rows(&rows_a).unwrap(); |
2168 | | dictionary_eq(&cols[0], &a); |
2169 | | |
2170 | | let b = Arc::new(DictionaryArray::<Int32Type>::from_iter([ |
2171 | | Some("hello"), |
2172 | | None, |
2173 | | Some("cupcakes"), |
2174 | | ])) as ArrayRef; |
2175 | | |
2176 | | let rows_b = converter.convert_columns(&[Arc::clone(&b)]).unwrap(); |
2177 | | assert_eq!(rows_a.row(1), rows_b.row(0)); |
2178 | | assert_eq!(rows_a.row(3), rows_b.row(1)); |
2179 | | assert!(rows_b.row(2) < rows_a.row(0)); |
2180 | | |
2181 | | let cols = converter.convert_rows(&rows_b).unwrap(); |
2182 | | dictionary_eq(&cols[0], &b); |
2183 | | |
2184 | | let converter = RowConverter::new(vec![SortField::new_with_options( |
2185 | | a.data_type().clone(), |
2186 | | SortOptions::default().desc().with_nulls_first(false), |
2187 | | )]) |
2188 | | .unwrap(); |
2189 | | |
2190 | | let rows_c = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); |
2191 | | assert!(rows_c.row(3) > rows_c.row(5)); |
2192 | | assert!(rows_c.row(2) > rows_c.row(1)); |
2193 | | assert!(rows_c.row(0) > rows_c.row(1)); |
2194 | | assert!(rows_c.row(3) > rows_c.row(0)); |
2195 | | |
2196 | | let cols = converter.convert_rows(&rows_c).unwrap(); |
2197 | | dictionary_eq(&cols[0], &a); |
2198 | | |
2199 | | let converter = RowConverter::new(vec![SortField::new_with_options( |
2200 | | a.data_type().clone(), |
2201 | | SortOptions::default().desc().with_nulls_first(true), |
2202 | | )]) |
2203 | | .unwrap(); |
2204 | | |
2205 | | let rows_c = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); |
2206 | | assert!(rows_c.row(3) < rows_c.row(5)); |
2207 | | assert!(rows_c.row(2) > rows_c.row(1)); |
2208 | | assert!(rows_c.row(0) > rows_c.row(1)); |
2209 | | assert!(rows_c.row(3) < rows_c.row(0)); |
2210 | | |
2211 | | let cols = converter.convert_rows(&rows_c).unwrap(); |
2212 | | dictionary_eq(&cols[0], &a); |
2213 | | } |
2214 | | |
2215 | | #[test] |
2216 | | fn test_struct() { |
2217 | | // Test basic |
2218 | | let a = Arc::new(Int32Array::from(vec![1, 1, 2, 2])) as ArrayRef; |
2219 | | let a_f = Arc::new(Field::new("int", DataType::Int32, false)); |
2220 | | let u = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])) as ArrayRef; |
2221 | | let u_f = Arc::new(Field::new("s", DataType::Utf8, false)); |
2222 | | let s1 = Arc::new(StructArray::from(vec![(a_f, a), (u_f, u)])) as ArrayRef; |
2223 | | |
2224 | | let sort_fields = vec![SortField::new(s1.data_type().clone())]; |
2225 | | let converter = RowConverter::new(sort_fields).unwrap(); |
2226 | | let r1 = converter.convert_columns(&[Arc::clone(&s1)]).unwrap(); |
2227 | | |
2228 | | for (a, b) in r1.iter().zip(r1.iter().skip(1)) { |
2229 | | assert!(a < b); |
2230 | | } |
2231 | | |
2232 | | let back = converter.convert_rows(&r1).unwrap(); |
2233 | | assert_eq!(back.len(), 1); |
2234 | | assert_eq!(&back[0], &s1); |
2235 | | |
2236 | | // Test struct nullability |
2237 | | let data = s1 |
2238 | | .to_data() |
2239 | | .into_builder() |
2240 | | .null_bit_buffer(Some(Buffer::from_slice_ref([0b00001010]))) |
2241 | | .null_count(2) |
2242 | | .build() |
2243 | | .unwrap(); |
2244 | | |
2245 | | let s2 = Arc::new(StructArray::from(data)) as ArrayRef; |
2246 | | let r2 = converter.convert_columns(&[Arc::clone(&s2)]).unwrap(); |
2247 | | assert_eq!(r2.row(0), r2.row(2)); // Nulls equal |
2248 | | assert!(r2.row(0) < r2.row(1)); // Nulls first |
2249 | | assert_ne!(r1.row(0), r2.row(0)); // Value does not equal null |
2250 | | assert_eq!(r1.row(1), r2.row(1)); // Values equal |
2251 | | |
2252 | | let back = converter.convert_rows(&r2).unwrap(); |
2253 | | assert_eq!(back.len(), 1); |
2254 | | assert_eq!(&back[0], &s2); |
2255 | | |
2256 | | back[0].to_data().validate_full().unwrap(); |
2257 | | } |
2258 | | |
2259 | | #[test] |
2260 | | fn test_dictionary_in_struct() { |
2261 | | let builder = StringDictionaryBuilder::<Int32Type>::new(); |
2262 | | let mut struct_builder = StructBuilder::new( |
2263 | | vec![Field::new_dictionary( |
2264 | | "foo", |
2265 | | DataType::Int32, |
2266 | | DataType::Utf8, |
2267 | | true, |
2268 | | )], |
2269 | | vec![Box::new(builder)], |
2270 | | ); |
2271 | | |
2272 | | let dict_builder = struct_builder |
2273 | | .field_builder::<StringDictionaryBuilder<Int32Type>>(0) |
2274 | | .unwrap(); |
2275 | | |
2276 | | // Flattened: ["a", null, "a", "b"] |
2277 | | dict_builder.append_value("a"); |
2278 | | dict_builder.append_null(); |
2279 | | dict_builder.append_value("a"); |
2280 | | dict_builder.append_value("b"); |
2281 | | |
2282 | | for _ in 0..4 { |
2283 | | struct_builder.append(true); |
2284 | | } |
2285 | | |
2286 | | let s = Arc::new(struct_builder.finish()) as ArrayRef; |
2287 | | let sort_fields = vec![SortField::new(s.data_type().clone())]; |
2288 | | let converter = RowConverter::new(sort_fields).unwrap(); |
2289 | | let r = converter.convert_columns(&[Arc::clone(&s)]).unwrap(); |
2290 | | |
2291 | | let back = converter.convert_rows(&r).unwrap(); |
2292 | | let [s2] = back.try_into().unwrap(); |
2293 | | |
2294 | | // RowConverter flattens Dictionary |
2295 | | // s.ty = Struct(foo Dictionary(Int32, Utf8)), s2.ty = Struct(foo Utf8) |
2296 | | assert_ne!(&s.data_type(), &s2.data_type()); |
2297 | | s2.to_data().validate_full().unwrap(); |
2298 | | |
2299 | | // Check if the logical data remains the same |
2300 | | // Keys: [0, null, 0, 1] |
2301 | | // Values: ["a", "b"] |
2302 | | let s1_struct = s.as_struct(); |
2303 | | let s1_0 = s1_struct.column(0); |
2304 | | let s1_idx_0 = s1_0.as_dictionary::<Int32Type>(); |
2305 | | let keys = s1_idx_0.keys(); |
2306 | | let values = s1_idx_0.values().as_string::<i32>(); |
2307 | | // Flattened: ["a", null, "a", "b"] |
2308 | | let s2_struct = s2.as_struct(); |
2309 | | let s2_0 = s2_struct.column(0); |
2310 | | let s2_idx_0 = s2_0.as_string::<i32>(); |
2311 | | |
2312 | | for i in 0..keys.len() { |
2313 | | if keys.is_null(i) { |
2314 | | assert!(s2_idx_0.is_null(i)); |
2315 | | } else { |
2316 | | let dict_index = keys.value(i) as usize; |
2317 | | assert_eq!(values.value(dict_index), s2_idx_0.value(i)); |
2318 | | } |
2319 | | } |
2320 | | } |
2321 | | |
2322 | | #[test] |
2323 | | fn test_dictionary_in_struct_empty() { |
2324 | | let ty = DataType::Struct( |
2325 | | vec![Field::new_dictionary( |
2326 | | "foo", |
2327 | | DataType::Int32, |
2328 | | DataType::Int32, |
2329 | | false, |
2330 | | )] |
2331 | | .into(), |
2332 | | ); |
2333 | | let s = arrow_array::new_empty_array(&ty); |
2334 | | |
2335 | | let sort_fields = vec![SortField::new(s.data_type().clone())]; |
2336 | | let converter = RowConverter::new(sort_fields).unwrap(); |
2337 | | let r = converter.convert_columns(&[Arc::clone(&s)]).unwrap(); |
2338 | | |
2339 | | let back = converter.convert_rows(&r).unwrap(); |
2340 | | let [s2] = back.try_into().unwrap(); |
2341 | | |
2342 | | // RowConverter flattens Dictionary |
2343 | | // s.ty = Struct(foo Dictionary(Int32, Int32)), s2.ty = Struct(foo Int32) |
2344 | | assert_ne!(&s.data_type(), &s2.data_type()); |
2345 | | s2.to_data().validate_full().unwrap(); |
2346 | | assert_eq!(s.len(), 0); |
2347 | | assert_eq!(s2.len(), 0); |
2348 | | } |
2349 | | |
2350 | | #[test] |
2351 | | fn test_list_of_string_dictionary() { |
2352 | | let mut builder = ListBuilder::<StringDictionaryBuilder<Int32Type>>::default(); |
2353 | | // List[0] = ["a", "b", "zero", null, "c", "b", "d" (dict)] |
2354 | | builder.values().append("a").unwrap(); |
2355 | | builder.values().append("b").unwrap(); |
2356 | | builder.values().append("zero").unwrap(); |
2357 | | builder.values().append_null(); |
2358 | | builder.values().append("c").unwrap(); |
2359 | | builder.values().append("b").unwrap(); |
2360 | | builder.values().append("d").unwrap(); |
2361 | | builder.append(true); |
2362 | | // List[1] = null |
2363 | | builder.append(false); |
2364 | | // List[2] = ["e", "zero", "a" (dict)] |
2365 | | builder.values().append("e").unwrap(); |
2366 | | builder.values().append("zero").unwrap(); |
2367 | | builder.values().append("a").unwrap(); |
2368 | | builder.append(true); |
2369 | | |
2370 | | let a = Arc::new(builder.finish()) as ArrayRef; |
2371 | | let data_type = a.data_type().clone(); |
2372 | | |
2373 | | let field = SortField::new(data_type.clone()); |
2374 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2375 | | let rows = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); |
2376 | | |
2377 | | let back = converter.convert_rows(&rows).unwrap(); |
2378 | | assert_eq!(back.len(), 1); |
2379 | | let [a2] = back.try_into().unwrap(); |
2380 | | |
2381 | | // RowConverter flattens Dictionary |
2382 | | // a.ty: List(Dictionary(Int32, Utf8)), a2.ty: List(Utf8) |
2383 | | assert_ne!(&a.data_type(), &a2.data_type()); |
2384 | | |
2385 | | a2.to_data().validate_full().unwrap(); |
2386 | | |
2387 | | let a2_list = a2.as_list::<i32>(); |
2388 | | let a1_list = a.as_list::<i32>(); |
2389 | | |
2390 | | // Check if the logical data remains the same |
2391 | | // List[0] = ["a", "b", "zero", null, "c", "b", "d" (dict)] |
2392 | | let a1_0 = a1_list.value(0); |
2393 | | let a1_idx_0 = a1_0.as_dictionary::<Int32Type>(); |
2394 | | let keys = a1_idx_0.keys(); |
2395 | | let values = a1_idx_0.values().as_string::<i32>(); |
2396 | | let a2_0 = a2_list.value(0); |
2397 | | let a2_idx_0 = a2_0.as_string::<i32>(); |
2398 | | |
2399 | | for i in 0..keys.len() { |
2400 | | if keys.is_null(i) { |
2401 | | assert!(a2_idx_0.is_null(i)); |
2402 | | } else { |
2403 | | let dict_index = keys.value(i) as usize; |
2404 | | assert_eq!(values.value(dict_index), a2_idx_0.value(i)); |
2405 | | } |
2406 | | } |
2407 | | |
2408 | | // List[1] = null |
2409 | | assert!(a1_list.is_null(1)); |
2410 | | assert!(a2_list.is_null(1)); |
2411 | | |
2412 | | // List[2] = ["e", "zero", "a" (dict)] |
2413 | | let a1_2 = a1_list.value(2); |
2414 | | let a1_idx_2 = a1_2.as_dictionary::<Int32Type>(); |
2415 | | let keys = a1_idx_2.keys(); |
2416 | | let values = a1_idx_2.values().as_string::<i32>(); |
2417 | | let a2_2 = a2_list.value(2); |
2418 | | let a2_idx_2 = a2_2.as_string::<i32>(); |
2419 | | |
2420 | | for i in 0..keys.len() { |
2421 | | if keys.is_null(i) { |
2422 | | assert!(a2_idx_2.is_null(i)); |
2423 | | } else { |
2424 | | let dict_index = keys.value(i) as usize; |
2425 | | assert_eq!(values.value(dict_index), a2_idx_2.value(i)); |
2426 | | } |
2427 | | } |
2428 | | } |
2429 | | |
2430 | | #[test] |
2431 | | fn test_primitive_dictionary() { |
2432 | | let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new(); |
2433 | | builder.append(2).unwrap(); |
2434 | | builder.append(3).unwrap(); |
2435 | | builder.append(0).unwrap(); |
2436 | | builder.append_null(); |
2437 | | builder.append(5).unwrap(); |
2438 | | builder.append(3).unwrap(); |
2439 | | builder.append(-1).unwrap(); |
2440 | | |
2441 | | let a = builder.finish(); |
2442 | | let data_type = a.data_type().clone(); |
2443 | | let columns = [Arc::new(a) as ArrayRef]; |
2444 | | |
2445 | | let field = SortField::new(data_type.clone()); |
2446 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2447 | | let rows = converter.convert_columns(&columns).unwrap(); |
2448 | | assert!(rows.row(0) < rows.row(1)); |
2449 | | assert!(rows.row(2) < rows.row(0)); |
2450 | | assert!(rows.row(3) < rows.row(2)); |
2451 | | assert!(rows.row(6) < rows.row(2)); |
2452 | | assert!(rows.row(3) < rows.row(6)); |
2453 | | |
2454 | | let back = converter.convert_rows(&rows).unwrap(); |
2455 | | assert_eq!(back.len(), 1); |
2456 | | back[0].to_data().validate_full().unwrap(); |
2457 | | } |
2458 | | |
2459 | | #[test] |
2460 | | fn test_dictionary_nulls() { |
2461 | | let values = Int32Array::from_iter([Some(1), Some(-1), None, Some(4), None]).into_data(); |
2462 | | let keys = |
2463 | | Int32Array::from_iter([Some(0), Some(0), Some(1), Some(2), Some(4), None]).into_data(); |
2464 | | |
2465 | | let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32)); |
2466 | | let data = keys |
2467 | | .into_builder() |
2468 | | .data_type(data_type.clone()) |
2469 | | .child_data(vec![values]) |
2470 | | .build() |
2471 | | .unwrap(); |
2472 | | |
2473 | | let columns = [Arc::new(DictionaryArray::<Int32Type>::from(data)) as ArrayRef]; |
2474 | | let field = SortField::new(data_type.clone()); |
2475 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2476 | | let rows = converter.convert_columns(&columns).unwrap(); |
2477 | | |
2478 | | assert_eq!(rows.row(0), rows.row(1)); |
2479 | | assert_eq!(rows.row(3), rows.row(4)); |
2480 | | assert_eq!(rows.row(4), rows.row(5)); |
2481 | | assert!(rows.row(3) < rows.row(0)); |
2482 | | } |
2483 | | |
2484 | | #[test] |
2485 | | #[should_panic(expected = "Encountered non UTF-8 data")] |
2486 | | fn test_invalid_utf8() { |
2487 | | let converter = RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap(); |
2488 | | let array = Arc::new(BinaryArray::from_iter_values([&[0xFF]])) as _; |
2489 | | let rows = converter.convert_columns(&[array]).unwrap(); |
2490 | | let binary_row = rows.row(0); |
2491 | | |
2492 | | let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
2493 | | let parser = converter.parser(); |
2494 | | let utf8_row = parser.parse(binary_row.as_ref()); |
2495 | | |
2496 | | converter.convert_rows(std::iter::once(utf8_row)).unwrap(); |
2497 | | } |
2498 | | |
2499 | | #[test] |
2500 | | #[should_panic(expected = "Encountered non UTF-8 data")] |
2501 | | fn test_invalid_utf8_array() { |
2502 | | let converter = RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap(); |
2503 | | let array = Arc::new(BinaryArray::from_iter_values([&[0xFF]])) as _; |
2504 | | let rows = converter.convert_columns(&[array]).unwrap(); |
2505 | | let binary_rows = rows.try_into_binary().expect("known-small rows"); |
2506 | | |
2507 | | let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
2508 | | let parsed = converter.from_binary(binary_rows); |
2509 | | |
2510 | | converter.convert_rows(parsed.iter()).unwrap(); |
2511 | | } |
2512 | | |
2513 | | #[test] |
2514 | | #[should_panic(expected = "index out of bounds")] |
2515 | | fn test_invalid_empty() { |
2516 | | let binary_row: &[u8] = &[]; |
2517 | | |
2518 | | let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
2519 | | let parser = converter.parser(); |
2520 | | let utf8_row = parser.parse(binary_row.as_ref()); |
2521 | | |
2522 | | converter.convert_rows(std::iter::once(utf8_row)).unwrap(); |
2523 | | } |
2524 | | |
2525 | | #[test] |
2526 | | #[should_panic(expected = "index out of bounds")] |
2527 | | fn test_invalid_empty_array() { |
2528 | | let row: &[u8] = &[]; |
2529 | | let binary_rows = BinaryArray::from(vec![row]); |
2530 | | |
2531 | | let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
2532 | | let parsed = converter.from_binary(binary_rows); |
2533 | | |
2534 | | converter.convert_rows(parsed.iter()).unwrap(); |
2535 | | } |
2536 | | |
2537 | | #[test] |
2538 | | #[should_panic(expected = "index out of bounds")] |
2539 | | fn test_invalid_truncated() { |
2540 | | let binary_row: &[u8] = &[0x02]; |
2541 | | |
2542 | | let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
2543 | | let parser = converter.parser(); |
2544 | | let utf8_row = parser.parse(binary_row.as_ref()); |
2545 | | |
2546 | | converter.convert_rows(std::iter::once(utf8_row)).unwrap(); |
2547 | | } |
2548 | | |
2549 | | #[test] |
2550 | | #[should_panic(expected = "index out of bounds")] |
2551 | | fn test_invalid_truncated_array() { |
2552 | | let row: &[u8] = &[0x02]; |
2553 | | let binary_rows = BinaryArray::from(vec![row]); |
2554 | | |
2555 | | let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); |
2556 | | let parsed = converter.from_binary(binary_rows); |
2557 | | |
2558 | | converter.convert_rows(parsed.iter()).unwrap(); |
2559 | | } |
2560 | | |
2561 | | #[test] |
2562 | | #[should_panic(expected = "rows were not produced by this RowConverter")] |
2563 | | fn test_different_converter() { |
2564 | | let values = Arc::new(Int32Array::from_iter([Some(1), Some(-1)])); |
2565 | | let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); |
2566 | | let rows = converter.convert_columns(&[values]).unwrap(); |
2567 | | |
2568 | | let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); |
2569 | | let _ = converter.convert_rows(&rows); |
2570 | | } |
2571 | | |
2572 | | fn test_single_list<O: OffsetSizeTrait>() { |
2573 | | let mut builder = GenericListBuilder::<O, _>::new(Int32Builder::new()); |
2574 | | builder.values().append_value(32); |
2575 | | builder.values().append_value(52); |
2576 | | builder.values().append_value(32); |
2577 | | builder.append(true); |
2578 | | builder.values().append_value(32); |
2579 | | builder.values().append_value(52); |
2580 | | builder.values().append_value(12); |
2581 | | builder.append(true); |
2582 | | builder.values().append_value(32); |
2583 | | builder.values().append_value(52); |
2584 | | builder.append(true); |
2585 | | builder.values().append_value(32); // MASKED |
2586 | | builder.values().append_value(52); // MASKED |
2587 | | builder.append(false); |
2588 | | builder.values().append_value(32); |
2589 | | builder.values().append_null(); |
2590 | | builder.append(true); |
2591 | | builder.append(true); |
2592 | | builder.values().append_value(17); // MASKED |
2593 | | builder.values().append_null(); // MASKED |
2594 | | builder.append(false); |
2595 | | |
2596 | | let list = Arc::new(builder.finish()) as ArrayRef; |
2597 | | let d = list.data_type().clone(); |
2598 | | |
2599 | | let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); |
2600 | | |
2601 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2602 | | assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] |
2603 | | assert!(rows.row(2) < rows.row(1)); // [32, 52] < [32, 52, 12] |
2604 | | assert!(rows.row(3) < rows.row(2)); // null < [32, 52] |
2605 | | assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 52] |
2606 | | assert!(rows.row(5) < rows.row(2)); // [] < [32, 52] |
2607 | | assert!(rows.row(3) < rows.row(5)); // null < [] |
2608 | | assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) |
2609 | | |
2610 | | let back = converter.convert_rows(&rows).unwrap(); |
2611 | | assert_eq!(back.len(), 1); |
2612 | | back[0].to_data().validate_full().unwrap(); |
2613 | | assert_eq!(&back[0], &list); |
2614 | | |
2615 | | let options = SortOptions::default().asc().with_nulls_first(false); |
2616 | | let field = SortField::new_with_options(d.clone(), options); |
2617 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2618 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2619 | | |
2620 | | assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] |
2621 | | assert!(rows.row(2) < rows.row(1)); // [32, 52] < [32, 52, 12] |
2622 | | assert!(rows.row(3) > rows.row(2)); // null > [32, 52] |
2623 | | assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 52] |
2624 | | assert!(rows.row(5) < rows.row(2)); // [] < [32, 52] |
2625 | | assert!(rows.row(3) > rows.row(5)); // null > [] |
2626 | | assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) |
2627 | | |
2628 | | let back = converter.convert_rows(&rows).unwrap(); |
2629 | | assert_eq!(back.len(), 1); |
2630 | | back[0].to_data().validate_full().unwrap(); |
2631 | | assert_eq!(&back[0], &list); |
2632 | | |
2633 | | let options = SortOptions::default().desc().with_nulls_first(false); |
2634 | | let field = SortField::new_with_options(d.clone(), options); |
2635 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2636 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2637 | | |
2638 | | assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] |
2639 | | assert!(rows.row(2) > rows.row(1)); // [32, 52] > [32, 52, 12] |
2640 | | assert!(rows.row(3) > rows.row(2)); // null > [32, 52] |
2641 | | assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 52] |
2642 | | assert!(rows.row(5) > rows.row(2)); // [] > [32, 52] |
2643 | | assert!(rows.row(3) > rows.row(5)); // null > [] |
2644 | | assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) |
2645 | | |
2646 | | let back = converter.convert_rows(&rows).unwrap(); |
2647 | | assert_eq!(back.len(), 1); |
2648 | | back[0].to_data().validate_full().unwrap(); |
2649 | | assert_eq!(&back[0], &list); |
2650 | | |
2651 | | let options = SortOptions::default().desc().with_nulls_first(true); |
2652 | | let field = SortField::new_with_options(d, options); |
2653 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2654 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2655 | | |
2656 | | assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] |
2657 | | assert!(rows.row(2) > rows.row(1)); // [32, 52] > [32, 52, 12] |
2658 | | assert!(rows.row(3) < rows.row(2)); // null < [32, 52] |
2659 | | assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 52] |
2660 | | assert!(rows.row(5) > rows.row(2)); // [] > [32, 52] |
2661 | | assert!(rows.row(3) < rows.row(5)); // null < [] |
2662 | | assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) |
2663 | | |
2664 | | let back = converter.convert_rows(&rows).unwrap(); |
2665 | | assert_eq!(back.len(), 1); |
2666 | | back[0].to_data().validate_full().unwrap(); |
2667 | | assert_eq!(&back[0], &list); |
2668 | | |
2669 | | let sliced_list = list.slice(1, 5); |
2670 | | let rows_on_sliced_list = converter |
2671 | | .convert_columns(&[Arc::clone(&sliced_list)]) |
2672 | | .unwrap(); |
2673 | | |
2674 | | assert!(rows_on_sliced_list.row(1) > rows_on_sliced_list.row(0)); // [32, 52] > [32, 52, 12] |
2675 | | assert!(rows_on_sliced_list.row(2) < rows_on_sliced_list.row(1)); // null < [32, 52] |
2676 | | assert!(rows_on_sliced_list.row(3) < rows_on_sliced_list.row(1)); // [32, null] < [32, 52] |
2677 | | assert!(rows_on_sliced_list.row(4) > rows_on_sliced_list.row(1)); // [] > [32, 52] |
2678 | | assert!(rows_on_sliced_list.row(2) < rows_on_sliced_list.row(4)); // null < [] |
2679 | | |
2680 | | let back = converter.convert_rows(&rows_on_sliced_list).unwrap(); |
2681 | | assert_eq!(back.len(), 1); |
2682 | | back[0].to_data().validate_full().unwrap(); |
2683 | | assert_eq!(&back[0], &sliced_list); |
2684 | | } |
2685 | | |
2686 | | fn test_nested_list<O: OffsetSizeTrait>() { |
2687 | | let mut builder = |
2688 | | GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(Int32Builder::new())); |
2689 | | |
2690 | | builder.values().values().append_value(1); |
2691 | | builder.values().values().append_value(2); |
2692 | | builder.values().append(true); |
2693 | | builder.values().values().append_value(1); |
2694 | | builder.values().values().append_null(); |
2695 | | builder.values().append(true); |
2696 | | builder.append(true); |
2697 | | |
2698 | | builder.values().values().append_value(1); |
2699 | | builder.values().values().append_null(); |
2700 | | builder.values().append(true); |
2701 | | builder.values().values().append_value(1); |
2702 | | builder.values().values().append_null(); |
2703 | | builder.values().append(true); |
2704 | | builder.append(true); |
2705 | | |
2706 | | builder.values().values().append_value(1); |
2707 | | builder.values().values().append_null(); |
2708 | | builder.values().append(true); |
2709 | | builder.values().append(false); |
2710 | | builder.append(true); |
2711 | | builder.append(false); |
2712 | | |
2713 | | builder.values().values().append_value(1); |
2714 | | builder.values().values().append_value(2); |
2715 | | builder.values().append(true); |
2716 | | builder.append(true); |
2717 | | |
2718 | | let list = Arc::new(builder.finish()) as ArrayRef; |
2719 | | let d = list.data_type().clone(); |
2720 | | |
2721 | | // [ |
2722 | | // [[1, 2], [1, null]], |
2723 | | // [[1, null], [1, null]], |
2724 | | // [[1, null], null] |
2725 | | // null |
2726 | | // [[1, 2]] |
2727 | | // ] |
2728 | | let options = SortOptions::default().asc().with_nulls_first(true); |
2729 | | let field = SortField::new_with_options(d.clone(), options); |
2730 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2731 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2732 | | |
2733 | | assert!(rows.row(0) > rows.row(1)); |
2734 | | assert!(rows.row(1) > rows.row(2)); |
2735 | | assert!(rows.row(2) > rows.row(3)); |
2736 | | assert!(rows.row(4) < rows.row(0)); |
2737 | | assert!(rows.row(4) > rows.row(1)); |
2738 | | |
2739 | | let back = converter.convert_rows(&rows).unwrap(); |
2740 | | assert_eq!(back.len(), 1); |
2741 | | back[0].to_data().validate_full().unwrap(); |
2742 | | assert_eq!(&back[0], &list); |
2743 | | |
2744 | | let options = SortOptions::default().desc().with_nulls_first(true); |
2745 | | let field = SortField::new_with_options(d.clone(), options); |
2746 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2747 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2748 | | |
2749 | | assert!(rows.row(0) > rows.row(1)); |
2750 | | assert!(rows.row(1) > rows.row(2)); |
2751 | | assert!(rows.row(2) > rows.row(3)); |
2752 | | assert!(rows.row(4) > rows.row(0)); |
2753 | | assert!(rows.row(4) > rows.row(1)); |
2754 | | |
2755 | | let back = converter.convert_rows(&rows).unwrap(); |
2756 | | assert_eq!(back.len(), 1); |
2757 | | back[0].to_data().validate_full().unwrap(); |
2758 | | assert_eq!(&back[0], &list); |
2759 | | |
2760 | | let options = SortOptions::default().desc().with_nulls_first(false); |
2761 | | let field = SortField::new_with_options(d, options); |
2762 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2763 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2764 | | |
2765 | | assert!(rows.row(0) < rows.row(1)); |
2766 | | assert!(rows.row(1) < rows.row(2)); |
2767 | | assert!(rows.row(2) < rows.row(3)); |
2768 | | assert!(rows.row(4) > rows.row(0)); |
2769 | | assert!(rows.row(4) < rows.row(1)); |
2770 | | |
2771 | | let back = converter.convert_rows(&rows).unwrap(); |
2772 | | assert_eq!(back.len(), 1); |
2773 | | back[0].to_data().validate_full().unwrap(); |
2774 | | assert_eq!(&back[0], &list); |
2775 | | |
2776 | | let sliced_list = list.slice(1, 3); |
2777 | | let rows = converter |
2778 | | .convert_columns(&[Arc::clone(&sliced_list)]) |
2779 | | .unwrap(); |
2780 | | |
2781 | | assert!(rows.row(0) < rows.row(1)); |
2782 | | assert!(rows.row(1) < rows.row(2)); |
2783 | | |
2784 | | let back = converter.convert_rows(&rows).unwrap(); |
2785 | | assert_eq!(back.len(), 1); |
2786 | | back[0].to_data().validate_full().unwrap(); |
2787 | | assert_eq!(&back[0], &sliced_list); |
2788 | | } |
2789 | | |
2790 | | #[test] |
2791 | | fn test_list() { |
2792 | | test_single_list::<i32>(); |
2793 | | test_nested_list::<i32>(); |
2794 | | } |
2795 | | |
2796 | | #[test] |
2797 | | fn test_large_list() { |
2798 | | test_single_list::<i64>(); |
2799 | | test_nested_list::<i64>(); |
2800 | | } |
2801 | | |
2802 | | #[test] |
2803 | | fn test_fixed_size_list() { |
2804 | | let mut builder = FixedSizeListBuilder::new(Int32Builder::new(), 3); |
2805 | | builder.values().append_value(32); |
2806 | | builder.values().append_value(52); |
2807 | | builder.values().append_value(32); |
2808 | | builder.append(true); |
2809 | | builder.values().append_value(32); |
2810 | | builder.values().append_value(52); |
2811 | | builder.values().append_value(12); |
2812 | | builder.append(true); |
2813 | | builder.values().append_value(32); |
2814 | | builder.values().append_value(52); |
2815 | | builder.values().append_null(); |
2816 | | builder.append(true); |
2817 | | builder.values().append_value(32); // MASKED |
2818 | | builder.values().append_value(52); // MASKED |
2819 | | builder.values().append_value(13); // MASKED |
2820 | | builder.append(false); |
2821 | | builder.values().append_value(32); |
2822 | | builder.values().append_null(); |
2823 | | builder.values().append_null(); |
2824 | | builder.append(true); |
2825 | | builder.values().append_null(); |
2826 | | builder.values().append_null(); |
2827 | | builder.values().append_null(); |
2828 | | builder.append(true); |
2829 | | builder.values().append_value(17); // MASKED |
2830 | | builder.values().append_null(); // MASKED |
2831 | | builder.values().append_value(77); // MASKED |
2832 | | builder.append(false); |
2833 | | |
2834 | | let list = Arc::new(builder.finish()) as ArrayRef; |
2835 | | let d = list.data_type().clone(); |
2836 | | |
2837 | | // Default sorting (ascending, nulls first) |
2838 | | let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); |
2839 | | |
2840 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2841 | | assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] |
2842 | | assert!(rows.row(2) < rows.row(1)); // [32, 52, null] < [32, 52, 12] |
2843 | | assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null] |
2844 | | assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null] |
2845 | | assert!(rows.row(5) < rows.row(2)); // [null, null, null] < [32, 52, null] |
2846 | | assert!(rows.row(3) < rows.row(5)); // null < [null, null, null] |
2847 | | assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) |
2848 | | |
2849 | | let back = converter.convert_rows(&rows).unwrap(); |
2850 | | assert_eq!(back.len(), 1); |
2851 | | back[0].to_data().validate_full().unwrap(); |
2852 | | assert_eq!(&back[0], &list); |
2853 | | |
2854 | | // Ascending, null last |
2855 | | let options = SortOptions::default().asc().with_nulls_first(false); |
2856 | | let field = SortField::new_with_options(d.clone(), options); |
2857 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2858 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2859 | | assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] |
2860 | | assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12] |
2861 | | assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null] |
2862 | | assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null] |
2863 | | assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null] |
2864 | | assert!(rows.row(3) > rows.row(5)); // null > [null, null, null] |
2865 | | assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) |
2866 | | |
2867 | | let back = converter.convert_rows(&rows).unwrap(); |
2868 | | assert_eq!(back.len(), 1); |
2869 | | back[0].to_data().validate_full().unwrap(); |
2870 | | assert_eq!(&back[0], &list); |
2871 | | |
2872 | | // Descending, nulls last |
2873 | | let options = SortOptions::default().desc().with_nulls_first(false); |
2874 | | let field = SortField::new_with_options(d.clone(), options); |
2875 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2876 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2877 | | assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] |
2878 | | assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12] |
2879 | | assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null] |
2880 | | assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null] |
2881 | | assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null] |
2882 | | assert!(rows.row(3) > rows.row(5)); // null > [null, null, null] |
2883 | | assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) |
2884 | | |
2885 | | let back = converter.convert_rows(&rows).unwrap(); |
2886 | | assert_eq!(back.len(), 1); |
2887 | | back[0].to_data().validate_full().unwrap(); |
2888 | | assert_eq!(&back[0], &list); |
2889 | | |
2890 | | // Descending, nulls first |
2891 | | let options = SortOptions::default().desc().with_nulls_first(true); |
2892 | | let field = SortField::new_with_options(d, options); |
2893 | | let converter = RowConverter::new(vec![field]).unwrap(); |
2894 | | let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); |
2895 | | |
2896 | | assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] |
2897 | | assert!(rows.row(2) < rows.row(1)); // [32, 52, null] > [32, 52, 12] |
2898 | | assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null] |
2899 | | assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null] |
2900 | | assert!(rows.row(5) < rows.row(2)); // [null, null, null] > [32, 52, null] |
2901 | | assert!(rows.row(3) < rows.row(5)); // null < [null, null, null] |
2902 | | assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) |
2903 | | |
2904 | | let back = converter.convert_rows(&rows).unwrap(); |
2905 | | assert_eq!(back.len(), 1); |
2906 | | back[0].to_data().validate_full().unwrap(); |
2907 | | assert_eq!(&back[0], &list); |
2908 | | |
2909 | | let sliced_list = list.slice(1, 5); |
2910 | | let rows_on_sliced_list = converter |
2911 | | .convert_columns(&[Arc::clone(&sliced_list)]) |
2912 | | .unwrap(); |
2913 | | |
2914 | | assert!(rows_on_sliced_list.row(2) < rows_on_sliced_list.row(1)); // null < [32, 52, null] |
2915 | | assert!(rows_on_sliced_list.row(3) < rows_on_sliced_list.row(1)); // [32, null, null] < [32, 52, null] |
2916 | | assert!(rows_on_sliced_list.row(4) < rows_on_sliced_list.row(1)); // [null, null, null] > [32, 52, null] |
2917 | | assert!(rows_on_sliced_list.row(2) < rows_on_sliced_list.row(4)); // null < [null, null, null] |
2918 | | |
2919 | | let back = converter.convert_rows(&rows_on_sliced_list).unwrap(); |
2920 | | assert_eq!(back.len(), 1); |
2921 | | back[0].to_data().validate_full().unwrap(); |
2922 | | assert_eq!(&back[0], &sliced_list); |
2923 | | } |
2924 | | |
2925 | | #[test] |
2926 | | fn test_two_fixed_size_lists() { |
2927 | | let mut first = FixedSizeListBuilder::new(UInt8Builder::new(), 1); |
2928 | | // 0: [100] |
2929 | | first.values().append_value(100); |
2930 | | first.append(true); |
2931 | | // 1: [101] |
2932 | | first.values().append_value(101); |
2933 | | first.append(true); |
2934 | | // 2: [102] |
2935 | | first.values().append_value(102); |
2936 | | first.append(true); |
2937 | | // 3: [null] |
2938 | | first.values().append_null(); |
2939 | | first.append(true); |
2940 | | // 4: null |
2941 | | first.values().append_null(); // MASKED |
2942 | | first.append(false); |
2943 | | let first = Arc::new(first.finish()) as ArrayRef; |
2944 | | let first_type = first.data_type().clone(); |
2945 | | |
2946 | | let mut second = FixedSizeListBuilder::new(UInt8Builder::new(), 1); |
2947 | | // 0: [200] |
2948 | | second.values().append_value(200); |
2949 | | second.append(true); |
2950 | | // 1: [201] |
2951 | | second.values().append_value(201); |
2952 | | second.append(true); |
2953 | | // 2: [202] |
2954 | | second.values().append_value(202); |
2955 | | second.append(true); |
2956 | | // 3: [null] |
2957 | | second.values().append_null(); |
2958 | | second.append(true); |
2959 | | // 4: null |
2960 | | second.values().append_null(); // MASKED |
2961 | | second.append(false); |
2962 | | let second = Arc::new(second.finish()) as ArrayRef; |
2963 | | let second_type = second.data_type().clone(); |
2964 | | |
2965 | | let converter = RowConverter::new(vec![ |
2966 | | SortField::new(first_type.clone()), |
2967 | | SortField::new(second_type.clone()), |
2968 | | ]) |
2969 | | .unwrap(); |
2970 | | |
2971 | | let rows = converter |
2972 | | .convert_columns(&[Arc::clone(&first), Arc::clone(&second)]) |
2973 | | .unwrap(); |
2974 | | |
2975 | | let back = converter.convert_rows(&rows).unwrap(); |
2976 | | assert_eq!(back.len(), 2); |
2977 | | back[0].to_data().validate_full().unwrap(); |
2978 | | assert_eq!(&back[0], &first); |
2979 | | back[1].to_data().validate_full().unwrap(); |
2980 | | assert_eq!(&back[1], &second); |
2981 | | } |
2982 | | |
2983 | | #[test] |
2984 | | fn test_fixed_size_list_with_variable_width_content() { |
2985 | | let mut first = FixedSizeListBuilder::new( |
2986 | | StructBuilder::from_fields( |
2987 | | vec![ |
2988 | | Field::new( |
2989 | | "timestamp", |
2990 | | DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), |
2991 | | false, |
2992 | | ), |
2993 | | Field::new("offset_minutes", DataType::Int16, false), |
2994 | | Field::new("time_zone", DataType::Utf8, false), |
2995 | | ], |
2996 | | 1, |
2997 | | ), |
2998 | | 1, |
2999 | | ); |
3000 | | // 0: null |
3001 | | first |
3002 | | .values() |
3003 | | .field_builder::<TimestampMicrosecondBuilder>(0) |
3004 | | .unwrap() |
3005 | | .append_null(); |
3006 | | first |
3007 | | .values() |
3008 | | .field_builder::<Int16Builder>(1) |
3009 | | .unwrap() |
3010 | | .append_null(); |
3011 | | first |
3012 | | .values() |
3013 | | .field_builder::<StringBuilder>(2) |
3014 | | .unwrap() |
3015 | | .append_null(); |
3016 | | first.values().append(false); |
3017 | | first.append(false); |
3018 | | // 1: [null] |
3019 | | first |
3020 | | .values() |
3021 | | .field_builder::<TimestampMicrosecondBuilder>(0) |
3022 | | .unwrap() |
3023 | | .append_null(); |
3024 | | first |
3025 | | .values() |
3026 | | .field_builder::<Int16Builder>(1) |
3027 | | .unwrap() |
3028 | | .append_null(); |
3029 | | first |
3030 | | .values() |
3031 | | .field_builder::<StringBuilder>(2) |
3032 | | .unwrap() |
3033 | | .append_null(); |
3034 | | first.values().append(false); |
3035 | | first.append(true); |
3036 | | // 2: [1970-01-01 00:00:00.000000 UTC] |
3037 | | first |
3038 | | .values() |
3039 | | .field_builder::<TimestampMicrosecondBuilder>(0) |
3040 | | .unwrap() |
3041 | | .append_value(0); |
3042 | | first |
3043 | | .values() |
3044 | | .field_builder::<Int16Builder>(1) |
3045 | | .unwrap() |
3046 | | .append_value(0); |
3047 | | first |
3048 | | .values() |
3049 | | .field_builder::<StringBuilder>(2) |
3050 | | .unwrap() |
3051 | | .append_value("UTC"); |
3052 | | first.values().append(true); |
3053 | | first.append(true); |
3054 | | // 3: [2005-09-10 13:30:00.123456 Europe/Warsaw] |
3055 | | first |
3056 | | .values() |
3057 | | .field_builder::<TimestampMicrosecondBuilder>(0) |
3058 | | .unwrap() |
3059 | | .append_value(1126351800123456); |
3060 | | first |
3061 | | .values() |
3062 | | .field_builder::<Int16Builder>(1) |
3063 | | .unwrap() |
3064 | | .append_value(120); |
3065 | | first |
3066 | | .values() |
3067 | | .field_builder::<StringBuilder>(2) |
3068 | | .unwrap() |
3069 | | .append_value("Europe/Warsaw"); |
3070 | | first.values().append(true); |
3071 | | first.append(true); |
3072 | | let first = Arc::new(first.finish()) as ArrayRef; |
3073 | | let first_type = first.data_type().clone(); |
3074 | | |
3075 | | let mut second = StringBuilder::new(); |
3076 | | second.append_value("somewhere near"); |
3077 | | second.append_null(); |
3078 | | second.append_value("Greenwich"); |
3079 | | second.append_value("Warsaw"); |
3080 | | let second = Arc::new(second.finish()) as ArrayRef; |
3081 | | let second_type = second.data_type().clone(); |
3082 | | |
3083 | | let converter = RowConverter::new(vec![ |
3084 | | SortField::new(first_type.clone()), |
3085 | | SortField::new(second_type.clone()), |
3086 | | ]) |
3087 | | .unwrap(); |
3088 | | |
3089 | | let rows = converter |
3090 | | .convert_columns(&[Arc::clone(&first), Arc::clone(&second)]) |
3091 | | .unwrap(); |
3092 | | |
3093 | | let back = converter.convert_rows(&rows).unwrap(); |
3094 | | assert_eq!(back.len(), 2); |
3095 | | back[0].to_data().validate_full().unwrap(); |
3096 | | assert_eq!(&back[0], &first); |
3097 | | back[1].to_data().validate_full().unwrap(); |
3098 | | assert_eq!(&back[1], &second); |
3099 | | } |
3100 | | |
3101 | | fn generate_primitive_array<K>(len: usize, valid_percent: f64) -> PrimitiveArray<K> |
3102 | | where |
3103 | | K: ArrowPrimitiveType, |
3104 | | StandardUniform: Distribution<K::Native>, |
3105 | | { |
3106 | | let mut rng = rng(); |
3107 | | (0..len) |
3108 | | .map(|_| rng.random_bool(valid_percent).then(|| rng.random())) |
3109 | | .collect() |
3110 | | } |
3111 | | |
3112 | | fn generate_strings<O: OffsetSizeTrait>( |
3113 | | len: usize, |
3114 | | valid_percent: f64, |
3115 | | ) -> GenericStringArray<O> { |
3116 | | let mut rng = rng(); |
3117 | | (0..len) |
3118 | | .map(|_| { |
3119 | | rng.random_bool(valid_percent).then(|| { |
3120 | | let len = rng.random_range(0..100); |
3121 | | let bytes = (0..len).map(|_| rng.random_range(0..128)).collect(); |
3122 | | String::from_utf8(bytes).unwrap() |
3123 | | }) |
3124 | | }) |
3125 | | .collect() |
3126 | | } |
3127 | | |
3128 | | fn generate_string_view(len: usize, valid_percent: f64) -> StringViewArray { |
3129 | | let mut rng = rng(); |
3130 | | (0..len) |
3131 | | .map(|_| { |
3132 | | rng.random_bool(valid_percent).then(|| { |
3133 | | let len = rng.random_range(0..100); |
3134 | | let bytes = (0..len).map(|_| rng.random_range(0..128)).collect(); |
3135 | | String::from_utf8(bytes).unwrap() |
3136 | | }) |
3137 | | }) |
3138 | | .collect() |
3139 | | } |
3140 | | |
3141 | | fn generate_byte_view(len: usize, valid_percent: f64) -> BinaryViewArray { |
3142 | | let mut rng = rng(); |
3143 | | (0..len) |
3144 | | .map(|_| { |
3145 | | rng.random_bool(valid_percent).then(|| { |
3146 | | let len = rng.random_range(0..100); |
3147 | | let bytes: Vec<_> = (0..len).map(|_| rng.random_range(0..128)).collect(); |
3148 | | bytes |
3149 | | }) |
3150 | | }) |
3151 | | .collect() |
3152 | | } |
3153 | | |
3154 | | fn generate_fixed_stringview_column(len: usize) -> StringViewArray { |
3155 | | let edge_cases = vec![ |
3156 | | Some("bar".to_string()), |
3157 | | Some("bar\0".to_string()), |
3158 | | Some("LongerThan12Bytes".to_string()), |
3159 | | Some("LongerThan12Bytez".to_string()), |
3160 | | Some("LongerThan12Bytes\0".to_string()), |
3161 | | Some("LongerThan12Byt".to_string()), |
3162 | | Some("backend one".to_string()), |
3163 | | Some("backend two".to_string()), |
3164 | | Some("a".repeat(257)), |
3165 | | Some("a".repeat(300)), |
3166 | | ]; |
3167 | | |
3168 | | // Fill up to `len` by repeating edge cases and trimming |
3169 | | let mut values = Vec::with_capacity(len); |
3170 | | for i in 0..len { |
3171 | | values.push( |
3172 | | edge_cases |
3173 | | .get(i % edge_cases.len()) |
3174 | | .cloned() |
3175 | | .unwrap_or(None), |
3176 | | ); |
3177 | | } |
3178 | | |
3179 | | StringViewArray::from(values) |
3180 | | } |
3181 | | |
3182 | | fn generate_dictionary<K>( |
3183 | | values: ArrayRef, |
3184 | | len: usize, |
3185 | | valid_percent: f64, |
3186 | | ) -> DictionaryArray<K> |
3187 | | where |
3188 | | K: ArrowDictionaryKeyType, |
3189 | | K::Native: SampleUniform, |
3190 | | { |
3191 | | let mut rng = rng(); |
3192 | | let min_key = K::Native::from_usize(0).unwrap(); |
3193 | | let max_key = K::Native::from_usize(values.len()).unwrap(); |
3194 | | let keys: PrimitiveArray<K> = (0..len) |
3195 | | .map(|_| { |
3196 | | rng.random_bool(valid_percent) |
3197 | | .then(|| rng.random_range(min_key..max_key)) |
3198 | | }) |
3199 | | .collect(); |
3200 | | |
3201 | | let data_type = |
3202 | | DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone())); |
3203 | | |
3204 | | let data = keys |
3205 | | .into_data() |
3206 | | .into_builder() |
3207 | | .data_type(data_type) |
3208 | | .add_child_data(values.to_data()) |
3209 | | .build() |
3210 | | .unwrap(); |
3211 | | |
3212 | | DictionaryArray::from(data) |
3213 | | } |
3214 | | |
3215 | | fn generate_fixed_size_binary(len: usize, valid_percent: f64) -> FixedSizeBinaryArray { |
3216 | | let mut rng = rng(); |
3217 | | let width = rng.random_range(0..20); |
3218 | | let mut builder = FixedSizeBinaryBuilder::new(width); |
3219 | | |
3220 | | let mut b = vec![0; width as usize]; |
3221 | | for _ in 0..len { |
3222 | | match rng.random_bool(valid_percent) { |
3223 | | true => { |
3224 | | b.iter_mut().for_each(|x| *x = rng.random()); |
3225 | | builder.append_value(&b).unwrap(); |
3226 | | } |
3227 | | false => builder.append_null(), |
3228 | | } |
3229 | | } |
3230 | | |
3231 | | builder.finish() |
3232 | | } |
3233 | | |
3234 | | fn generate_struct(len: usize, valid_percent: f64) -> StructArray { |
3235 | | let mut rng = rng(); |
3236 | | let nulls = NullBuffer::from_iter((0..len).map(|_| rng.random_bool(valid_percent))); |
3237 | | let a = generate_primitive_array::<Int32Type>(len, valid_percent); |
3238 | | let b = generate_strings::<i32>(len, valid_percent); |
3239 | | let fields = Fields::from(vec![ |
3240 | | Field::new("a", DataType::Int32, true), |
3241 | | Field::new("b", DataType::Utf8, true), |
3242 | | ]); |
3243 | | let values = vec![Arc::new(a) as _, Arc::new(b) as _]; |
3244 | | StructArray::new(fields, values, Some(nulls)) |
3245 | | } |
3246 | | |
3247 | | fn generate_list<F>(len: usize, valid_percent: f64, values: F) -> ListArray |
3248 | | where |
3249 | | F: FnOnce(usize) -> ArrayRef, |
3250 | | { |
3251 | | let mut rng = rng(); |
3252 | | let offsets = OffsetBuffer::<i32>::from_lengths((0..len).map(|_| rng.random_range(0..10))); |
3253 | | let values_len = offsets.last().unwrap().to_usize().unwrap(); |
3254 | | let values = values(values_len); |
3255 | | let nulls = NullBuffer::from_iter((0..len).map(|_| rng.random_bool(valid_percent))); |
3256 | | let field = Arc::new(Field::new_list_field(values.data_type().clone(), true)); |
3257 | | ListArray::new(field, offsets, values, Some(nulls)) |
3258 | | } |
3259 | | |
3260 | | fn generate_column(len: usize) -> ArrayRef { |
3261 | | let mut rng = rng(); |
3262 | | match rng.random_range(0..18) { |
3263 | | 0 => Arc::new(generate_primitive_array::<Int32Type>(len, 0.8)), |
3264 | | 1 => Arc::new(generate_primitive_array::<UInt32Type>(len, 0.8)), |
3265 | | 2 => Arc::new(generate_primitive_array::<Int64Type>(len, 0.8)), |
3266 | | 3 => Arc::new(generate_primitive_array::<UInt64Type>(len, 0.8)), |
3267 | | 4 => Arc::new(generate_primitive_array::<Float32Type>(len, 0.8)), |
3268 | | 5 => Arc::new(generate_primitive_array::<Float64Type>(len, 0.8)), |
3269 | | 6 => Arc::new(generate_strings::<i32>(len, 0.8)), |
3270 | | 7 => Arc::new(generate_dictionary::<Int64Type>( |
3271 | | // Cannot test dictionaries containing null values because of #2687 |
3272 | | Arc::new(generate_strings::<i32>(rng.random_range(1..len), 1.0)), |
3273 | | len, |
3274 | | 0.8, |
3275 | | )), |
3276 | | 8 => Arc::new(generate_dictionary::<Int64Type>( |
3277 | | // Cannot test dictionaries containing null values because of #2687 |
3278 | | Arc::new(generate_primitive_array::<Int64Type>( |
3279 | | rng.random_range(1..len), |
3280 | | 1.0, |
3281 | | )), |
3282 | | len, |
3283 | | 0.8, |
3284 | | )), |
3285 | | 9 => Arc::new(generate_fixed_size_binary(len, 0.8)), |
3286 | | 10 => Arc::new(generate_struct(len, 0.8)), |
3287 | | 11 => Arc::new(generate_list(len, 0.8, |values_len| { |
3288 | | Arc::new(generate_primitive_array::<Int64Type>(values_len, 0.8)) |
3289 | | })), |
3290 | | 12 => Arc::new(generate_list(len, 0.8, |values_len| { |
3291 | | Arc::new(generate_strings::<i32>(values_len, 0.8)) |
3292 | | })), |
3293 | | 13 => Arc::new(generate_list(len, 0.8, |values_len| { |
3294 | | Arc::new(generate_struct(values_len, 0.8)) |
3295 | | })), |
3296 | | 14 => Arc::new(generate_string_view(len, 0.8)), |
3297 | | 15 => Arc::new(generate_byte_view(len, 0.8)), |
3298 | | 16 => Arc::new(generate_fixed_stringview_column(len)), |
3299 | | 17 => Arc::new( |
3300 | | generate_list(len + 1000, 0.8, |values_len| { |
3301 | | Arc::new(generate_primitive_array::<Int64Type>(values_len, 0.8)) |
3302 | | }) |
3303 | | .slice(500, len), |
3304 | | ), |
3305 | | _ => unreachable!(), |
3306 | | } |
3307 | | } |
3308 | | |
3309 | | fn print_row(cols: &[SortColumn], row: usize) -> String { |
3310 | | let t: Vec<_> = cols |
3311 | | .iter() |
3312 | | .map(|x| match x.values.is_valid(row) { |
3313 | | true => { |
3314 | | let opts = FormatOptions::default().with_null("NULL"); |
3315 | | let formatter = ArrayFormatter::try_new(x.values.as_ref(), &opts).unwrap(); |
3316 | | formatter.value(row).to_string() |
3317 | | } |
3318 | | false => "NULL".to_string(), |
3319 | | }) |
3320 | | .collect(); |
3321 | | t.join(",") |
3322 | | } |
3323 | | |
3324 | | fn print_col_types(cols: &[SortColumn]) -> String { |
3325 | | let t: Vec<_> = cols |
3326 | | .iter() |
3327 | | .map(|x| x.values.data_type().to_string()) |
3328 | | .collect(); |
3329 | | t.join(",") |
3330 | | } |
3331 | | |
3332 | | #[test] |
3333 | | #[cfg_attr(miri, ignore)] |
3334 | | fn fuzz_test() { |
3335 | | for _ in 0..100 { |
3336 | | let mut rng = rng(); |
3337 | | let num_columns = rng.random_range(1..5); |
3338 | | let len = rng.random_range(5..100); |
3339 | | let arrays: Vec<_> = (0..num_columns).map(|_| generate_column(len)).collect(); |
3340 | | |
3341 | | let options: Vec<_> = (0..num_columns) |
3342 | | .map(|_| SortOptions { |
3343 | | descending: rng.random_bool(0.5), |
3344 | | nulls_first: rng.random_bool(0.5), |
3345 | | }) |
3346 | | .collect(); |
3347 | | |
3348 | | let sort_columns: Vec<_> = options |
3349 | | .iter() |
3350 | | .zip(&arrays) |
3351 | | .map(|(o, c)| SortColumn { |
3352 | | values: Arc::clone(c), |
3353 | | options: Some(*o), |
3354 | | }) |
3355 | | .collect(); |
3356 | | |
3357 | | let comparator = LexicographicalComparator::try_new(&sort_columns).unwrap(); |
3358 | | |
3359 | | let columns: Vec<SortField> = options |
3360 | | .into_iter() |
3361 | | .zip(&arrays) |
3362 | | .map(|(o, a)| SortField::new_with_options(a.data_type().clone(), o)) |
3363 | | .collect(); |
3364 | | |
3365 | | let converter = RowConverter::new(columns).unwrap(); |
3366 | | let rows = converter.convert_columns(&arrays).unwrap(); |
3367 | | |
3368 | | for i in 0..len { |
3369 | | for j in 0..len { |
3370 | | let row_i = rows.row(i); |
3371 | | let row_j = rows.row(j); |
3372 | | let row_cmp = row_i.cmp(&row_j); |
3373 | | let lex_cmp = comparator.compare(i, j); |
3374 | | assert_eq!( |
3375 | | row_cmp, |
3376 | | lex_cmp, |
3377 | | "({:?} vs {:?}) vs ({:?} vs {:?}) for types {}", |
3378 | | print_row(&sort_columns, i), |
3379 | | print_row(&sort_columns, j), |
3380 | | row_i, |
3381 | | row_j, |
3382 | | print_col_types(&sort_columns) |
3383 | | ); |
3384 | | } |
3385 | | } |
3386 | | |
3387 | | // Convert rows produced from convert_columns(). |
3388 | | // Note: validate_utf8 is set to false since Row is initialized through empty_rows() |
3389 | | let back = converter.convert_rows(&rows).unwrap(); |
3390 | | for (actual, expected) in back.iter().zip(&arrays) { |
3391 | | actual.to_data().validate_full().unwrap(); |
3392 | | dictionary_eq(actual, expected) |
3393 | | } |
3394 | | |
3395 | | // Check that we can convert rows into ByteArray and then parse, convert it back to array |
3396 | | // Note: validate_utf8 is set to true since Row is initialized through RowParser |
3397 | | let rows = rows.try_into_binary().expect("reasonable size"); |
3398 | | let parser = converter.parser(); |
3399 | | let back = converter |
3400 | | .convert_rows(rows.iter().map(|b| parser.parse(b.expect("valid bytes")))) |
3401 | | .unwrap(); |
3402 | | for (actual, expected) in back.iter().zip(&arrays) { |
3403 | | actual.to_data().validate_full().unwrap(); |
3404 | | dictionary_eq(actual, expected) |
3405 | | } |
3406 | | |
3407 | | let rows = converter.from_binary(rows); |
3408 | | let back = converter.convert_rows(&rows).unwrap(); |
3409 | | for (actual, expected) in back.iter().zip(&arrays) { |
3410 | | actual.to_data().validate_full().unwrap(); |
3411 | | dictionary_eq(actual, expected) |
3412 | | } |
3413 | | } |
3414 | | } |
3415 | | |
3416 | | #[test] |
3417 | | fn test_clear() { |
3418 | | let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); |
3419 | | let mut rows = converter.empty_rows(3, 128); |
3420 | | |
3421 | | let first = Int32Array::from(vec![None, Some(2), Some(4)]); |
3422 | | let second = Int32Array::from(vec![Some(2), None, Some(4)]); |
3423 | | let arrays = [Arc::new(first) as ArrayRef, Arc::new(second) as ArrayRef]; |
3424 | | |
3425 | | for array in arrays.iter() { |
3426 | | rows.clear(); |
3427 | | converter |
3428 | | .append(&mut rows, std::slice::from_ref(array)) |
3429 | | .unwrap(); |
3430 | | let back = converter.convert_rows(&rows).unwrap(); |
3431 | | assert_eq!(&back[0], array); |
3432 | | } |
3433 | | |
3434 | | let mut rows_expected = converter.empty_rows(3, 128); |
3435 | | converter.append(&mut rows_expected, &arrays[1..]).unwrap(); |
3436 | | |
3437 | | for (i, (actual, expected)) in rows.iter().zip(rows_expected.iter()).enumerate() { |
3438 | | assert_eq!( |
3439 | | actual, expected, |
3440 | | "For row {i}: expected {expected:?}, actual: {actual:?}", |
3441 | | ); |
3442 | | } |
3443 | | } |
3444 | | |
3445 | | #[test] |
3446 | | fn test_append_codec_dictionary_binary() { |
3447 | | use DataType::*; |
3448 | | // Dictionary RowConverter |
3449 | | let converter = RowConverter::new(vec![SortField::new(Dictionary( |
3450 | | Box::new(Int32), |
3451 | | Box::new(Binary), |
3452 | | ))]) |
3453 | | .unwrap(); |
3454 | | let mut rows = converter.empty_rows(4, 128); |
3455 | | |
3456 | | let keys = Int32Array::from_iter_values([0, 1, 2, 3]); |
3457 | | let values = BinaryArray::from(vec![ |
3458 | | Some("a".as_bytes()), |
3459 | | Some(b"b"), |
3460 | | Some(b"c"), |
3461 | | Some(b"d"), |
3462 | | ]); |
3463 | | let dict_array = DictionaryArray::new(keys, Arc::new(values)); |
3464 | | |
3465 | | rows.clear(); |
3466 | | let array = Arc::new(dict_array) as ArrayRef; |
3467 | | converter |
3468 | | .append(&mut rows, std::slice::from_ref(&array)) |
3469 | | .unwrap(); |
3470 | | let back = converter.convert_rows(&rows).unwrap(); |
3471 | | |
3472 | | dictionary_eq(&back[0], &array); |
3473 | | } |
3474 | | |
3475 | | #[test] |
3476 | | fn test_list_prefix() { |
3477 | | let mut a = ListBuilder::new(Int8Builder::new()); |
3478 | | a.append_value([None]); |
3479 | | a.append_value([None, None]); |
3480 | | let a = a.finish(); |
3481 | | |
3482 | | let converter = RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap(); |
3483 | | let rows = converter.convert_columns(&[Arc::new(a) as _]).unwrap(); |
3484 | | assert_eq!(rows.row(0).cmp(&rows.row(1)), Ordering::Less); |
3485 | | } |
3486 | | |
3487 | | #[test] |
3488 | | fn map_should_be_marked_as_unsupported() { |
3489 | | let map_data_type = Field::new_map( |
3490 | | "map", |
3491 | | "entries", |
3492 | | Field::new("key", DataType::Utf8, false), |
3493 | | Field::new("value", DataType::Utf8, true), |
3494 | | false, |
3495 | | true, |
3496 | | ) |
3497 | | .data_type() |
3498 | | .clone(); |
3499 | | |
3500 | | let is_supported = RowConverter::supports_fields(&[SortField::new(map_data_type)]); |
3501 | | |
3502 | | assert!(!is_supported, "Map should not be supported"); |
3503 | | } |
3504 | | |
3505 | | #[test] |
3506 | | fn should_fail_to_create_row_converter_for_unsupported_map_type() { |
3507 | | let map_data_type = Field::new_map( |
3508 | | "map", |
3509 | | "entries", |
3510 | | Field::new("key", DataType::Utf8, false), |
3511 | | Field::new("value", DataType::Utf8, true), |
3512 | | false, |
3513 | | true, |
3514 | | ) |
3515 | | .data_type() |
3516 | | .clone(); |
3517 | | |
3518 | | let converter = RowConverter::new(vec![SortField::new(map_data_type)]); |
3519 | | |
3520 | | match converter { |
3521 | | Err(ArrowError::NotYetImplemented(message)) => { |
3522 | | assert!( |
3523 | | message.contains("Row format support not yet implemented for"), |
3524 | | "Expected NotYetImplemented error for map data type, got: {message}", |
3525 | | ); |
3526 | | } |
3527 | | Err(e) => panic!("Expected NotYetImplemented error, got: {e}"), |
3528 | | Ok(_) => panic!("Expected NotYetImplemented error for map data type"), |
3529 | | } |
3530 | | } |
3531 | | |
3532 | | #[test] |
3533 | | fn test_values_buffer_smaller_when_utf8_validation_disabled() { |
3534 | | fn get_values_buffer_len(col: ArrayRef) -> (usize, usize) { |
3535 | | // 1. Convert cols into rows |
3536 | | let converter = RowConverter::new(vec![SortField::new(DataType::Utf8View)]).unwrap(); |
3537 | | |
3538 | | // 2a. Convert rows into colsa (validate_utf8 = false) |
3539 | | let rows = converter.convert_columns(&[col]).unwrap(); |
3540 | | let converted = converter.convert_rows(&rows).unwrap(); |
3541 | | let unchecked_values_len = converted[0].as_string_view().data_buffers()[0].len(); |
3542 | | |
3543 | | // 2b. Convert rows into cols (validate_utf8 = true since Row is initialized through RowParser) |
3544 | | let rows = rows.try_into_binary().expect("reasonable size"); |
3545 | | let parser = converter.parser(); |
3546 | | let converted = converter |
3547 | | .convert_rows(rows.iter().map(|b| parser.parse(b.expect("valid bytes")))) |
3548 | | .unwrap(); |
3549 | | let checked_values_len = converted[0].as_string_view().data_buffers()[0].len(); |
3550 | | (unchecked_values_len, checked_values_len) |
3551 | | } |
3552 | | |
3553 | | // Case1. StringViewArray with inline strings |
3554 | | let col = Arc::new(StringViewArray::from_iter([ |
3555 | | Some("hello"), // short(5) |
3556 | | None, // null |
3557 | | Some("short"), // short(5) |
3558 | | Some("tiny"), // short(4) |
3559 | | ])) as ArrayRef; |
3560 | | |
3561 | | let (unchecked_values_len, checked_values_len) = get_values_buffer_len(col); |
3562 | | // Since there are no long (>12) strings, len of values buffer is 0 |
3563 | | assert_eq!(unchecked_values_len, 0); |
3564 | | // When utf8 validation enabled, values buffer includes inline strings (5+5+4) |
3565 | | assert_eq!(checked_values_len, 14); |
3566 | | |
3567 | | // Case2. StringViewArray with long(>12) strings |
3568 | | let col = Arc::new(StringViewArray::from_iter([ |
3569 | | Some("this is a very long string over 12 bytes"), |
3570 | | Some("another long string to test the buffer"), |
3571 | | ])) as ArrayRef; |
3572 | | |
3573 | | let (unchecked_values_len, checked_values_len) = get_values_buffer_len(col); |
3574 | | // Since there are no inline strings, expected length of values buffer is the same |
3575 | | assert!(unchecked_values_len > 0); |
3576 | | assert_eq!(unchecked_values_len, checked_values_len); |
3577 | | |
3578 | | // Case3. StringViewArray with both short and long strings |
3579 | | let col = Arc::new(StringViewArray::from_iter([ |
3580 | | Some("tiny"), // 4 (short) |
3581 | | Some("thisisexact13"), // 13 (long) |
3582 | | None, |
3583 | | Some("short"), // 5 (short) |
3584 | | ])) as ArrayRef; |
3585 | | |
3586 | | let (unchecked_values_len, checked_values_len) = get_values_buffer_len(col); |
3587 | | // Since there is single long string, len of values buffer is 13 |
3588 | | assert_eq!(unchecked_values_len, 13); |
3589 | | assert!(checked_values_len > unchecked_values_len); |
3590 | | } |
3591 | | } |