From acb0f03e021c82ea2f5032abea1042d68dd1ab02 Mon Sep 17 00:00:00 2001 From: Charles Bensimon Date: Sun, 28 Feb 2021 12:57:26 +0100 Subject: [PATCH 1/7] Prevent threads from being stuck in DynamicBatcher Ensure input batch can be safely fed with new samples at any time Remove the waiter mechanism Use a safer way to generate a thread ID --- .../serve/cortex_internal/lib/api/batching.py | 39 +++++++++---------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/pkg/cortex/serve/cortex_internal/lib/api/batching.py b/pkg/cortex/serve/cortex_internal/lib/api/batching.py index bde2edd637..f5c30dd6a9 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/batching.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/batching.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import threading as td import time import traceback @@ -32,16 +33,14 @@ def __init__(self, predictor_impl: Callable, max_batch_size: int, batch_interval self.batch_max_size = max_batch_size self.batch_interval = batch_interval # measured in seconds - # waiter prevents new threads from modifying the input batch while a batch prediction is in progress - self.waiter = td.Event() - self.waiter.set() - - self.barrier = td.Barrier(self.batch_max_size + 1, action=self.waiter.clear) + self.barrier = td.Barrier(self.batch_max_size + 1) self.samples = {} self.predictions = {} td.Thread(target=self._batch_engine).start() + self.thread_id_generator = itertools.count() + def _batch_engine(self): while True: if len(self.predictions) > 0: @@ -54,10 +53,11 @@ def _batch_engine(self): pass self.predictions = {} + thread_ids = list(self.samples.keys()) try: if self.samples: - batch = self._make_batch(self.samples) + batch = self._make_batch(thread_ids) predictions = self.predictor_impl.predict(**batch) if not isinstance(predictions, list): @@ -65,31 +65,28 @@ def _batch_engine(self): f"please return a list when using server side batching, got {type(predictions)}" ) - self.predictions = dict(zip(self.samples.keys(), predictions)) + self.predictions = dict(zip(thread_ids, predictions)) except Exception as e: - self.predictions = {thread_id: e for thread_id in self.samples} + self.predictions = {thread_id: e for thread_id in thread_ids} logger.error(traceback.format_exc()) finally: - self.samples = {} + for thread_id in thread_ids: + del self.samples[thread_id] self.barrier.reset() - self.waiter.set() - @staticmethod - def _make_batch(samples: Dict[int, Dict[str, Any]]) -> Dict[str, List[Any]]: + def _make_batch(self, thread_ids: List[int]) -> Dict[str, List[Any]]: batched_samples = defaultdict(list) - for thread_id in samples: - for key, sample in samples[thread_id].items(): + for thread_id in thread_ids: + for key, sample in self.samples[thread_id].items(): batched_samples[key].append(sample) return dict(batched_samples) - def _enqueue_request(self, **kwargs): + def _enqueue_request(self, thread_id: int, **kwargs): """ Enqueue sample for batch inference. This is a blocking method. """ - thread_id = td.get_ident() - self.waiter.wait() self.samples[thread_id] = kwargs try: self.barrier.wait() @@ -101,15 +98,15 @@ def predict(self, **kwargs): Queues a request to be batched with other incoming request, waits for the response and returns the prediction result. This is a blocking method. """ - self._enqueue_request(**kwargs) - prediction = self._get_prediction() + thread_id = next(self.thread_id_generator) + self._enqueue_request(thread_id, **kwargs) + prediction = self._get_prediction(thread_id) return prediction - def _get_prediction(self) -> Any: + def _get_prediction(self, thread_id: int) -> Any: """ Return the prediction. This is a blocking method. """ - thread_id = td.get_ident() while thread_id not in self.predictions: time.sleep(0.001) From d2b18973cee6dccae4c50d87c693cab4eafce307 Mon Sep 17 00:00:00 2001 From: Robert Lucian Chiriac Date: Wed, 3 Mar 2021 22:27:06 +0200 Subject: [PATCH 2/7] Enforce max batch size --- pkg/cortex/serve/cortex_internal/lib/api/batching.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/cortex/serve/cortex_internal/lib/api/batching.py b/pkg/cortex/serve/cortex_internal/lib/api/batching.py index f5c30dd6a9..a910507e4d 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/batching.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/batching.py @@ -53,8 +53,7 @@ def _batch_engine(self): pass self.predictions = {} - thread_ids = list(self.samples.keys()) - + thread_ids = self._get_thread_ids(self.batch_max_size) try: if self.samples: batch = self._make_batch(thread_ids) @@ -74,6 +73,11 @@ def _batch_engine(self): del self.samples[thread_id] self.barrier.reset() + def _get_thread_ids(self, max_number: int) -> List[int]: + if len(self.samples) <= max_number: + return list(self.samples.keys()) + return sorted(self.samples)[:max_number] + def _make_batch(self, thread_ids: List[int]) -> Dict[str, List[Any]]: batched_samples = defaultdict(list) for thread_id in thread_ids: From f0df3470bd7008322e610d48001d113d91b18e9e Mon Sep 17 00:00:00 2001 From: Robert Lucian Chiriac Date: Wed, 3 Mar 2021 22:27:20 +0200 Subject: [PATCH 3/7] Don't use relative imports --- pkg/cortex/serve/cortex_internal/lib/api/batching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/cortex/serve/cortex_internal/lib/api/batching.py b/pkg/cortex/serve/cortex_internal/lib/api/batching.py index a910507e4d..15e5b19a2c 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/batching.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/batching.py @@ -22,8 +22,8 @@ from starlette.responses import Response -from ..exceptions import UserRuntimeException -from ..log import logger +from cortex_internal.lib.exceptions import UserRuntimeException +from cortex_internal.lib.log import logger class DynamicBatcher: From 4420a715bf579348a9c0c6bfb409b9ac5795ae5d Mon Sep 17 00:00:00 2001 From: Robert Lucian Chiriac Date: Wed, 3 Mar 2021 22:41:58 +0200 Subject: [PATCH 4/7] Add test file for dynamic batcher for dev (to be added to test suite) --- dev/tests/dynamic_batcher_test.py | 67 +++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 dev/tests/dynamic_batcher_test.py diff --git a/dev/tests/dynamic_batcher_test.py b/dev/tests/dynamic_batcher_test.py new file mode 100644 index 0000000000..a34e28ad26 --- /dev/null +++ b/dev/tests/dynamic_batcher_test.py @@ -0,0 +1,67 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dependencies: cortex_internal, pystuck +# When this is running, to inspect the stuck threads upon exiting (by pressing CTRL-C), run pystuck (as a CLI) in a different window + +import time +import itertools +import threading as td +import signal +import sys + +import pystuck +from cortex_internal.lib.api import batching + +pystuck.run_server() + + +class Predictor: + def predict(self, payload): + print("received payload:", payload) + time.sleep(0.2) + return payload + + +db = batching.DynamicBatcher(Predictor(), max_batch_size=32, batch_interval=0.1) +counter = itertools.count(1) +event = td.Event() +global_list = [] +running_threads = [] + + +def submitter(): + while not event.is_set(): + global_list.append(db.predict(payload=next(counter))) + time.sleep(0.1) + + +def signal_handler(sig, frame): + event.set() + print("ctrl-c has been pressed; exiting ...") + print( + f"global_list(global_list+1)/2={int(len(global_list) * (len(global_list) + 1) / 2)} || sum(global_list)={sum(global_list)}" + ) + for thread in running_threads: + print("joining on", thread.getName()) + thread.join() + sys.exit(0) + + +signal.signal(signal.SIGINT, signal_handler) + +for i in range(128): + thread = td.Thread(target=submitter) + thread.start() + running_threads.append(thread) From ea75ffa8764dd2a7a97f3a6b2c97a35770df0095 Mon Sep 17 00:00:00 2001 From: Robert Lucian Chiriac Date: Thu, 4 Mar 2021 23:34:16 +0200 Subject: [PATCH 5/7] Rename thread_ids to sample_ids --- .../serve/cortex_internal/lib/api/batching.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/pkg/cortex/serve/cortex_internal/lib/api/batching.py b/pkg/cortex/serve/cortex_internal/lib/api/batching.py index 15e5b19a2c..9adef95730 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/batching.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/batching.py @@ -39,7 +39,7 @@ def __init__(self, predictor_impl: Callable, max_batch_size: int, batch_interval self.predictions = {} td.Thread(target=self._batch_engine).start() - self.thread_id_generator = itertools.count() + self.sample_id_generator = itertools.count() def _batch_engine(self): while True: @@ -53,10 +53,10 @@ def _batch_engine(self): pass self.predictions = {} - thread_ids = self._get_thread_ids(self.batch_max_size) + sample_ids = self._get_sample_ids(self.batch_max_size) try: if self.samples: - batch = self._make_batch(thread_ids) + batch = self._make_batch(sample_ids) predictions = self.predictor_impl.predict(**batch) if not isinstance(predictions, list): @@ -64,34 +64,34 @@ def _batch_engine(self): f"please return a list when using server side batching, got {type(predictions)}" ) - self.predictions = dict(zip(thread_ids, predictions)) + self.predictions = dict(zip(sample_ids, predictions)) except Exception as e: - self.predictions = {thread_id: e for thread_id in thread_ids} + self.predictions = {sample_id: e for sample_id in sample_ids} logger.error(traceback.format_exc()) finally: - for thread_id in thread_ids: - del self.samples[thread_id] + for sample_id in sample_ids: + del self.samples[sample_id] self.barrier.reset() - def _get_thread_ids(self, max_number: int) -> List[int]: + def _get_sample_ids(self, max_number: int) -> List[int]: if len(self.samples) <= max_number: return list(self.samples.keys()) return sorted(self.samples)[:max_number] - def _make_batch(self, thread_ids: List[int]) -> Dict[str, List[Any]]: + def _make_batch(self, sample_ids: List[int]) -> Dict[str, List[Any]]: batched_samples = defaultdict(list) - for thread_id in thread_ids: - for key, sample in self.samples[thread_id].items(): + for sample_id in sample_ids: + for key, sample in self.samples[sample_id].items(): batched_samples[key].append(sample) return dict(batched_samples) - def _enqueue_request(self, thread_id: int, **kwargs): + def _enqueue_request(self, sample_id: int, **kwargs): """ Enqueue sample for batch inference. This is a blocking method. """ - self.samples[thread_id] = kwargs + self.samples[sample_id] = kwargs try: self.barrier.wait() except td.BrokenBarrierError: @@ -102,20 +102,20 @@ def predict(self, **kwargs): Queues a request to be batched with other incoming request, waits for the response and returns the prediction result. This is a blocking method. """ - thread_id = next(self.thread_id_generator) - self._enqueue_request(thread_id, **kwargs) - prediction = self._get_prediction(thread_id) + sample_id = next(self.sample_id_generator) + self._enqueue_request(sample_id, **kwargs) + prediction = self._get_prediction(sample_id) return prediction - def _get_prediction(self, thread_id: int) -> Any: + def _get_prediction(self, sample_id: int) -> Any: """ Return the prediction. This is a blocking method. """ - while thread_id not in self.predictions: + while sample_id not in self.predictions: time.sleep(0.001) - prediction = self.predictions[thread_id] - del self.predictions[thread_id] + prediction = self.predictions[sample_id] + del self.predictions[sample_id] if isinstance(prediction, Exception): return Response( From d179a119cdb032266b687107a78464a2668b197f Mon Sep 17 00:00:00 2001 From: Robert Lucian Chiriac Date: Fri, 5 Mar 2021 00:59:48 +0200 Subject: [PATCH 6/7] Convert manual dynamic_batcher_test to unit test --- dev/tests/dynamic_batcher_test.py | 67 ----------- images/test/Dockerfile | 4 + images/test/run.sh | 6 + .../serve/cortex_internal.requirements.txt | 1 + .../serve/cortex_internal/lib/api/batching.py | 15 ++- .../lib/test/dynamic_batching_test.py | 112 ++++++++++++++++++ 6 files changed, 136 insertions(+), 69 deletions(-) delete mode 100644 dev/tests/dynamic_batcher_test.py create mode 100644 pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py diff --git a/dev/tests/dynamic_batcher_test.py b/dev/tests/dynamic_batcher_test.py deleted file mode 100644 index a34e28ad26..0000000000 --- a/dev/tests/dynamic_batcher_test.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2021 Cortex Labs, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Dependencies: cortex_internal, pystuck -# When this is running, to inspect the stuck threads upon exiting (by pressing CTRL-C), run pystuck (as a CLI) in a different window - -import time -import itertools -import threading as td -import signal -import sys - -import pystuck -from cortex_internal.lib.api import batching - -pystuck.run_server() - - -class Predictor: - def predict(self, payload): - print("received payload:", payload) - time.sleep(0.2) - return payload - - -db = batching.DynamicBatcher(Predictor(), max_batch_size=32, batch_interval=0.1) -counter = itertools.count(1) -event = td.Event() -global_list = [] -running_threads = [] - - -def submitter(): - while not event.is_set(): - global_list.append(db.predict(payload=next(counter))) - time.sleep(0.1) - - -def signal_handler(sig, frame): - event.set() - print("ctrl-c has been pressed; exiting ...") - print( - f"global_list(global_list+1)/2={int(len(global_list) * (len(global_list) + 1) / 2)} || sum(global_list)={sum(global_list)}" - ) - for thread in running_threads: - print("joining on", thread.getName()) - thread.join() - sys.exit(0) - - -signal.signal(signal.SIGINT, signal_handler) - -for i in range(128): - thread = td.Thread(target=submitter) - thread.start() - running_threads.append(thread) diff --git a/images/test/Dockerfile b/images/test/Dockerfile index 55f1fc5415..3ca81f9e01 100644 --- a/images/test/Dockerfile +++ b/images/test/Dockerfile @@ -10,6 +10,10 @@ RUN pip install --upgrade pip && \ COPY pkg /src COPY images/test/run.sh /src/run.sh +COPY pkg/cortex/serve/log_config.yaml /src/cortex/serve/log_config.yaml +ENV CORTEX_LOG_LEVEL DEBUG +ENV CORTEX_LOG_CONFIG_FILE /src/cortex/serve/log_config.yaml + RUN pip install --no-deps /src/cortex/serve/ && \ rm -rf /root/.cache/pip* diff --git a/images/test/run.sh b/images/test/run.sh index 699ba49ab8..041aa905da 100644 --- a/images/test/run.sh +++ b/images/test/run.sh @@ -18,6 +18,12 @@ err=0 trap 'err=1' ERR +function substitute_env_vars() { + file_to_run_substitution=$1 + python -c "from cortex_internal.lib import util; import os; util.expand_environment_vars_on_file('$file_to_run_substitution')" +} + +substitute_env_vars $CORTEX_LOG_CONFIG_FILE pytest lib/test test $err = 0 diff --git a/pkg/cortex/serve/cortex_internal.requirements.txt b/pkg/cortex/serve/cortex_internal.requirements.txt index eb6b745471..a7232acde2 100644 --- a/pkg/cortex/serve/cortex_internal.requirements.txt +++ b/pkg/cortex/serve/cortex_internal.requirements.txt @@ -1,3 +1,4 @@ +grpcio==1.32.0 boto3==1.14.53 google-cloud-storage==1.32.0 datadog==0.39.0 diff --git a/pkg/cortex/serve/cortex_internal/lib/api/batching.py b/pkg/cortex/serve/cortex_internal/lib/api/batching.py index 9adef95730..29e0a040bc 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/batching.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/batching.py @@ -27,17 +27,25 @@ class DynamicBatcher: - def __init__(self, predictor_impl: Callable, max_batch_size: int, batch_interval: int): + def __init__( + self, + predictor_impl: Callable, + max_batch_size: int, + batch_interval: int, + test_mode: bool = False, + ): self.predictor_impl = predictor_impl self.batch_max_size = max_batch_size self.batch_interval = batch_interval # measured in seconds + self.test_mode = test_mode # only for unit testing + self._test_batch_lengths = [] # only when unit testing self.barrier = td.Barrier(self.batch_max_size + 1) self.samples = {} self.predictions = {} - td.Thread(target=self._batch_engine).start() + td.Thread(target=self._batch_engine, daemon=True).start() self.sample_id_generator = itertools.count() @@ -64,6 +72,9 @@ def _batch_engine(self): f"please return a list when using server side batching, got {type(predictions)}" ) + if self.test_mode: + self._test_batch_lengths.append(len(predictions)) + self.predictions = dict(zip(sample_ids, predictions)) except Exception as e: self.predictions = {sample_id: e for sample_id in sample_ids} diff --git a/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py b/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py new file mode 100644 index 0000000000..17dfcdd661 --- /dev/null +++ b/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py @@ -0,0 +1,112 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import threading as td +import itertools +import time + +import cortex_internal.lib.api.batching as batching + + +class Predictor: + def predict(self, payload): + time.sleep(0.2) + return payload + + +def test_dynamic_batching_while_hitting_max_batch_size(): + max_batch_size = 32 + dynamic_batcher = batching.DynamicBatcher( + Predictor(), max_batch_size=max_batch_size, batch_interval=0.1, test_mode=True + ) + counter = itertools.count(1) + event = td.Event() + global_list = [] + + def submitter(): + while not event.is_set(): + global_list.append(dynamic_batcher.predict(payload=next(counter))) + time.sleep(0.1) + + running_threads = [] + for _ in range(128): + thread = td.Thread(target=submitter, daemon=True) + thread.start() + running_threads.append(thread) + + time.sleep(60) + event.set() + + # if this fails, then the submitter threads are getting stuck + for thread in running_threads: + thread.join(3.0) + if thread.is_alive(): + raise TimeoutError("thread", thread.getName(), "got stuck") + + sum1 = int(len(global_list) * (len(global_list) + 1) / 2) + sum2 = sum(global_list) + assert sum1 == sum2 + + # get the last 80% of batch lengths + # we ignore the first 20% because it may take some time for all threads to start making requests + batch_lengths = dynamic_batcher._test_batch_lengths + batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :] + + # verify that the batch size is always equal to the max batch size + assert len(set(batch_lengths)) == 1 + assert max_batch_size in batch_lengths + + +def test_dynamic_batching_while_hitting_max_interval(): + max_batch_size = 32 + dynamic_batcher = batching.DynamicBatcher( + Predictor(), max_batch_size=max_batch_size, batch_interval=1.0, test_mode=True + ) + counter = itertools.count(1) + event = td.Event() + global_list = [] + + def submitter(): + while not event.is_set(): + global_list.append(dynamic_batcher.predict(payload=next(counter))) + time.sleep(0.1) + + running_threads = [] + for _ in range(2): + thread = td.Thread(target=submitter, daemon=True) + thread.start() + running_threads.append(thread) + + time.sleep(30) + event.set() + + # if this fails, then the submitter threads are getting stuck + for thread in running_threads: + thread.join(3.0) + if thread.is_alive(): + raise TimeoutError("thread", thread.getName(), "got stuck") + + sum1 = int(len(global_list) * (len(global_list) + 1) / 2) + sum2 = sum(global_list) + assert sum1 == sum2 + + # get the last 80% of batch lengths + # we ignore the first 20% because it may take some time for all threads to start making requests + batch_lengths = dynamic_batcher._test_batch_lengths + batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :] + + # verify that the batch size is always equal to number of running threads + assert len(set(batch_lengths)) == 1 + assert len(running_threads) in batch_lengths From d4327cfa7ec1487529952f7079e4af0e764345a6 Mon Sep 17 00:00:00 2001 From: Robert Lucian Chiriac Date: Fri, 5 Mar 2021 01:04:17 +0200 Subject: [PATCH 7/7] Missing word in python comment --- .../serve/cortex_internal/lib/test/dynamic_batching_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py b/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py index 17dfcdd661..1cb3f9574a 100644 --- a/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py +++ b/pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py @@ -107,6 +107,6 @@ def submitter(): batch_lengths = dynamic_batcher._test_batch_lengths batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :] - # verify that the batch size is always equal to number of running threads + # verify that the batch size is always equal to the number of running threads assert len(set(batch_lengths)) == 1 assert len(running_threads) in batch_lengths