@@ -42,51 +42,55 @@ def _raise_bad_key_type(key: Any):
4242
4343
4444def _query_table_with_indices_mapping (
45- pa_table : pa . Table , key : Union [int , slice , range , str , Iterable ], indices : pa . lib . UInt64Array
45+ table : Table , key : Union [int , slice , range , str , Iterable ], indices : Table
4646) -> pa .Table :
4747 """
4848 Query a pyarrow Table to extract the subtable that correspond to the given key.
4949 The :obj:`indices` parameter corresponds to the indices mapping in case we cant to take into
5050 account a shuffling or an indices selection for example.
51+ The indices table must contain one column named "indices" of type uint64.
5152 """
5253 if isinstance (key , int ):
53- return _query_table (pa_table , indices [key ].as_py ())
54+ key = indices .fast_slice (key % indices .num_rows , 1 ).column (0 )[0 ].as_py ()
55+ return _query_table (table , key )
5456 if isinstance (key , slice ):
55- key = range (* key .indices (pa_table .num_rows ))
57+ key = range (* key .indices (table .num_rows ))
5658 if isinstance (key , range ):
57- if _is_range_contiguous (key ):
58- return _query_table (pa_table , [i .as_py () for i in indices .slice (key .start , key .stop - key .start )])
59+ if _is_range_contiguous (key ) and key .start >= 0 :
60+ return _query_table (
61+ table , [i .as_py () for i in indices .fast_slice (key .start , key .stop - key .start ).column (0 )]
62+ )
5963 else :
6064 pass # treat as an iterable
6165 if isinstance (key , str ):
62- pa_table = _query_table ( pa_table , key )
63- return _query_table (pa_table , indices .to_pylist ())
66+ table = table . drop ([ column for column in table . column_names if column != key ] )
67+ return _query_table (table , indices . column ( 0 ) .to_pylist ())
6468 if isinstance (key , Iterable ):
65- return _query_table (pa_table , [indices [ i ].as_py () for i in key ])
69+ return _query_table (table , [indices . fast_slice ( i , 1 ). column ( 0 )[ 0 ].as_py () for i in key ])
6670
6771 _raise_bad_key_type (key )
6872
6973
70- def _query_table (pa_table : pa . Table , key : Union [int , slice , range , str , Iterable ]) -> pa .Table :
74+ def _query_table (table : Table , key : Union [int , slice , range , str , Iterable ]) -> pa .Table :
7175 """
7276 Query a pyarrow Table to extract the subtable that correspond to the given key.
7377 """
7478 if isinstance (key , int ):
75- return pa_table . slice (key % pa_table .num_rows , 1 )
79+ return table . fast_slice (key % table .num_rows , 1 )
7680 if isinstance (key , slice ):
77- key = range (* key .indices (pa_table .num_rows ))
81+ key = range (* key .indices (table .num_rows ))
7882 if isinstance (key , range ):
7983 if _is_range_contiguous (key ) and key .start >= 0 :
80- return pa_table . slice (key .start , key .stop - key .start )
84+ return table . fast_slice (key .start , key .stop - key .start )
8185 else :
8286 pass # treat as an iterable
8387 if isinstance (key , str ):
84- return pa_table . drop ([column for column in pa_table .column_names if column != key ])
88+ return table . table . drop ([column for column in table .column_names if column != key ])
8589 if isinstance (key , Iterable ):
8690 if len (key ) == 0 :
87- return pa_table .slice (0 , 0 )
91+ return table . table .slice (0 , 0 )
8892 # don't use pyarrow.Table.take even for pyarrow >=1.0 (see https://issues.apache.org/jira/browse/ARROW-9773)
89- return pa .concat_tables (pa_table . slice (int (i ) % pa_table .num_rows , 1 ) for i in key )
93+ return pa .concat_tables (table . fast_slice (int (i ) % table .num_rows , 1 ) for i in key )
9094
9195 _raise_bad_key_type (key )
9296
@@ -306,7 +310,7 @@ def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
306310def query_table (
307311 table : Table ,
308312 key : Union [int , slice , range , str , Iterable ],
309- indices : Optional [pa . lib . UInt64Array ] = None ,
313+ indices : Optional [Table ] = None ,
310314) -> pa .Table :
311315 """
312316 Query a Table to extract the subtable that correspond to the given key.
@@ -319,30 +323,27 @@ def query_table(
319323 - a range(i, j, k): the subtable containing the rows that correspond to this range
320324 - a string c: the subtable containing all the rows but only the column c
321325 - an iterable l: the subtable that is the concatenation of all the i-th rows for all i in the iterable
322- indices (Optional ``pyarrow.lib.UInt64Array``): If not None, it is used to re-map the given key to the table rows.
326+ indices (Optional ``datasets.table.Table``): If not None, it is used to re-map the given key to the table rows.
327+ The indices table must contain one column named "indices" of type uint64.
323328 This is used in case of shuffling or rows selection.
324329
325330
326331 Returns:
327332 ``pyarrow.Table``: the result of the query on the input table
328333 """
329- if isinstance (table , Table ):
330- pa_table = table .table
331- else :
332- pa_table = table
333334 # Check if key is valid
334335 if not isinstance (key , (int , slice , range , str , Iterable )):
335336 _raise_bad_key_type (key )
336337 if isinstance (key , str ):
337- _check_valid_column_key (key , pa_table .column_names )
338+ _check_valid_column_key (key , table .column_names )
338339 else :
339- size = len ( indices ) if indices is not None else pa_table .num_rows
340+ size = indices . num_rows if indices is not None else table .num_rows
340341 _check_valid_index_key (key , size )
341342 # Query the main table
342343 if indices is None :
343- pa_subtable = _query_table (pa_table , key )
344+ pa_subtable = _query_table (table , key )
344345 else :
345- pa_subtable = _query_table_with_indices_mapping (pa_table , key , indices = indices )
346+ pa_subtable = _query_table_with_indices_mapping (table , key , indices = indices )
346347 return pa_subtable
347348
348349
0 commit comments