/Users/andrewlamb/Software/arrow-rs/arrow-ord/src/rank.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Provides `rank` function to assign a rank to each value in an array |
19 | | |
20 | | use arrow_array::cast::AsArray; |
21 | | use arrow_array::types::*; |
22 | | use arrow_array::{ |
23 | | downcast_primitive_array, Array, ArrowNativeTypeOp, BooleanArray, GenericByteArray, |
24 | | }; |
25 | | use arrow_buffer::NullBuffer; |
26 | | use arrow_schema::{ArrowError, DataType, SortOptions}; |
27 | | use std::cmp::Ordering; |
28 | | |
29 | | /// Whether `arrow_ord::rank` can rank an array of given data type. |
30 | 0 | pub(crate) fn can_rank(data_type: &DataType) -> bool { |
31 | 0 | data_type.is_primitive() |
32 | 0 | || matches!( |
33 | 0 | data_type, |
34 | | DataType::Boolean |
35 | | | DataType::Utf8 |
36 | | | DataType::LargeUtf8 |
37 | | | DataType::Binary |
38 | | | DataType::LargeBinary |
39 | | ) |
40 | 0 | } |
41 | | |
42 | | /// Assigns a rank to each value in `array` based on its position in the sorted order |
43 | | /// |
44 | | /// Where values are equal, they will be assigned the highest of their ranks, |
45 | | /// leaving gaps in the overall rank assignment |
46 | | /// |
47 | | /// ``` |
48 | | /// # use arrow_array::StringArray; |
49 | | /// # use arrow_ord::rank::rank; |
50 | | /// let array = StringArray::from(vec![Some("foo"), None, Some("foo"), None, Some("bar")]); |
51 | | /// let ranks = rank(&array, None).unwrap(); |
52 | | /// assert_eq!(ranks, &[5, 2, 5, 2, 3]); |
53 | | /// ``` |
54 | 0 | pub fn rank(array: &dyn Array, options: Option<SortOptions>) -> Result<Vec<u32>, ArrowError> { |
55 | 0 | let options = options.unwrap_or_default(); |
56 | 0 | let ranks = downcast_primitive_array! { |
57 | 0 | array => primitive_rank(array.values(), array.nulls(), options), |
58 | 0 | DataType::Boolean => boolean_rank(array.as_boolean(), options), |
59 | 0 | DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options), |
60 | 0 | DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(), options), |
61 | 0 | DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(), options), |
62 | 0 | DataType::LargeBinary => bytes_rank(array.as_bytes::<LargeBinaryType>(), options), |
63 | 0 | d => return Err(ArrowError::ComputeError(format!("{d:?} not supported in rank"))) |
64 | | }; |
65 | 0 | Ok(ranks) |
66 | 0 | } |
67 | | |
68 | | #[inline(never)] |
69 | 0 | fn primitive_rank<T: ArrowNativeTypeOp>( |
70 | 0 | values: &[T], |
71 | 0 | nulls: Option<&NullBuffer>, |
72 | 0 | options: SortOptions, |
73 | 0 | ) -> Vec<u32> { |
74 | 0 | let len: u32 = values.len().try_into().unwrap(); |
75 | 0 | let to_sort = match nulls.filter(|n| n.null_count() > 0) { |
76 | 0 | Some(n) => n |
77 | 0 | .valid_indices() |
78 | 0 | .map(|idx| (values[idx], idx as u32)) |
79 | 0 | .collect(), |
80 | 0 | None => values.iter().copied().zip(0..len).collect(), |
81 | | }; |
82 | 0 | rank_impl(values.len(), to_sort, options, T::compare, T::is_eq) |
83 | 0 | } |
84 | | |
85 | | #[inline(never)] |
86 | 0 | fn bytes_rank<T: ByteArrayType>(array: &GenericByteArray<T>, options: SortOptions) -> Vec<u32> { |
87 | 0 | let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) { |
88 | 0 | Some(n) => n |
89 | 0 | .valid_indices() |
90 | 0 | .map(|idx| (array.value(idx).as_ref(), idx as u32)) |
91 | 0 | .collect(), |
92 | 0 | None => (0..array.len()) |
93 | 0 | .map(|idx| (array.value(idx).as_ref(), idx as u32)) |
94 | 0 | .collect(), |
95 | | }; |
96 | 0 | rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq) |
97 | 0 | } |
98 | | |
99 | 0 | fn rank_impl<T, C, E>( |
100 | 0 | len: usize, |
101 | 0 | mut valid: Vec<(T, u32)>, |
102 | 0 | options: SortOptions, |
103 | 0 | compare: C, |
104 | 0 | eq: E, |
105 | 0 | ) -> Vec<u32> |
106 | 0 | where |
107 | 0 | T: Copy, |
108 | 0 | C: Fn(T, T) -> Ordering, |
109 | 0 | E: Fn(T, T) -> bool, |
110 | | { |
111 | | // We can use an unstable sort as we combine equal values later |
112 | 0 | valid.sort_unstable_by(|a, b| compare(a.0, b.0)); |
113 | 0 | if options.descending { |
114 | 0 | valid.reverse(); |
115 | 0 | } |
116 | | |
117 | 0 | let (mut valid_rank, null_rank) = match options.nulls_first { |
118 | 0 | true => (len as u32, (len - valid.len()) as u32), |
119 | 0 | false => (valid.len() as u32, len as u32), |
120 | | }; |
121 | | |
122 | 0 | let mut out: Vec<_> = vec![null_rank; len]; |
123 | 0 | if let Some(v) = valid.last() { |
124 | 0 | out[v.1 as usize] = valid_rank; |
125 | 0 | } |
126 | | |
127 | 0 | let mut count = 1; // Number of values in rank |
128 | 0 | for w in valid.windows(2).rev() { |
129 | 0 | match eq(w[0].0, w[1].0) { |
130 | 0 | true => { |
131 | 0 | count += 1; |
132 | 0 | out[w[0].1 as usize] = valid_rank; |
133 | 0 | } |
134 | | false => { |
135 | 0 | valid_rank -= count; |
136 | 0 | count = 1; |
137 | 0 | out[w[0].1 as usize] = valid_rank |
138 | | } |
139 | | } |
140 | | } |
141 | | |
142 | 0 | out |
143 | 0 | } |
144 | | |
145 | | /// Return the index for the rank when ranking boolean array |
146 | | /// |
147 | | /// The index is calculated as follows: |
148 | | /// if is_null is true, the index is 2 |
149 | | /// if is_null is false and the value is true, the index is 1 |
150 | | /// otherwise, the index is 0 |
151 | | /// |
152 | | /// false is 0 and true is 1 because these are the value when cast to number |
153 | | #[inline] |
154 | 0 | fn get_boolean_rank_index(value: bool, is_null: bool) -> usize { |
155 | 0 | let is_null_num = is_null as usize; |
156 | 0 | (is_null_num << 1) | (value as usize & !is_null_num) |
157 | 0 | } |
158 | | |
159 | | #[inline(never)] |
160 | 0 | fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec<u32> { |
161 | 0 | let null_count = array.null_count() as u32; |
162 | 0 | let true_count = array.true_count() as u32; |
163 | 0 | let false_count = array.len() as u32 - null_count - true_count; |
164 | | |
165 | | // Rank values for [false, true, null] in that order |
166 | | // |
167 | | // The value for a rank is last value rank + own value count |
168 | | // this means that if we have the following order: `false`, `true` and then `null` |
169 | | // the ranks will be: |
170 | | // - false: false_count |
171 | | // - true: false_count + true_count |
172 | | // - null: false_count + true_count + null_count |
173 | | // |
174 | | // If we have the following order: `null`, `false` and then `true` |
175 | | // the ranks will be: |
176 | | // - false: null_count + false_count |
177 | | // - true: null_count + false_count + true_count |
178 | | // - null: null_count |
179 | | // |
180 | | // You will notice that the last rank is always the total length of the array but we don't use it for readability on how the rank is calculated |
181 | 0 | let ranks_index: [u32; 3] = match (options.descending, options.nulls_first) { |
182 | | // The order is null, true, false |
183 | 0 | (true, true) => [ |
184 | 0 | null_count + true_count + false_count, |
185 | 0 | null_count + true_count, |
186 | 0 | null_count, |
187 | 0 | ], |
188 | | // The order is true, false, null |
189 | 0 | (true, false) => [ |
190 | 0 | true_count + false_count, |
191 | 0 | true_count, |
192 | 0 | true_count + false_count + null_count, |
193 | 0 | ], |
194 | | // The order is null, false, true |
195 | 0 | (false, true) => [ |
196 | 0 | null_count + false_count, |
197 | 0 | null_count + false_count + true_count, |
198 | 0 | null_count, |
199 | 0 | ], |
200 | | // The order is false, true, null |
201 | 0 | (false, false) => [ |
202 | 0 | false_count, |
203 | 0 | false_count + true_count, |
204 | 0 | false_count + true_count + null_count, |
205 | 0 | ], |
206 | | }; |
207 | | |
208 | 0 | match array.nulls().filter(|n| n.null_count() > 0) { |
209 | 0 | Some(n) => array |
210 | 0 | .values() |
211 | 0 | .iter() |
212 | 0 | .zip(n.iter()) |
213 | 0 | .map(|(value, is_valid)| ranks_index[get_boolean_rank_index(value, !is_valid)]) |
214 | 0 | .collect::<Vec<u32>>(), |
215 | 0 | None => array |
216 | 0 | .values() |
217 | 0 | .iter() |
218 | 0 | .map(|value| ranks_index[value as usize]) |
219 | 0 | .collect::<Vec<u32>>(), |
220 | | } |
221 | 0 | } |
222 | | |
223 | | #[cfg(test)] |
224 | | mod tests { |
225 | | use super::*; |
226 | | use arrow_array::*; |
227 | | |
228 | | #[test] |
229 | | fn test_primitive() { |
230 | | let descending = SortOptions { |
231 | | descending: true, |
232 | | nulls_first: true, |
233 | | }; |
234 | | |
235 | | let nulls_last = SortOptions { |
236 | | descending: false, |
237 | | nulls_first: false, |
238 | | }; |
239 | | |
240 | | let nulls_last_descending = SortOptions { |
241 | | descending: true, |
242 | | nulls_first: false, |
243 | | }; |
244 | | |
245 | | let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3), Some(3), Some(4)]); |
246 | | let res = rank(&a, None).unwrap(); |
247 | | assert_eq!(res, &[3, 3, 1, 5, 5, 6]); |
248 | | |
249 | | let res = rank(&a, Some(descending)).unwrap(); |
250 | | assert_eq!(res, &[6, 6, 1, 4, 4, 2]); |
251 | | |
252 | | let res = rank(&a, Some(nulls_last)).unwrap(); |
253 | | assert_eq!(res, &[2, 2, 6, 4, 4, 5]); |
254 | | |
255 | | let res = rank(&a, Some(nulls_last_descending)).unwrap(); |
256 | | assert_eq!(res, &[5, 5, 6, 3, 3, 1]); |
257 | | |
258 | | // Test with non-zero null values |
259 | | let nulls = NullBuffer::from(vec![true, true, false, true, false, false]); |
260 | | let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls)); |
261 | | let res = rank(&a, None).unwrap(); |
262 | | assert_eq!(res, &[4, 6, 3, 6, 3, 3]); |
263 | | } |
264 | | |
265 | | #[test] |
266 | | fn test_get_boolean_rank_index() { |
267 | | assert_eq!(get_boolean_rank_index(true, true), 2); |
268 | | assert_eq!(get_boolean_rank_index(false, true), 2); |
269 | | assert_eq!(get_boolean_rank_index(true, false), 1); |
270 | | assert_eq!(get_boolean_rank_index(false, false), 0); |
271 | | } |
272 | | |
273 | | #[test] |
274 | | fn test_nullable_booleans() { |
275 | | let descending = SortOptions { |
276 | | descending: true, |
277 | | nulls_first: true, |
278 | | }; |
279 | | |
280 | | let nulls_last = SortOptions { |
281 | | descending: false, |
282 | | nulls_first: false, |
283 | | }; |
284 | | |
285 | | let nulls_last_descending = SortOptions { |
286 | | descending: true, |
287 | | nulls_first: false, |
288 | | }; |
289 | | |
290 | | let a = BooleanArray::from(vec![Some(true), Some(true), None, Some(false), Some(false)]); |
291 | | let res = rank(&a, None).unwrap(); |
292 | | assert_eq!(res, &[5, 5, 1, 3, 3]); |
293 | | |
294 | | let res = rank(&a, Some(descending)).unwrap(); |
295 | | assert_eq!(res, &[3, 3, 1, 5, 5]); |
296 | | |
297 | | let res = rank(&a, Some(nulls_last)).unwrap(); |
298 | | assert_eq!(res, &[4, 4, 5, 2, 2]); |
299 | | |
300 | | let res = rank(&a, Some(nulls_last_descending)).unwrap(); |
301 | | assert_eq!(res, &[2, 2, 5, 4, 4]); |
302 | | |
303 | | // Test with non-zero null values |
304 | | let nulls = NullBuffer::from(vec![true, true, false, true, true]); |
305 | | let a = BooleanArray::new(vec![true, true, true, false, false].into(), Some(nulls)); |
306 | | let res = rank(&a, None).unwrap(); |
307 | | assert_eq!(res, &[5, 5, 1, 3, 3]); |
308 | | } |
309 | | |
310 | | #[test] |
311 | | fn test_booleans() { |
312 | | let descending = SortOptions { |
313 | | descending: true, |
314 | | nulls_first: true, |
315 | | }; |
316 | | |
317 | | let nulls_last = SortOptions { |
318 | | descending: false, |
319 | | nulls_first: false, |
320 | | }; |
321 | | |
322 | | let nulls_last_descending = SortOptions { |
323 | | descending: true, |
324 | | nulls_first: false, |
325 | | }; |
326 | | |
327 | | let a = BooleanArray::from(vec![true, false, false, false, true]); |
328 | | let res = rank(&a, None).unwrap(); |
329 | | assert_eq!(res, &[5, 3, 3, 3, 5]); |
330 | | |
331 | | let res = rank(&a, Some(descending)).unwrap(); |
332 | | assert_eq!(res, &[2, 5, 5, 5, 2]); |
333 | | |
334 | | let res = rank(&a, Some(nulls_last)).unwrap(); |
335 | | assert_eq!(res, &[5, 3, 3, 3, 5]); |
336 | | |
337 | | let res = rank(&a, Some(nulls_last_descending)).unwrap(); |
338 | | assert_eq!(res, &[2, 5, 5, 5, 2]); |
339 | | } |
340 | | |
341 | | #[test] |
342 | | fn test_bytes() { |
343 | | let v = vec!["foo", "fo", "bar", "bar"]; |
344 | | let values = StringArray::from(v.clone()); |
345 | | let res = rank(&values, None).unwrap(); |
346 | | assert_eq!(res, &[4, 3, 2, 2]); |
347 | | |
348 | | let values = LargeStringArray::from(v.clone()); |
349 | | let res = rank(&values, None).unwrap(); |
350 | | assert_eq!(res, &[4, 3, 2, 2]); |
351 | | |
352 | | let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]]; |
353 | | let values = LargeBinaryArray::from(v.clone()); |
354 | | let res = rank(&values, None).unwrap(); |
355 | | assert_eq!(res, &[3, 1, 4, 3]); |
356 | | |
357 | | let values = BinaryArray::from(v); |
358 | | let res = rank(&values, None).unwrap(); |
359 | | assert_eq!(res, &[3, 1, 4, 3]); |
360 | | } |
361 | | } |