diff --git a/images/test/Dockerfile b/images/test/Dockerfile index 55f1fc5415..3ca81f9e01 100644 --- a/images/test/Dockerfile +++ b/images/test/Dockerfile @@ -10,6 +10,10 @@ RUN pip install --upgrade pip && \ COPY pkg /src COPY images/test/run.sh /src/run.sh +COPY pkg/cortex/serve/log_config.yaml /src/cortex/serve/log_config.yaml +ENV CORTEX_LOG_LEVEL DEBUG +ENV CORTEX_LOG_CONFIG_FILE /src/cortex/serve/log_config.yaml + RUN pip install --no-deps /src/cortex/serve/ && \ rm -rf /root/.cache/pip* diff --git a/images/test/run.sh b/images/test/run.sh index 699ba49ab8..041aa905da 100644 --- a/images/test/run.sh +++ b/images/test/run.sh @@ -18,6 +18,12 @@ err=0 trap 'err=1' ERR +function substitute_env_vars() { + file_to_run_substitution=$1 + python -c "from cortex_internal.lib import util; import os; util.expand_environment_vars_on_file('$file_to_run_substitution')" +} + +substitute_env_vars $CORTEX_LOG_CONFIG_FILE pytest lib/test test $err = 0 diff --git a/pkg/cortex/serve/cortex_internal.requirements.txt b/pkg/cortex/serve/cortex_internal.requirements.txt index eb6b745471..a7232acde2 100644 --- a/pkg/cortex/serve/cortex_internal.requirements.txt +++ b/pkg/cortex/serve/cortex_internal.requirements.txt @@ -1,3 +1,4 @@ +grpcio==1.32.0 boto3==1.14.53 google-cloud-storage==1.32.0 datadog==0.39.0 diff --git a/pkg/cortex/serve/cortex_internal/lib/api/batching.py b/pkg/cortex/serve/cortex_internal/lib/api/batching.py index bde2edd637..29e0a040bc 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/batching.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/batching.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import threading as td import time import traceback @@ -21,26 +22,32 @@ from starlette.responses import Response -from ..exceptions import UserRuntimeException -from ..log import logger +from cortex_internal.lib.exceptions import UserRuntimeException +from cortex_internal.lib.log import logger class DynamicBatcher: - def __init__(self, predictor_impl: Callable, max_batch_size: int, batch_interval: int): + def __init__( + self, + predictor_impl: Callable, + max_batch_size: int, + batch_interval: int, + test_mode: bool = False, + ): self.predictor_impl = predictor_impl self.batch_max_size = max_batch_size self.batch_interval = batch_interval # measured in seconds + self.test_mode = test_mode # only for unit testing + self._test_batch_lengths = [] # only when unit testing - # waiter prevents new threads from modifying the input batch while a batch prediction is in progress - self.waiter = td.Event() - self.waiter.set() - - self.barrier = td.Barrier(self.batch_max_size + 1, action=self.waiter.clear) + self.barrier = td.Barrier(self.batch_max_size + 1) self.samples = {} self.predictions = {} - td.Thread(target=self._batch_engine).start() + td.Thread(target=self._batch_engine, daemon=True).start() + + self.sample_id_generator = itertools.count() def _batch_engine(self): while True: @@ -54,10 +61,10 @@ def _batch_engine(self): pass self.predictions = {} - + sample_ids = self._get_sample_ids(self.batch_max_size) try: if self.samples: - batch = self._make_batch(self.samples) + batch = self._make_batch(sample_ids) predictions = self.predictor_impl.predict(**batch) if not isinstance(predictions, list): @@ -65,32 +72,37 @@ def _batch_engine(self): f"please return a list when using server side batching, got {type(predictions)}" ) - self.predictions = dict(zip(self.samples.keys(), predictions)) + if self.test_mode: + self._test_batch_lengths.append(len(predictions)) + + self.predictions = dict(zip(sample_ids, predictions)) except Exception as e: - self.predictions = {thread_id: e for thread_id in self.samples} + self.predictions = {sample_id: e for sample_id in sample_ids} logger.error(traceback.format_exc()) finally: - self.samples = {} + for sample_id in sample_ids: + del self.samples[sample_id] self.barrier.reset() - self.waiter.set() - @staticmethod - def _make_batch(samples: Dict[int, Dict[str, Any]]) -> Dict[str, List[Any]]: + def _get_sample_ids(self, max_number: int) -> List[int]: + if len(self.samples) <= max_number: + return list(self.samples.keys()) + return sorted(self.samples)[:max_number] + + def _make_batch(self, sample_ids: List[int]) -> Dict[str, List[Any]]: batched_samples = defaultdict(list) - for thread_id in samples: - for key, sample in samples[thread_id].items(): + for sample_id in sample_ids: + for key, sample in self.samples[sample_id].items(): batched_samples[key].append(sample) return dict(batched_samples) - def _enqueue_request(self, **kwargs): + def _enqueue_request(self, sample_id: int, **kwargs): """ Enqueue sample for batch inference. This is a blocking method. """ - thread_id = td.get_ident() - self.waiter.wait() - self.samples[thread_id] = kwargs + self.samples[sample_id] = kwargs try: self.barrier.wait() except td.BrokenBarrierError: @@ -101,20 +113,20 @@ def predict(self, **kwargs): Queues a request to be batched with other incoming request, waits for the response and returns the prediction result. This is a blocking method. """ - self._enqueue_request(**kwargs) - prediction = self._get_prediction() + sample_id = next(self.sample_id_generator) + self._enqueue_request(sample_id, **kwargs) + prediction = self._get_prediction(sample_id) return prediction - def _get_prediction(self) -> Any: + def _get_prediction(self, sample_id: int) -> Any: """ Return the prediction. This is a blocking method. """ - thread_id = td.get_ident() - while thread_id not in self.predictions: + while sample_id not in self.predictions: time.sleep(0.001) - prediction = self.predictions[thread_id] - del self.predictions[thread_id] + prediction = self.predictions[sample_id] + del self.predictions[sample_id] if isinstance(prediction, Exception): return Response( diff --git a/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py b/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py new file mode 100644 index 0000000000..1cb3f9574a --- /dev/null +++ b/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py @@ -0,0 +1,112 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import threading as td +import itertools +import time + +import cortex_internal.lib.api.batching as batching + + +class Predictor: + def predict(self, payload): + time.sleep(0.2) + return payload + + +def test_dynamic_batching_while_hitting_max_batch_size(): + max_batch_size = 32 + dynamic_batcher = batching.DynamicBatcher( + Predictor(), max_batch_size=max_batch_size, batch_interval=0.1, test_mode=True + ) + counter = itertools.count(1) + event = td.Event() + global_list = [] + + def submitter(): + while not event.is_set(): + global_list.append(dynamic_batcher.predict(payload=next(counter))) + time.sleep(0.1) + + running_threads = [] + for _ in range(128): + thread = td.Thread(target=submitter, daemon=True) + thread.start() + running_threads.append(thread) + + time.sleep(60) + event.set() + + # if this fails, then the submitter threads are getting stuck + for thread in running_threads: + thread.join(3.0) + if thread.is_alive(): + raise TimeoutError("thread", thread.getName(), "got stuck") + + sum1 = int(len(global_list) * (len(global_list) + 1) / 2) + sum2 = sum(global_list) + assert sum1 == sum2 + + # get the last 80% of batch lengths + # we ignore the first 20% because it may take some time for all threads to start making requests + batch_lengths = dynamic_batcher._test_batch_lengths + batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :] + + # verify that the batch size is always equal to the max batch size + assert len(set(batch_lengths)) == 1 + assert max_batch_size in batch_lengths + + +def test_dynamic_batching_while_hitting_max_interval(): + max_batch_size = 32 + dynamic_batcher = batching.DynamicBatcher( + Predictor(), max_batch_size=max_batch_size, batch_interval=1.0, test_mode=True + ) + counter = itertools.count(1) + event = td.Event() + global_list = [] + + def submitter(): + while not event.is_set(): + global_list.append(dynamic_batcher.predict(payload=next(counter))) + time.sleep(0.1) + + running_threads = [] + for _ in range(2): + thread = td.Thread(target=submitter, daemon=True) + thread.start() + running_threads.append(thread) + + time.sleep(30) + event.set() + + # if this fails, then the submitter threads are getting stuck + for thread in running_threads: + thread.join(3.0) + if thread.is_alive(): + raise TimeoutError("thread", thread.getName(), "got stuck") + + sum1 = int(len(global_list) * (len(global_list) + 1) / 2) + sum2 = sum(global_list) + assert sum1 == sum2 + + # get the last 80% of batch lengths + # we ignore the first 20% because it may take some time for all threads to start making requests + batch_lengths = dynamic_batcher._test_batch_lengths + batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :] + + # verify that the batch size is always equal to the number of running threads + assert len(set(batch_lengths)) == 1 + assert len(running_threads) in batch_lengths