diff --git a/autosklearn/automl.py b/autosklearn/automl.py index 222f9de727..e0aae596e1 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -210,6 +210,7 @@ def __init__( scoring_functions=None, get_trials_callback=None, dataset_compression: Union[bool, Mapping[str, Any]] = True, + allow_string_features: bool = True, ): super(AutoML, self).__init__() self.configuration_space = None @@ -281,6 +282,7 @@ def __init__( self._dataset_compression = validate_dataset_compression_arg( dataset_compression, memory_limit=self._memory_limit ) + self.allow_string_features = allow_string_features self._datamanager = None self._dataset_name = None @@ -687,6 +689,7 @@ def fit( is_classification=is_classification, feat_type=feat_type, logger_port=self._logger_port, + allow_string_features=self.allow_string_features, ) self.InputValidator.fit(X_train=X, y_train=y, X_test=X_test, y_test=y_test) X, y = self.InputValidator.transform(X, y) diff --git a/autosklearn/data/feature_validator.py b/autosklearn/data/feature_validator.py index 85bb3a900c..07c108390c 100644 --- a/autosklearn/data/feature_validator.py +++ b/autosklearn/data/feature_validator.py @@ -42,9 +42,11 @@ def __init__( self, feat_type: Optional[List[str]] = None, logger: Optional[PickableLoggerAdapter] = None, + allow_string_features: bool = True, ) -> None: # If a dataframe was provided, we populate - # this attribute with a mapping from column to {numerical | categorical} + # this attribute with a mapping from column to + # {numerical | categorical | string} self.feat_type: Optional[Dict[Union[str, int], str]] = None if feat_type is not None: if isinstance(feat_type, dict): @@ -52,7 +54,7 @@ def __init__( elif not isinstance(feat_type, List): raise ValueError( "Auto-Sklearn expects a list of categorical/" - "numerical feature types, yet a" + "numerical/string feature types, yet a" " {} was provided".format(type(feat_type)) ) else: @@ -68,6 +70,7 @@ def __init__( self.logger = logger if logger is not None else logging.getLogger(__name__) self._is_fitted = False + self.allow_string_features = allow_string_features def fit( self, @@ -300,7 +303,14 @@ def get_feat_type_from_columns( elif X[column].dtype.name in ["category", "bool"]: feat_type[column] = "categorical" elif X[column].dtype.name == "string": - feat_type[column] = "string" + if self.allow_string_features: + feat_type[column] = "string" + else: + feat_type[column] = "categorical" + warnings.warn( + f"you disabled text encoding column {column} will be " + f"encoded as category" + ) # Move away from np.issubdtype as it causes # TypeError: data type not understood in certain pandas types elif not is_numeric_dtype(X[column]): @@ -311,7 +321,14 @@ def get_feat_type_from_columns( f"Please ensure that this setting is suitable for your task.", UserWarning, ) - feat_type[column] = "string" + if self.allow_string_features: + feat_type[column] = "string" + else: + feat_type[column] = "categorical" + warnings.warn( + f"you disabled text encoding column {column} will be" + f"encoded as category" + ) elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype( X[column].dtype ): diff --git a/autosklearn/data/validation.py b/autosklearn/data/validation.py index 89aaca85c0..324bab4895 100644 --- a/autosklearn/data/validation.py +++ b/autosklearn/data/validation.py @@ -80,6 +80,7 @@ def __init__( feat_type: Optional[List[str]] = None, is_classification: bool = False, logger_port: Optional[int] = None, + allow_string_features: bool = True, ) -> None: self.feat_type = feat_type self.is_classification = is_classification @@ -92,8 +93,11 @@ def __init__( else: self.logger = logging.getLogger("Validation") + self.allow_string_features = allow_string_features self.feature_validator = FeatureValidator( - feat_type=self.feat_type, logger=self.logger + feat_type=self.feat_type, + logger=self.logger, + allow_string_features=self.allow_string_features, ) self.target_validator = TargetValidator( is_classification=self.is_classification, logger=self.logger diff --git a/autosklearn/estimators.py b/autosklearn/estimators.py index 491309a7b8..eb98e54673 100644 --- a/autosklearn/estimators.py +++ b/autosklearn/estimators.py @@ -51,6 +51,7 @@ def __init__( load_models: bool = True, get_trials_callback=None, dataset_compression: Union[bool, Mapping[str, Any]] = True, + allow_string_features: bool = True, ): """ Parameters @@ -322,6 +323,10 @@ def __init__( accordingly. We guarantee that at least one occurrence of each label is included in the sampled set. + allow_string_features: bool = True + Whether autosklearn should process string features. By default the + textpreprocessing is enabled. + Attributes ---------- cv_results_ : dict of numpy (masked) ndarrays @@ -367,6 +372,7 @@ def __init__( self.load_models = load_models self.get_trials_callback = get_trials_callback self.dataset_compression = dataset_compression + self.allow_string_features = allow_string_features self.automl_ = None # type: Optional[AutoML] @@ -415,6 +421,7 @@ def build_automl(self): scoring_functions=self.scoring_functions, get_trials_callback=self.get_trials_callback, dataset_compression=self.dataset_compression, + allow_string_features=self.allow_string_features, ) return automl diff --git a/autosklearn/experimental/askl2.py b/autosklearn/experimental/askl2.py index 7068270a8e..65ef9b2def 100644 --- a/autosklearn/experimental/askl2.py +++ b/autosklearn/experimental/askl2.py @@ -206,6 +206,7 @@ def __init__( scoring_functions: Optional[List[Scorer]] = None, load_models: bool = True, dataset_compression: Union[bool, Mapping[str, Any]] = True, + allow_string_features: bool = True, ): """ @@ -363,6 +364,7 @@ def __init__( metric=metric, scoring_functions=scoring_functions, load_models=load_models, + allow_string_features=allow_string_features, ) def fit( diff --git a/doc/manual.rst b/doc/manual.rst index b1a4d9353a..cdfcd3cbe0 100644 --- a/doc/manual.rst +++ b/doc/manual.rst @@ -301,23 +301,29 @@ Other Supported formats for these training and testing pairs are: np.ndarray, pd.DataFrame, scipy.sparse.csr_matrix and python lists. - If your data contains categorical values (in the features or targets), autosklearn will automatically encode your - data using a `sklearn.preprocessing.LabelEncoder `_ - for unidimensional data and a `sklearn.preprocessing.OrdinalEncoder `_ - for multidimensional data. - - Regarding the features, there are two methods to guide *auto-sklearn* to properly encode categorical columns: + Regarding the features, there are multiple things to consider: * Providing a X_train/X_test numpy array with the optional flag feat_type. For further details, you can check the Example :ref:`sphx_glr_examples_40_advanced_example_feature_types.py`. * You can provide a pandas DataFrame, with properly formatted columns. If a column has numerical - dtype, *auto-sklearn* will not encode it and it will be passed directly to scikit-learn. If the - column has a categorical/boolean class, it will be encoded. If the column is of any other type - (Object or Timeseries), an error will be raised. For further details on how to properly encode - your data, you can check the Pandas Example - `Working with categorical data `_). - If you are working with time series, it is recommended that you follow this approach + dtype, *auto-sklearn* will not encode it and it will be passed directly to scikit-learn. *auto-sklearn* + supports both categorical or string as column type. Please ensure that you are using the correct + dtype for your task. By default *auto-sklearn* treats object and string columns as strings and + encodes the data using `sklearn.feature_extraction.text.CountVectorizer `_ + * If your data contains categorical values (in the features or targets), ensure that you explicitly label them as categorical. + data labeled as categorical is encoded by using a `sklearn.preprocessing.LabelEncoder `_ + for unidimensional data and a `sklearn.preprodcessing.OrdinalEncoder `_ for multidimensional data. + * For further details on how to properly encode your data, you can check the Pandas Example + `Working with categorical data `_). If you are working with time series, it is recommended that you follow this approach `Working with time data `_. + * If you prefer not using the string option at all you can disable this option. In this case + objects, strings and categorical columns are encoded as categorical. + + .. code:: python + + import autosklearn.classification + automl = autosklearn.classification.AutoSklearnClassifier(allow_string_features=False) + automl.fit(X_train, y_train) Regarding the targets (y_train/y_test), if the task involves a classification problem, such features will be automatically encoded. It is recommended to provide both y_train and y_test during fit, so that a common encoding @@ -336,14 +342,15 @@ Other In order to obtain *vanilla auto-sklearn* as used in `Efficient and Robust Automated Machine Learning `_ - set ``ensemble_size=1`` and ``initial_configurations_via_metalearning=0``: + set ``ensemble_size=1``, ``initial_configurations_via_metalearning=0`` and ``allow_string_features=False``: .. code:: python import autosklearn.classification automl = autosklearn.classification.AutoSklearnClassifier( ensemble_size=1, - initial_configurations_via_metalearning=0 + initial_configurations_via_metalearning=0, + allow_string_features=False, ) An ensemble of size one will result in always choosing the current best model diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index 216f286936..a06b91b0f3 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -524,15 +524,15 @@ def dummy_func(self): dummy_object = Dummy(1) lst = [1, 2, 3] array = np.array([1, 2, 3]) - dummy_stirng = "dummy string" + dummy_string = "dummy string" df = pd.DataFrame( { "dummy_object": [dummy_object] * 4, "dummy_lst": [lst] * 4, "dummy_array": [array] * 4, - "dummy_string": [dummy_stirng] * 4, - "type_mix_column": [dummy_stirng, dummy_object, array, lst], + "dummy_string": [dummy_string] * 4, + "type_mix_column": [dummy_string, dummy_object, array, lst], "cat_column": ["a", "b", "a", "b"], } ) @@ -560,3 +560,29 @@ def dummy_func(self): } assert feat_type == column_types + + +def test_allow_string_feature(): + df = pd.DataFrame({"Text": ["Hello", "how are you?"]}) + with pytest.warns( + UserWarning, + match=r"Input Column Text has generic type object. " + r"Autosklearn will treat this column as string. " + r"Please ensure that this setting is suitable for your task.", + ): + validator = FeatureValidator(allow_string_features=False) + feat_type = validator.get_feat_type_from_columns(df) + + column_types = {"Text": "categorical"} + assert feat_type == column_types + + df["Text"] = df["Text"].astype("string") + with pytest.warns( + UserWarning, + match=r"you disabled text encoding column Text will be " r"encoded as category", + ): + validator = FeatureValidator(allow_string_features=False) + feat_type = validator.get_feat_type_from_columns(df) + + column_types = {"Text": "categorical"} + assert feat_type == column_types