Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ coverage.xml
*,cover
.hypothesis/
prof/

# Mypy
.mypy_cache/
22 changes: 19 additions & 3 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
CoalescenseChoice
)
from autosklearn.pipeline.components.data_preprocessing.rescaling import RescalingChoice
from autosklearn.util.single_thread_client import SingleThreadedClient


def _model_predict(model, X, batch_size, logger, task):
Expand Down Expand Up @@ -222,6 +223,16 @@ def __init__(self,
# The ensemble performance history through time
self.ensemble_performance_history = []

# Single core, local runs should use fork
# to prevent the __main__ requirements in
# examples. Nevertheless, multi-process runs
# have spawn as requirement to reduce the
# possibility of a deadlock
self._multiprocessing_context = 'spawn'
if self._n_jobs == 1 and self._dask_client is None:
self._multiprocessing_context = 'fork'
self._dask_client = SingleThreadedClient()

if not isinstance(self._time_for_task, int):
raise ValueError("time_left_for_this_task not of type integer, "
"but %s" % str(type(self._time_for_task)))
Expand All @@ -241,7 +252,7 @@ def _create_dask_client(self):
self._dask_client = dask.distributed.Client(
dask.distributed.LocalCluster(
n_workers=self._n_jobs,
processes=True,
processes=True if self._n_jobs != 1 else False,
threads_per_worker=1,
# We use the temporal directory to save the
# dask workers, because deleting workers
Expand Down Expand Up @@ -288,7 +299,8 @@ def _get_logger(self, name):
# under the above logging configuration setting
# We need to specify the logger_name so that received records
# are treated under the logger_name ROOT logger setting
context = multiprocessing.get_context('spawn')
context = multiprocessing.get_context(
self._multiprocessing_context)
self.stop_logging_server = context.Event()
port = context.Value('l') # be safe by using a long
port.value = -1
Expand Down Expand Up @@ -389,6 +401,7 @@ def _do_dummy_prediction(self, datamanager, num_run):
abort_on_first_run_crash=False,
cost_for_crash=get_cost_of_crash(self._metric),
port=self._logger_port,
pynisher_context=self._multiprocessing_context,
**self._resampling_strategy_arguments)

status, cost, runtime, additional_info = ta.run(num_run, cutoff=self._time_for_task)
Expand Down Expand Up @@ -553,6 +566,7 @@ def fit(
self._logger.debug(' resampling_strategy_arguments: %s',
str(self._resampling_strategy_arguments))
self._logger.debug(' n_jobs: %s', str(self._n_jobs))
self._logger.debug(' multiprocessing_context: %s', str(self._multiprocessing_context))
self._logger.debug(' dask_client: %s', str(self._dask_client))
self._logger.debug(' precision: %s', str(self.precision))
self._logger.debug(' disable_evaluator_output: %s', str(self._disable_evaluator_output))
Expand Down Expand Up @@ -662,6 +676,7 @@ def fit(
ensemble_memory_limit=self._memory_limit,
random_state=self._seed,
logger_port=self._logger_port,
pynisher_context=self._multiprocessing_context,
)

self._stopwatch.stop_task(ensemble_task_name)
Expand Down Expand Up @@ -737,6 +752,7 @@ def fit(
smac_scenario_args=self._smac_scenario_args,
scoring_functions=self._scoring_functions,
port=self._logger_port,
pynisher_context=self._multiprocessing_context,
ensemble_callback=proc_ensemble,
)

Expand Down Expand Up @@ -1015,10 +1031,10 @@ def fit_ensemble(self, y, task=None, precision=32,
ensemble_memory_limit=self._memory_limit,
random_state=self._seed,
logger_port=self._logger_port,
pynisher_context=self._multiprocessing_context,
)
manager.build_ensemble(self._dask_client)
future = manager.futures.pop()
dask.distributed.wait([future]) # wait for the ensemble process to finish
result = future.result()
if result is None:
raise ValueError("Error building the ensemble - please check the log file and command "
Expand Down
9 changes: 6 additions & 3 deletions autosklearn/ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
ensemble_memory_limit: Optional[int],
random_state: int,
logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT,
pynisher_context: str = 'fork',
):
""" SMAC callback to handle ensemble building

Expand Down Expand Up @@ -105,6 +106,8 @@ def __init__(
read at most n new prediction files in each iteration
logger_port: int
port that receives logging records
pynisher_context: str
The multiprocessing context for pynisher. One of spawn/fork/forkserver.

Returns
-------
Expand All @@ -128,6 +131,7 @@ def __init__(
self.ensemble_memory_limit = ensemble_memory_limit
self.random_state = random_state
self.logger_port = logger_port
self.pynisher_context = pynisher_context

# Store something similar to SMAC's runhistory
self.history = []
Expand Down Expand Up @@ -155,7 +159,6 @@ def __call__(
def build_ensemble(
self,
dask_client: dask.distributed.Client,
pynisher_context: str = 'spawn',
unit_test: bool = False
) -> None:

Expand Down Expand Up @@ -229,7 +232,7 @@ def build_ensemble(
iteration=self.iteration,
return_predictions=False,
priority=100,
pynisher_context=pynisher_context,
pynisher_context=self.pynisher_context,
logger_port=self.logger_port,
unit_test=unit_test,
))
Expand Down Expand Up @@ -573,7 +576,7 @@ def run(
end_at: Optional[float] = None,
time_buffer=5,
return_predictions: bool = False,
pynisher_context: str = 'spawn', # only change for unit testing!
pynisher_context: str = 'spawn',
):

if time_left is None and end_at is None:
Expand Down
4 changes: 4 additions & 0 deletions autosklearn/smbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(self, config_space, dataset_name,
smac_scenario_args=None,
get_smac_object_callback=None,
scoring_functions=None,
pynisher_context='spawn',
ensemble_callback: typing.Optional[EnsembleBuilderManager] = None,
):
super(AutoMLSMBO, self).__init__()
Expand Down Expand Up @@ -269,6 +270,8 @@ def __init__(self, config_space, dataset_name,
self.get_smac_object_callback = get_smac_object_callback
self.scoring_functions = scoring_functions

self.pynisher_context = pynisher_context

self.ensemble_callback = ensemble_callback

dataset_name_ = "" if dataset_name is None else dataset_name
Expand Down Expand Up @@ -448,6 +451,7 @@ def run_smbo(self):
disable_file_output=self.disable_file_output,
scoring_functions=self.scoring_functions,
port=self.port,
pynisher_context=self.pynisher_context,
**self.resampling_strategy_args
)
ta = ExecuteTaFuncWithQueue
Expand Down
86 changes: 86 additions & 0 deletions autosklearn/util/single_thread_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import typing
from pathlib import Path

import dask.distributed


class DummyFuture(dask.distributed.Future):
"""
A class that mimics a distributed Future, the outcome of
performing submit on a distributed client.
"""
def __init__(self, result: typing.Any) -> None:
self._result = result # type: typing.Any

def result(self, timeout: typing.Optional[int] = None) -> typing.Any:
return self._result

def cancel(self) -> None:
pass

def done(self) -> bool:
return True

def __repr__(self) -> str:
return "DummyFuture: {}".format(self._result)

def __del__(self) -> None:
pass


class SingleThreadedClient(dask.distributed.Client):
"""
A class to Mock the Distributed Client class, in case
Auto-Sklearn is meant to run in the current Thread.
"""
def __init__(self) -> None:

# Raise a not implemented error if using a method from Client
implemented_methods = ['submit', 'close', 'shutdown', 'write_scheduler_file',
'_get_scheduler_info', 'nthreads']
method_list = [func for func in dir(dask.distributed.Client) if callable(
getattr(dask.distributed.Client, func)) and not func.startswith('__')]
for method in method_list:
if method in implemented_methods:
continue
setattr(self, method, self._unsupported_method)
pass

def _unsupported_method(self) -> None:
raise NotImplementedError()

def submit(
self,
func: typing.Callable,
*args: typing.List,
priority: int = 0,
**kwargs: typing.Dict,
) -> typing.Any:
return DummyFuture(func(*args, **kwargs))

def close(self) -> None:
pass

def shutdown(self) -> None:
pass

def write_scheduler_file(self, scheduler_file: str) -> None:
Path(scheduler_file).touch()
return

def _get_scheduler_info(self) -> typing.Dict:
return {
'workers': ['127.0.0.1'],
'type': 'Scheduler',
}

def nthreads(self) -> typing.Dict:
return {
'127.0.0.1': 1,
}

def __repr__(self) -> str:
return 'SingleThreadedClient()'

def __del__(self) -> None:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check automatically check which other methods are implemented by dask using inspection and automatically add these to the single-threaded client and raise a NotImplementedError so that we don't run into any issues in the future?

65 changes: 32 additions & 33 deletions examples/20_basic/example_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,35 @@
import autosklearn.classification


if __name__ == "__main__":
############################################################################
# Data Loading
# ============

X, y = sklearn.datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = \
sklearn.model_selection.train_test_split(X, y, random_state=1)

############################################################################
# Build and fit a regressor
# =========================

automl = autosklearn.classification.AutoSklearnClassifier(
time_left_for_this_task=120,
per_run_time_limit=30,
tmp_folder='/tmp/autosklearn_classification_example_tmp',
output_folder='/tmp/autosklearn_classification_example_out',
)
automl.fit(X_train, y_train, dataset_name='breast_cancer')

############################################################################
# Print the final ensemble constructed by auto-sklearn
# ====================================================

print(automl.show_models())

###########################################################################
# Get the Score of the final ensemble
# ===================================

predictions = automl.predict(X_test)
print("Accuracy score:", sklearn.metrics.accuracy_score(y_test, predictions))
############################################################################
# Data Loading
# ============

X, y = sklearn.datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = \
sklearn.model_selection.train_test_split(X, y, random_state=1)

############################################################################
# Build and fit a regressor
# =========================

automl = autosklearn.classification.AutoSklearnClassifier(
time_left_for_this_task=120,
per_run_time_limit=30,
tmp_folder='/tmp/autosklearn_classification_example_tmp',
output_folder='/tmp/autosklearn_classification_example_out',
)
automl.fit(X_train, y_train, dataset_name='breast_cancer')

############################################################################
# Print the final ensemble constructed by auto-sklearn
# ====================================================

print(automl.show_models())

###########################################################################
# Get the Score of the final ensemble
# ===================================

predictions = automl.predict(X_test)
print("Accuracy score:", sklearn.metrics.accuracy_score(y_test, predictions))
Loading