Skip to content

Commit ae8b940

Browse files
authored
Fast table queries with interpolation search (#2122)
* add interpolation search * update dataset and formatting * update test_formatting * test interpolation search * docstrings * add benchmark * update benchmarks * add indexed table test
1 parent 76c2a61 commit ae8b940

File tree

10 files changed

+327
-106
lines changed

10 files changed

+327
-106
lines changed

.github/workflows/benchmarks.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ jobs:
1616
pip install setuptools wheel
1717
pip install -e .[benchmarks]
1818
19-
# pyarrow==0.17.1
20-
pip install pyarrow==0.17.1
19+
# pyarrow==1.0.0
20+
pip install pyarrow==1.0.0
2121
2222
dvc repro --force
2323
@@ -26,7 +26,7 @@ jobs:
2626
2727
python ./benchmarks/format.py report.json report.md
2828
29-
echo "<details>\n<summary>Show benchmarks</summary>\n\nPyArrow==0.17.1\n" > final_report.md
29+
echo "<details>\n<summary>Show benchmarks</summary>\n\nPyArrow==1.0.0\n" > final_report.md
3030
cat report.md >> final_report.md
3131
3232
# pyarrow
@@ -39,7 +39,7 @@ jobs:
3939
4040
python ./benchmarks/format.py report.json report.md
4141
42-
echo "\nPyArrow==1.0\n" >> final_report.md
42+
echo "\nPyArrow==latest\n" >> final_report.md
4343
cat report.md >> final_report.md
4444
echo "\n</details>" >> final_report.md
4545
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import json
2+
import os
3+
from dataclasses import dataclass
4+
5+
import numpy as np
6+
import pyarrow as pa
7+
from utils import get_duration
8+
9+
import datasets
10+
11+
12+
SPEED_TEST_N_EXAMPLES = 100_000_000_000
13+
SPEED_TEST_CHUNK_SIZE = 10_000
14+
15+
RESULTS_BASEPATH, RESULTS_FILENAME = os.path.split(__file__)
16+
RESULTS_FILE_PATH = os.path.join(RESULTS_BASEPATH, "results", RESULTS_FILENAME.replace(".py", ".json"))
17+
18+
19+
def generate_100B_dataset(num_examples: int, chunk_size: int) -> datasets.Dataset:
20+
table = pa.Table.from_pydict({"col": [0] * chunk_size})
21+
table = pa.concat_tables([table] * (num_examples // chunk_size))
22+
return datasets.Dataset(table, fingerprint="table_100B")
23+
24+
25+
@dataclass
26+
class RandIter:
27+
low: int
28+
high: int
29+
size: int
30+
seed: int
31+
32+
def __post_init__(self):
33+
rng = np.random.default_rng(self.seed)
34+
self._sampled_values = rng.integers(low=self.low, high=self.high, size=self.size).tolist()
35+
36+
def __iter__(self):
37+
return iter(self._sampled_values)
38+
39+
def __len__(self):
40+
return self.size
41+
42+
43+
@get_duration
44+
def get_first_row(dataset: datasets.Dataset):
45+
_ = dataset[0]
46+
47+
48+
@get_duration
49+
def get_last_row(dataset: datasets.Dataset):
50+
_ = dataset[-1]
51+
52+
53+
@get_duration
54+
def get_batch_of_1024_rows(dataset: datasets.Dataset):
55+
_ = dataset[range(len(dataset) // 2, len(dataset) // 2 + 1024)]
56+
57+
58+
@get_duration
59+
def get_batch_of_1024_random_rows(dataset: datasets.Dataset):
60+
_ = dataset[RandIter(0, len(dataset), 1024, seed=42)]
61+
62+
63+
def benchmark_table_100B():
64+
times = {"num examples": SPEED_TEST_N_EXAMPLES}
65+
functions = (get_first_row, get_last_row, get_batch_of_1024_rows, get_batch_of_1024_random_rows)
66+
print("generating dataset")
67+
dataset = generate_100B_dataset(num_examples=SPEED_TEST_N_EXAMPLES, chunk_size=SPEED_TEST_CHUNK_SIZE)
68+
print("Functions")
69+
for func in functions:
70+
print(func.__name__)
71+
times[func.__name__] = func(dataset)
72+
73+
with open(RESULTS_FILE_PATH, "wb") as f:
74+
f.write(json.dumps(times).encode("utf-8"))
75+
76+
77+
if __name__ == "__main__": # useful to run the profiler
78+
benchmark_table_100B()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"num examples": 100000000000, "get_first_row": 0.00019991099999927542, "get_last_row": 5.4411000000698095e-05, "get_batch_of_1024_rows": 0.0004897069999998394, "get_batch_of_1024_random_rows": 0.01800621099999944}

dvc.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ stages:
3030
metrics:
3131
- ./benchmarks/results/benchmark_iterating.json:
3232
cache: false
33+
34+
benchmark_getitem_100B:
35+
cmd: python ./benchmarks/benchmark_getitem_100B.py
36+
deps:
37+
- ./benchmarks/benchmark_getitem_100B.py
38+
metrics:
39+
- ./benchmarks/results/benchmark_getitem_100B.json:
40+
cache: false

src/datasets/arrow_dataset.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,18 @@ def wrapper(*args, **kwargs):
191191
return wrapper
192192

193193

194+
def _check_table(table) -> Table:
195+
"""We check the table type to make sure it's an instance of :class:`datasets.table.Table`"""
196+
if isinstance(table, pa.Table):
197+
# for a pyarrow table, we can just consider it as a in-memory table
198+
# this is here for backward compatibility
199+
return InMemoryTable(table)
200+
elif isinstance(table, Table):
201+
return table
202+
else:
203+
raise TypeError(f"Expected a pyarrow.Table or a datasets.table.Table object, but got {table}.")
204+
205+
194206
class Dataset(DatasetInfoMixin, IndexableMixin):
195207
"""A Dataset backed by an Arrow table."""
196208

@@ -206,8 +218,8 @@ def __init__(
206218
DatasetInfoMixin.__init__(self, info=info, split=split)
207219
IndexableMixin.__init__(self)
208220

209-
self._data: Table = arrow_table
210-
self._indices: Optional[Table] = indices_table
221+
self._data: Table = _check_table(arrow_table)
222+
self._indices: Optional[Table] = _check_table(indices_table) if indices_table is not None else None
211223

212224
self._format_type: Optional[str] = None
213225
self._format_kwargs: dict = {}
@@ -1157,9 +1169,7 @@ def _getitem(
11571169
"""
11581170
format_kwargs = format_kwargs if format_kwargs is not None else {}
11591171
formatter = get_formatter(format_type, **format_kwargs)
1160-
pa_subtable = query_table(
1161-
self._data, key, indices=self._indices.column(0) if self._indices is not None else None
1162-
)
1172+
pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
11631173
formatted_output = format_table(
11641174
pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
11651175
)
@@ -1870,7 +1880,7 @@ def select(
18701880
if self._indices is not None:
18711881
indices_array = self._indices.column(0).take(indices_array)
18721882

1873-
indices_table = InMemoryTable.from_arrays([indices_array], names=["indices"])
1883+
indices_table = pa.Table.from_arrays([indices_array], names=["indices"])
18741884

18751885
with writer:
18761886
try:
@@ -2427,15 +2437,15 @@ def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Un
24272437
return query_table(
24282438
table=self._data,
24292439
key=slice(0, len(self)),
2430-
indices=self._indices.column(0) if self._indices is not None else None,
2440+
indices=self._indices if self._indices is not None else None,
24312441
).to_pydict()
24322442
else:
24332443
batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
24342444
return (
24352445
query_table(
24362446
table=self._data,
24372447
key=slice(offset, offset + batch_size),
2438-
indices=self._indices.column(0) if self._indices is not None else None,
2448+
indices=self._indices if self._indices is not None else None,
24392449
).to_pydict()
24402450
for offset in range(0, len(self), batch_size)
24412451
)
@@ -2458,15 +2468,15 @@ def to_pandas(
24582468
return query_table(
24592469
table=self._data,
24602470
key=slice(0, len(self)),
2461-
indices=self._indices.column(0) if self._indices is not None else None,
2471+
indices=self._indices if self._indices is not None else None,
24622472
).to_pandas()
24632473
else:
24642474
batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
24652475
return (
24662476
query_table(
24672477
table=self._data,
24682478
key=slice(offset, offset + batch_size),
2469-
indices=self._indices.column(0) if self._indices is not None else None,
2479+
indices=self._indices if self._indices is not None else None,
24702480
).to_pandas()
24712481
for offset in range(0, len(self), batch_size)
24722482
)

src/datasets/formatting/formatting.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,51 +42,55 @@ def _raise_bad_key_type(key: Any):
4242

4343

4444
def _query_table_with_indices_mapping(
45-
pa_table: pa.Table, key: Union[int, slice, range, str, Iterable], indices: pa.lib.UInt64Array
45+
table: Table, key: Union[int, slice, range, str, Iterable], indices: Table
4646
) -> pa.Table:
4747
"""
4848
Query a pyarrow Table to extract the subtable that correspond to the given key.
4949
The :obj:`indices` parameter corresponds to the indices mapping in case we cant to take into
5050
account a shuffling or an indices selection for example.
51+
The indices table must contain one column named "indices" of type uint64.
5152
"""
5253
if isinstance(key, int):
53-
return _query_table(pa_table, indices[key].as_py())
54+
key = indices.fast_slice(key % indices.num_rows, 1).column(0)[0].as_py()
55+
return _query_table(table, key)
5456
if isinstance(key, slice):
55-
key = range(*key.indices(pa_table.num_rows))
57+
key = range(*key.indices(table.num_rows))
5658
if isinstance(key, range):
57-
if _is_range_contiguous(key):
58-
return _query_table(pa_table, [i.as_py() for i in indices.slice(key.start, key.stop - key.start)])
59+
if _is_range_contiguous(key) and key.start >= 0:
60+
return _query_table(
61+
table, [i.as_py() for i in indices.fast_slice(key.start, key.stop - key.start).column(0)]
62+
)
5963
else:
6064
pass # treat as an iterable
6165
if isinstance(key, str):
62-
pa_table = _query_table(pa_table, key)
63-
return _query_table(pa_table, indices.to_pylist())
66+
table = table.drop([column for column in table.column_names if column != key])
67+
return _query_table(table, indices.column(0).to_pylist())
6468
if isinstance(key, Iterable):
65-
return _query_table(pa_table, [indices[i].as_py() for i in key])
69+
return _query_table(table, [indices.fast_slice(i, 1).column(0)[0].as_py() for i in key])
6670

6771
_raise_bad_key_type(key)
6872

6973

70-
def _query_table(pa_table: pa.Table, key: Union[int, slice, range, str, Iterable]) -> pa.Table:
74+
def _query_table(table: Table, key: Union[int, slice, range, str, Iterable]) -> pa.Table:
7175
"""
7276
Query a pyarrow Table to extract the subtable that correspond to the given key.
7377
"""
7478
if isinstance(key, int):
75-
return pa_table.slice(key % pa_table.num_rows, 1)
79+
return table.fast_slice(key % table.num_rows, 1)
7680
if isinstance(key, slice):
77-
key = range(*key.indices(pa_table.num_rows))
81+
key = range(*key.indices(table.num_rows))
7882
if isinstance(key, range):
7983
if _is_range_contiguous(key) and key.start >= 0:
80-
return pa_table.slice(key.start, key.stop - key.start)
84+
return table.fast_slice(key.start, key.stop - key.start)
8185
else:
8286
pass # treat as an iterable
8387
if isinstance(key, str):
84-
return pa_table.drop([column for column in pa_table.column_names if column != key])
88+
return table.table.drop([column for column in table.column_names if column != key])
8589
if isinstance(key, Iterable):
8690
if len(key) == 0:
87-
return pa_table.slice(0, 0)
91+
return table.table.slice(0, 0)
8892
# don't use pyarrow.Table.take even for pyarrow >=1.0 (see https://issues.apache.org/jira/browse/ARROW-9773)
89-
return pa.concat_tables(pa_table.slice(int(i) % pa_table.num_rows, 1) for i in key)
93+
return pa.concat_tables(table.fast_slice(int(i) % table.num_rows, 1) for i in key)
9094

9195
_raise_bad_key_type(key)
9296

@@ -306,7 +310,7 @@ def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
306310
def query_table(
307311
table: Table,
308312
key: Union[int, slice, range, str, Iterable],
309-
indices: Optional[pa.lib.UInt64Array] = None,
313+
indices: Optional[Table] = None,
310314
) -> pa.Table:
311315
"""
312316
Query a Table to extract the subtable that correspond to the given key.
@@ -319,30 +323,27 @@ def query_table(
319323
- a range(i, j, k): the subtable containing the rows that correspond to this range
320324
- a string c: the subtable containing all the rows but only the column c
321325
- an iterable l: the subtable that is the concatenation of all the i-th rows for all i in the iterable
322-
indices (Optional ``pyarrow.lib.UInt64Array``): If not None, it is used to re-map the given key to the table rows.
326+
indices (Optional ``datasets.table.Table``): If not None, it is used to re-map the given key to the table rows.
327+
The indices table must contain one column named "indices" of type uint64.
323328
This is used in case of shuffling or rows selection.
324329
325330
326331
Returns:
327332
``pyarrow.Table``: the result of the query on the input table
328333
"""
329-
if isinstance(table, Table):
330-
pa_table = table.table
331-
else:
332-
pa_table = table
333334
# Check if key is valid
334335
if not isinstance(key, (int, slice, range, str, Iterable)):
335336
_raise_bad_key_type(key)
336337
if isinstance(key, str):
337-
_check_valid_column_key(key, pa_table.column_names)
338+
_check_valid_column_key(key, table.column_names)
338339
else:
339-
size = len(indices) if indices is not None else pa_table.num_rows
340+
size = indices.num_rows if indices is not None else table.num_rows
340341
_check_valid_index_key(key, size)
341342
# Query the main table
342343
if indices is None:
343-
pa_subtable = _query_table(pa_table, key)
344+
pa_subtable = _query_table(table, key)
344345
else:
345-
pa_subtable = _query_table_with_indices_mapping(pa_table, key, indices=indices)
346+
pa_subtable = _query_table_with_indices_mapping(table, key, indices=indices)
346347
return pa_subtable
347348

348349

src/datasets/io/csv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _write(
8989
batch = query_table(
9090
table=self.dataset._data,
9191
key=slice(offset, offset + batch_size),
92-
indices=self.dataset._indices.column(0) if self.dataset._indices is not None else None,
92+
indices=self.dataset._indices if self.dataset._indices is not None else None,
9393
)
9494
csv_str = batch.to_pandas().to_csv(
9595
path_or_buf=None, header=header if (offset == 0) else False, encoding=encoding, **to_csv_kwargs

0 commit comments

Comments
 (0)