From 5cb018c8c9359994c69a5e8c27462a1d0136e348 Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Fri, 29 Aug 2025 15:10:21 -0700 Subject: [PATCH 1/4] Add backfill_value and default_value fields to Parameter (#4177) Summary: Add two properties to `Parameter`: 1. `backfill_value` - used for parameters added to experiments that already have run trials. Specifies the backfill value to use for trials that have already run. 2. `default_value` - used for parameters disabled in experiments that already have run trials. Specified the default value to use in modeling for future trials. Reviewed By: lena-kashtelyan Differential Revision: D79260975 --- ax/core/parameter.py | 76 ++++++++++++++++++++++++----- ax/storage/sqa_store/decoder.py | 8 +++ ax/storage/sqa_store/encoder.py | 8 +++ ax/storage/sqa_store/sqa_classes.py | 2 + 4 files changed, 83 insertions(+), 11 deletions(-) diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 52cfe196b5e..1b3d9aa0d17 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -13,10 +13,10 @@ from enum import Enum from logging import Logger from math import inf -from typing import cast, Union +from typing import Any, cast, Union from warnings import warn -from ax.core.types import TNumeric, TParameterization, TParamValue, TParamValueList +from ax.core.types import TNumeric, TParameterization, TParamValue from ax.exceptions.core import AxParameterWarning, UnsupportedError, UserInputError from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger @@ -98,6 +98,8 @@ class Parameter(SortableBase, metaclass=ABCMeta): _name: str _target_value: TParamValue = None _parameter_type: ParameterType + _backfill_value: TParamValue = None + _default_value: TParamValue = None def cast(self, value: TParamValue) -> TParamValue: if value is None: @@ -154,6 +156,18 @@ def is_hierarchical(self) -> bool: def target_value(self) -> TParamValue: return self._target_value + @property + def backfill_value(self) -> TParamValue: + return self._backfill_value + + @property + def default_value(self) -> TParamValue: + return self._default_value + + @property + def is_disabled(self) -> bool: + return self.default_value is not None + @property def parameter_type(self) -> ParameterType: return self._parameter_type @@ -214,9 +228,9 @@ def available_flags(self) -> list[str]: @property def summary_dict( self, - ) -> dict[str, TParamValueList | TParamValue | str | list[str]]: + ) -> dict[str, Any]: # Assemble dict. - summary_dict = { + summary_dict: dict[str, Any] = { "name": self.name, "type": self.__class__.__name__.removesuffix("Parameter"), "domain": self.domain_repr, @@ -239,17 +253,14 @@ def summary_dict( if flags: summary_dict["flags"] = ", ".join(flags) if getattr(self, "is_fidelity", False) or getattr(self, "is_task", False): - # pyre-fixme[6]: For 2nd argument expected `str` but got `Union[None, - # bool, float, int, str]`. summary_dict["target_value"] = self.target_value if getattr(self, "is_hierarchical", False): - # pyre-fixme[6]: For 2nd argument expected `str` but got - # `Dict[Union[None, bool, float, int, str], List[str]]`. summary_dict["dependents"] = self.dependents + if getattr(self, "backfill_value", None) is not None: + summary_dict["backfill_value"] = self.backfill_value + if getattr(self, "default_value", None) is not None: + summary_dict["default_value"] = self.default_value - # pyre-fixme[7]: Expected `Dict[str, Union[None, List[Union[None, bool, - # float, int, str]], List[str], bool, float, int, str]]` but got `Dict[str, - # str]`. return summary_dict @@ -267,6 +278,8 @@ def __init__( digits: int | None = None, is_fidelity: bool = False, target_value: TParamValue = None, + backfill_value: TParamValue = None, + default_value: TParamValue = None, ) -> None: """Initialize RangeParameter @@ -283,6 +296,11 @@ def __init__( digits: Number of digits to round values to for float type. is_fidelity: Whether this parameter is a fidelity parameter. target_value: Target value of this parameter if it is a fidelity. + backfill_value: For parameters added to experiments that have already run + trials. + Used to backfill trials missing the parameter. + default_value: For parameters disabled in experiments that have already + run trials. Used as default value in modeling for future trials. """ if is_fidelity and (target_value is None): raise UserInputError( @@ -303,6 +321,12 @@ def __init__( self._target_value: TNumeric | None = ( self.cast(target_value) if target_value is not None else None ) + self._backfill_value: TNumeric | None = ( + self.cast(backfill_value) if backfill_value is not None else None + ) + self._default_value: TNumeric | None = ( + self.cast(default_value) if default_value is not None else None + ) self._validate_range_param( parameter_type=parameter_type, @@ -541,6 +565,8 @@ def clone(self) -> RangeParameter: digits=self._digits, is_fidelity=self._is_fidelity, target_value=self._target_value, + backfill_value=self._backfill_value, + default_value=self._default_value, ) def cast(self, value: TParamValue) -> TNumeric: @@ -588,6 +614,10 @@ class ChoiceParameter(Parameter): True. dependents: Optional mapping for parameters in hierarchical search spaces; format is { value -> list of dependent parameter names }. + backfill_value: For parameters added to experiments that have already run. + Used to backfill trials missing the parameter. + default_value: For parameters disabled in experiments that have already + run. Used as default value in modeling for future trials. """ def __init__( @@ -601,6 +631,8 @@ def __init__( target_value: TParamValue = None, sort_values: bool | None = None, dependents: dict[TParamValue, list[str]] | None = None, + backfill_value: TParamValue = None, + default_value: TParamValue = None, ) -> None: if (is_fidelity or is_task) and (target_value is None): ptype = "fidelity" if is_fidelity else "task" @@ -616,6 +648,12 @@ def __init__( self._target_value: TParamValue = ( self.cast(target_value) if target_value is not None else None ) + self._backfill_value: TParamValue = ( + self.cast(backfill_value) if backfill_value is not None else None + ) + self._default_value: TParamValue = ( + self.cast(default_value) if default_value is not None else None + ) # A choice parameter with only one value is a FixedParameter. if not len(values) > 1: raise UserInputError(f"{self._name}({values}): {FIXED_CHOICE_PARAM_ERROR}") @@ -789,6 +827,8 @@ def clone(self) -> ChoiceParameter: target_value=self._target_value, sort_values=self._sort_values, dependents=deepcopy(self._dependents), + backfill_value=self._backfill_value, + default_value=self._default_value, ) def __repr__(self) -> str: @@ -826,6 +866,8 @@ def __init__( is_fidelity: bool = False, target_value: TParamValue = None, dependents: dict[TParamValue, list[str]] | None = None, + backfill_value: TParamValue = None, + default_value: TParamValue = None, ) -> None: """Initialize FixedParameter @@ -838,6 +880,10 @@ def __init__( target_value: Target value of this parameter if it is a fidelity. dependents: Optional mapping for parameters in hierarchical search spaces; format is { value -> list of dependent parameter names }. + backfill_value: For parameters added to experiments that have already run. + Used to backfill trials missing the parameter. + default_value: For parameters disabled in experiments that have already + run. Used as default value in modeling for future trials. """ if is_fidelity and (target_value is None): raise UserInputError( @@ -852,6 +898,12 @@ def __init__( self._target_value: TParamValue = ( self.cast(target_value) if target_value is not None else None ) + self._backfill_value: TParamValue = ( + self.cast(backfill_value) if backfill_value is not None else None + ) + self._default_value: TParamValue = ( + self.cast(default_value) if default_value is not None else None + ) # NOTE: We don't need to check that dependent parameters actually exist as # that is done in `HierarchicalSearchSpace` constructor. if dependents: @@ -909,6 +961,8 @@ def clone(self) -> FixedParameter: is_fidelity=self._is_fidelity, target_value=self._target_value, dependents=self._dependents, + backfill_value=self._backfill_value, + default_value=self._default_value, ) def __repr__(self) -> str: diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index b9047a29894..275c4408e1b 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -391,6 +391,8 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: digits=parameter_sqa.digits, is_fidelity=parameter_sqa.is_fidelity or False, target_value=parameter_sqa.target_value, + backfill_value=parameter_sqa.backfill_value, + default_value=parameter_sqa.default_value, ) elif parameter_sqa.domain_type == DomainType.CHOICE: target_value = parameter_sqa.target_value @@ -414,6 +416,8 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: is_ordered=parameter_sqa.is_ordered, is_task=bool(parameter_sqa.is_task), dependents=parameter_sqa.dependents, + backfill_value=parameter_sqa.backfill_value, + default_value=parameter_sqa.default_value, ) elif parameter_sqa.domain_type == DomainType.FIXED: # Don't throw an error if parameter_sqa.fixed_value is None; @@ -425,6 +429,8 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: is_fidelity=parameter_sqa.is_fidelity or False, target_value=parameter_sqa.target_value, dependents=parameter_sqa.dependents, + backfill_value=parameter_sqa.backfill_value, + default_value=parameter_sqa.default_value, ) elif parameter_sqa.domain_type == DomainType.DERIVED: parameter = DerivedParameter( @@ -543,6 +549,8 @@ def environmental_variable_from_sqa(self, parameter_sqa: SQAParameter) -> Parame digits=parameter_sqa.digits, is_fidelity=parameter_sqa.is_fidelity or False, target_value=parameter_sqa.target_value, + backfill_value=parameter_sqa.backfill_value, + default_value=parameter_sqa.default_value, ) else: raise SQADecodeError( diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 75383e63a0b..d26e1cb2f21 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -277,6 +277,8 @@ def parameter_to_sqa(self, parameter: Parameter) -> SQAParameter: is_fidelity=parameter.is_fidelity, target_value=parameter.target_value, dependents=parameter.dependents if parameter.is_hierarchical else None, + backfill_value=parameter.backfill_value, + default_value=parameter.default_value, ) elif isinstance(parameter, ChoiceParameter): # pyre-fixme[29]: `SQAParameter` is not a function. @@ -291,6 +293,8 @@ def parameter_to_sqa(self, parameter: Parameter) -> SQAParameter: is_fidelity=parameter.is_fidelity, target_value=parameter.target_value, dependents=parameter.dependents if parameter.is_hierarchical else None, + backfill_value=parameter.backfill_value, + default_value=parameter.default_value, ) elif isinstance(parameter, FixedParameter): # pyre-fixme[29]: `SQAParameter` is not a function. @@ -303,6 +307,8 @@ def parameter_to_sqa(self, parameter: Parameter) -> SQAParameter: is_fidelity=parameter.is_fidelity, target_value=parameter.target_value, dependents=parameter.dependents if parameter.is_hierarchical else None, + backfill_value=parameter.backfill_value, + default_value=parameter.default_value, ) elif isinstance(parameter, DerivedParameter): # pyre-fixme[29]: `SQAParameter` is not a function. @@ -424,6 +430,8 @@ def environmental_variable_to_sqa(self, parameter: Parameter) -> SQAParameter: digits=parameter.digits, is_fidelity=parameter.is_fidelity, target_value=parameter.target_value, + backfill_value=parameter.backfill_value, + default_value=parameter.default_value, ) else: raise SQAEncodeError( diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 2544eec4f43..b23b9c445e4 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -71,6 +71,8 @@ class SQAParameter(Base): ) is_fidelity: Column[bool | None] = Column(Boolean) target_value: Column[TParamValue | None] = Column(JSONEncodedObject) + backfill_value: Column[TParamValue | None] = Column(JSONEncodedObject) + default_value: Column[TParamValue | None] = Column(JSONEncodedObject) # Attributes for Range Parameters digits: Column[int | None] = Column(Integer) From 3539e982a3ac4e9b0196a51b62c67300a67e9d34 Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Fri, 29 Aug 2025 15:10:21 -0700 Subject: [PATCH 2/4] New methods to add/remove search space parameters (#4178) Summary: In `Experiment` add two methods `add_parameters_to_search_space` and `disable_parameters_in_search_space`. In `Adapter._process_and_transform_data` check if the experiment's search space has been updated and update the adapter's search space and model space. Use `backfill_values` for `FillMissingParameters` transform. In `GenerationNode._determine_fixed_features_from_node` check for any disabled parameters. Add their `default_value` as fixed features. Reviewed By: lena-kashtelyan Differential Revision: D79263457 --- ax/adapter/base.py | 68 ++++++++--- ax/adapter/tests/test_base_adapter.py | 28 ++++- ax/core/experiment.py | 94 +++++++++++++++- ax/core/parameter.py | 14 +++ ax/core/search_space.py | 89 +++++++++++++++ ax/core/tests/test_experiment.py | 106 ++++++++++++++++++ ax/generation_strategy/generation_node.py | 43 ++++--- .../tests/test_generation_node.py | 23 ++++ 8 files changed, 436 insertions(+), 29 deletions(-) diff --git a/ax/adapter/base.py b/ax/adapter/base.py index 2bc3501db33..b2646e2e20d 100644 --- a/ax/adapter/base.py +++ b/ax/adapter/base.py @@ -12,7 +12,7 @@ from copy import deepcopy from dataclasses import dataclass, field from logging import Logger -from typing import Any +from typing import Any, cast import numpy as np import pandas as pd @@ -43,6 +43,7 @@ TModelMean, TModelPredict, TParameterization, + TParamValue, ) from ax.core.utils import extract_map_keys_from_opt_config, get_target_trial_index from ax.exceptions.core import UnsupportedError, UserInputError @@ -98,6 +99,16 @@ class Adapter: specification. """ + # pyre-ignore [13] Assigned in _set_and_filter_training_data. + _training_data: ExperimentData + + # The space used for optimization. + _search_space: SearchSpace + + # The space used for modeling. Might be larger than the optimization + # space to cover training data. + _model_space: SearchSpace + def __init__( self, *, @@ -184,17 +195,17 @@ def __init__( t_fit_start = time.monotonic() transforms = transforms or [] transforms = [Cast] + list(transforms) - transform_configs = {} if transform_configs is None else transform_configs - if "FillMissingParameters" in transform_configs: + self._transform_configs: Mapping[str, TConfig] = ( + {} if transform_configs is None else {**transform_configs} + ) + if "FillMissingParameters" in self._transform_configs: transforms = [FillMissingParameters] + transforms - self._raw_transforms = transforms - self._transform_configs: Mapping[str, TConfig] = transform_configs + self._raw_transforms: list[type[Transform]] = transforms + self._set_search_space(search_space or experiment.search_space) self.fit_time: float = 0.0 self.fit_time_since_gen: float = 0.0 self._metric_names: set[str] = set() - # pyre-ignore [13] Assigned in _set_and_filter_training_data. - self._training_data: ExperimentData self._optimization_config: OptimizationConfig | None = optimization_config self._training_in_design_idx: list[bool] = [] self._status_quo: Observation | None = None @@ -203,12 +214,6 @@ def __init__( self._model_key: str | None = None self._model_kwargs: dict[str, Any] | None = None self._bridge_kwargs: dict[str, Any] | None = None - # The space used for optimization. - search_space = search_space or experiment.search_space - self._search_space: SearchSpace = search_space.clone() - # The space used for modeling. Might be larger than the optimization - # space to cover training data. - self._model_space: SearchSpace = search_space.clone() self._fit_tracking_metrics = fit_tracking_metrics self.outcomes: list[str] = [] self._experiment_has_immutable_search_space_and_opt_config: bool = ( @@ -302,6 +307,7 @@ def _process_and_transform_data( ) -> tuple[ExperimentData, SearchSpace]: r"""Processes the data into ``ExperimentData`` and returns the transformed ``ExperimentData`` and the search space. This packages the following methods: + * self._set_search_space * self._set_and_filter_training_data * self._set_status_quo * self._transform_data @@ -311,6 +317,10 @@ def _process_and_transform_data( data_loader_config=self._data_loader_config, data=data, ) + # If the search space has changed, we need to update the model space + if self._search_space != experiment.search_space: + self._set_search_space(experiment.search_space) + self._set_model_space(arm_data=experiment_data.arm_data) experiment_data = self._set_and_filter_training_data( experiment_data=experiment_data, search_space=self._model_space ) @@ -352,6 +362,38 @@ def _transform_data( self.transforms[t.__name__] = t_instance return experiment_data, search_space + def _set_search_space( + self, + search_space: SearchSpace, + ) -> None: + """Set search space and model space. Adds a FillMissingParameters transform for + newly added parameters.""" + self._search_space = search_space.clone() + self._model_space = search_space.clone() + # Add FillMissingParameters transform if search space has parameters with + # backfill values. + backfill_values = search_space.backfill_values() + if len(backfill_values) > 0: + fill_missing_values_transform = self._transform_configs.get( + "FillMissingParameters", {} + ) + current_fill_values = cast( + Mapping[str, TParamValue], + fill_missing_values_transform.get("fill_values", {}), + ) + # Override backfill_values with fill values already in the transform + fill_missing_values_transform["fill_values"] = { + **backfill_values, + **current_fill_values, + } + self._transform_configs = { + **self._transform_configs, + "FillMissingParameters": fill_missing_values_transform, + } + # Add FillMissingParameters transform if not already present. + if FillMissingParameters not in self._raw_transforms: + self._raw_transforms = [FillMissingParameters] + self._raw_transforms + def _set_and_filter_training_data( self, experiment_data: ExperimentData, search_space: SearchSpace ) -> ExperimentData: diff --git a/ax/adapter/tests/test_base_adapter.py b/ax/adapter/tests/test_base_adapter.py index 08b05720662..72c8d03444a 100644 --- a/ax/adapter/tests/test_base_adapter.py +++ b/ax/adapter/tests/test_base_adapter.py @@ -849,7 +849,7 @@ def test_FillMissingParameters(self, mock_fit: Mock) -> None: # Fit model without filling missing parameters m = Adapter(experiment=experiment, generator=Generator()) self.assertEqual( - [t.__name__ for t in m._raw_transforms], # pyre-ignore[16] + [t.__name__ for t in m._raw_transforms], ["Cast"], ) # Check that SQ and all trial 1 are OOD @@ -1111,3 +1111,29 @@ def test_get_training_data(self) -> None: in_design_training_data.observation_data, training_data.observation_data.iloc[[0, 2]], ) + + def test_added_parameters(self) -> None: + exp = get_branin_experiment() + adapter = Adapter(experiment=exp, generator=Generator()) + data, ss = adapter._process_and_transform_data(experiment=exp) + self.assertEqual(ss, exp.search_space) + self.assertListEqual(list(data.arm_data.columns), ["x1", "x2", "metadata"]) + # Add new parameter + exp.add_parameters_to_search_space( + [ + RangeParameter( + name="x3", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + backfill_value=0.5, + ) + ] + ) + self.assertNotEqual(exp.search_space, adapter._search_space) + adapter._process_and_transform_data(experiment=exp) + data, ss = adapter._process_and_transform_data(experiment=exp) + self.assertEqual(ss, exp.search_space) + self.assertListEqual( + list(data.arm_data.columns), ["x1", "x2", "x3", "metadata"] + ) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index a6794f6cde0..46558dfac76 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -31,7 +31,7 @@ from ax.core.metric import Metric, MetricFetchE, MetricFetchResult from ax.core.objective import MultiObjective from ax.core.optimization_config import ObjectiveThreshold, OptimizationConfig -from ax.core.parameter import Parameter +from ax.core.parameter import DerivedParameter, Parameter from ax.core.runner import Runner from ax.core.search_space import HierarchicalSearchSpace, SearchSpace from ax.core.trial import Trial @@ -179,7 +179,7 @@ def __init__( self.add_tracking_metrics(tracking_metrics or []) # call setters defined below - self.search_space = search_space + self.search_space: SearchSpace = search_space self.status_quo = status_quo if optimization_config is not None: self.optimization_config = optimization_config @@ -277,6 +277,96 @@ def search_space(self, search_space: SearchSpace) -> None: ) self._search_space = search_space + def add_parameters_to_search_space( + self, + parameters: Sequence[Parameter], + status_quo_values: TParameterization | None = None, + ) -> None: + """ + Add new parameters to the experiment's search space. This allows extending + the search space after the experiment has run some trials. + + Backfill values must be provided for all new parameters if the experiment has + already run some trials. The backfill values represent the parameter values + that were used in the existing trials. + + Args: + parameters: A sequence of parameter configurations to add to the search + space. + status_quo_values: Optional parameter values for the new parameters to + use in the status quo (baseline) arm, if one is defined. + """ + status_quo_values = status_quo_values or {} + + # Additional checks iff a trial exists + if len(self.trials) != 0: + if any(parameter.backfill_value is None for parameter in parameters): + raise UserInputError( + "Must provide backfill values for all new parameters when " + "adding parameters to an experiment with existing trials." + ) + if any(isinstance(parameter, DerivedParameter) for parameter in parameters): + raise UserInputError( + "Cannot add derived parameters to an experiment with existing " + "trials." + ) + + # Validate status quo values + status_quo = self._status_quo + if status_quo_values is not None and status_quo is None: + logger.warning( + "Status quo values specified, but experiment does not have a " + "status quo. Ignoring provided status quo values." + ) + if status_quo is not None: + parameter_names = {parameter.name for parameter in parameters} + status_quo_parameters = status_quo_values.keys() + disabled_parameters = { + parameter.name + for parameter in self._search_space.parameters.values() + if parameter.is_disabled + } + extra_status_quo_values = status_quo_parameters - parameter_names + if extra_status_quo_values: + logger.warning( + "Status quo value provided for parameters " + f"`{extra_status_quo_values}` which is are being added to " + "the search space. Ignoring provided status quo values." + ) + mising_status_quo_values = ( + parameter_names - disabled_parameters - status_quo_parameters + ) + if mising_status_quo_values: + raise UserInputError( + "No status quo value provided for parameters " + f"`{mising_status_quo_values}` which are being added to " + "the search space." + ) + for parameter_name, value in status_quo_values.items(): + status_quo._parameters[parameter_name] = value + + # Add parameters to search space + self._search_space.add_parameters(parameters) + + def disable_parameters_in_search_space( + self, default_parameter_values: TParameterization + ) -> None: + """ + Disable parameters in the experiment. This allows narrowing the search space + after the experiment has run some trials. + + When parameters are disabled, they are effectively removed from the search + space for future trial generation. Existing trials remain valid, and the + disabled parameters are replaced with fixed default values for all subsequent + trials. + + Args: + default_parameter_values: Fixed values to use for the disabled parameters + in all future trials. These values will be used for the parameter in + all subsequent trials. + """ + self._search_space.disable_parameters(default_parameter_values) + @property def status_quo(self) -> Arm | None: """The existing arm that new arms will be compared against.""" diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 1b3d9aa0d17..cae6948172f 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -186,6 +186,20 @@ def dependents(self) -> dict[TParamValue, list[str]]: def clone(self) -> Parameter: pass + def disable(self, default_value: TParamValue) -> None: + """ + Effectively remove parameter from the search space for future trial generation. + Existing trials remain valid, and the disabled parameter is replaced with the + default_value for all subsequent trials. + """ + if self.is_disabled: + logger.warning( + f"Parameter {self.name} is already disabled with " + f"default value {self.default_value}. " + f"Updating default value to {default_value}." + ) + self._default_value = default_value + @property def _unique_id(self) -> str: return str(self) diff --git a/ax/core/search_space.py b/ax/core/search_space.py index 8207dab7b30..4d2b287e77b 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -154,6 +154,87 @@ def set_parameter_constraints( self._parameter_constraints: list[ParameterConstraint] = parameter_constraints + def add_parameters( + self, + parameters: Sequence[Parameter], + ) -> None: + """ + Add new parameters to the experiment's search space. This allows extending + the search space after the experiment has run some trials. + + Backfill values must be provided for all new parameters if the experiment has + already run some trials. The backfill values represent the parameter values + that were used in the existing trials. + """ + # Disabled parameters should be updated + parameters_to_add = [] + parameters_to_update = [] + + # Check which parameters to add to the search space and which to update + for parameter in parameters: + # Parameters already exist in search space + if parameter.name in self.parameters.keys(): + existing_parameter = self.parameters[parameter.name] + # Only disabled parameters can be re-added + if not existing_parameter.is_disabled: + raise UserInputError( + f"Parameter `{parameter.name}` already exists in search space." + ) + if type(parameter) is not type(existing_parameter): + raise UserInputError( + f"Parameter `{parameter.name}` already exists in search " + "space. Cannot change parameter type from " + f"{type(existing_parameter)} to {type(parameter)}." + ) + parameters_to_update.append(parameter) + + # Parameter does not exist in search space + else: + parameters_to_add.append(parameter) + + # Add new parameters to search space and status quo + for parameter in parameters_to_add: + self.add_parameter(parameter) + + # Update disabled parameters in search space + for parameter in parameters_to_update: + self.update_parameter(parameter) + + def disable_parameters(self, default_parameter_values: TParameterization) -> None: + """ + Disable parameters in the experiment. This allows narrowing the search space + after the experiment has run some trials. + + When parameters are disabled, they are effectively removed from the search + space for future trial generation. Existing trials remain valid, and the + disabled parameters are replaced with fixed default values for all subsequent + trials. + + Args: + default_parameter_values: Fixed values to use for the disabled parameters + in all future trials. These values will be used for the parameter in + all subsequent trials. + """ + parameters_to_disable = set(default_parameter_values.keys()) + search_space_parameters = set(self.parameters.keys()) + parameters_not_in_search_space = parameters_to_disable - search_space_parameters + + # Validate that all parameters to disable are in the search space + if len(parameters_not_in_search_space) > 0: + raise UserInputError( + f"Cannot disable parameters {parameters_not_in_search_space} " + "because they are not in the search space." + ) + + # Validate that all parameters to disable have a valid default + for parameter_to_disable, default_value in default_parameter_values.items(): + parameter = self.parameters[parameter_to_disable] + parameter.validate(default_value, raises=True) + + # Disable parameters + for parameter_to_disable, default_value in default_parameter_values.items(): + self.parameters[parameter_to_disable].disable(default_value) + def add_parameter(self, parameter: Parameter) -> None: if parameter.name in self.parameters.keys(): raise ValueError( @@ -451,6 +532,14 @@ def _validate_derived_parameter(self, parameter: DerivedParameter) -> None: "to add an fixed value to a derived parameter." ) + def backfill_values(self) -> TParameterization: + """Backfill values for parameters that have a backfill value.""" + return { + name: parameter.backfill_value + for name, parameter in self.parameters.items() + if parameter.backfill_value is not None + } + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 0de5533581f..37e96f48d7e 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -39,6 +39,7 @@ OptimizationNotConfiguredError, RunnerNotFoundError, UnsupportedError, + UserInputError, ) from ax.metrics.branin import BraninMetric from ax.runners.synthetic import SyntheticRunner @@ -352,6 +353,111 @@ def test_SearchSpaceSetter(self) -> None: with self.assertRaises(ValueError): self.experiment.search_space = extra_param_ss + def test_AddSearchSpaceParameters(self) -> None: + new_param = RangeParameter( + name="new_param", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + ) + + with self.subTest("Add parameter to experiment with no trials"): + experiment = self.experiment.clone_with(trial_indices=[]) + experiment.add_parameters_to_search_space( + parameters=[new_param], + status_quo_values={new_param.name: 0.0}, + ) + # Verify parameter was added + self.assertIn("new_param", experiment.search_space.parameters) + self.assertEqual(experiment.search_space.parameters["new_param"], new_param) + # Verify backfill value was used as status quo + self.assertIsNotNone(experiment.status_quo) + self.assertIn("new_param", experiment.status_quo.parameters) + self.assertEqual(experiment.status_quo.parameters["new_param"], 0.0) + + with self.subTest("Add parameter with status quo value"): + experiment = self.experiment.clone_with(trial_indices=[]) + experiment.add_parameters_to_search_space( + parameters=[new_param], status_quo_values={"new_param": 1.0} + ) + # Verify parameter was added + self.assertIn("new_param", experiment.search_space.parameters) + self.assertEqual(experiment.search_space.parameters["new_param"], new_param) + # Verify backfill value was used as status quo + self.assertIsNotNone(experiment.status_quo) + self.assertIn("new_param", experiment.status_quo.parameters) + self.assertEqual(experiment.status_quo.parameters["new_param"], 1.0) + + with self.subTest("Test error when adding parameter that already exists"): + experiment = self.experiment.clone_with(trial_indices=[]) + existing_param = self.experiment.search_space.parameters["w"] + with self.assertRaises(UserInputError): + experiment.add_parameters_to_search_space(parameters=[existing_param]) + + with self.subTest( + "Test error when adding parameters to experiment with trials but no " + "backfill values" + ): + experiment = self.experiment.clone_with() + experiment.new_batch_trial() + with self.assertRaises(UserInputError): + experiment.add_parameters_to_search_space(parameters=[new_param]) + + with self.subTest( + "Test successfully adding parameters with backfill values when trials exist" + ): + experiment = self.experiment.clone_with() + experiment.new_batch_trial() + new_param._backfill_value = 0.5 + experiment.add_parameters_to_search_space( + parameters=[new_param], status_quo_values={new_param.name: 0.0} + ) + # Verify parameter was added + self.assertIn("new_param", experiment.search_space.parameters) + self.assertEqual(experiment.search_space.parameters["new_param"], new_param) + # Verify backfill value was used as status quo + self.assertIsNotNone(experiment.status_quo) + self.assertIn("new_param", experiment.status_quo.parameters) + self.assertEqual(experiment.status_quo.parameters["new_param"], 0.0) + + def test_DisableSearchSpaceParameters(self) -> None: + with self.subTest( + "Test error when trying to disable parameter not in search space" + ): + experiment = self.experiment.clone_with() + with self.assertRaises(UserInputError): + experiment.disable_parameters_in_search_space({"nonexistent": 1.0}) + + with self.subTest("Test error when providing invalid default value"): + experiment = self.experiment.clone_with() + with self.assertRaises(UserInputError): + experiment.disable_parameters_in_search_space({"w": "string_value"}) + + with self.subTest("Test successfully disabling parameter"): + experiment = self.experiment.clone_with() + experiment.disable_parameters_in_search_space({"w": 2.5}) + # Verify parameter was disabled (has default value) + self.assertEqual(experiment.search_space.parameters["w"].default_value, 2.5) + + with self.subTest("Test re-enable parameter"): + # Using the same experiment as above + parameter = experiment.search_space.parameters["w"].clone() + parameter._default_value = None + experiment.add_parameters_to_search_space(parameters=[parameter]) + # Verify parameter was re-enabled + self.assertIsNone(experiment.search_space.parameters["w"].default_value) + + def test_OptimizationConfigSetter(self) -> None: + # Establish current metrics size + self.assertEqual( + len(get_optimization_config().metrics) + 1, len(self.experiment.metrics) + ) + + # Add optimization config with 1 different metric + opt_config = get_optimization_config() + opt_config.outcome_constraints[0].metric = Metric(name="m3") + self + def test_StatusQuoSetter(self) -> None: sq_parameters = self.experiment.status_quo.parameters diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 8289f839dab..36fc4070042 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -942,25 +942,42 @@ def _determine_fixed_features_from_node( An object of ObservationFeatures that represents the fixed features to pass into the model. """ + node_fixed_features = None # passed_fixed_features represents the fixed features that were passed by the # user to the gen method as overrides. passed_fixed_features = gen_kwargs.get("fixed_features") if passed_fixed_features is not None: - return passed_fixed_features + node_fixed_features = passed_fixed_features + else: + input_constructors_module = gs_module.generation_node_input_constructors + purpose_fixed_features = ( + input_constructors_module.InputConstructorPurpose.FIXED_FEATURES + ) + if purpose_fixed_features in self.input_constructors: + node_fixed_features = self.input_constructors[purpose_fixed_features]( + previous_node=self.previous_node, + next_node=self, + gs_gen_call_kwargs=gen_kwargs, + experiment=experiment, + ) + # also pass default parameter values as fixed features for disabled parameters + disabled_parameters_parameterization = { + name: parameter.default_value + for name, parameter in experiment.search_space.parameters.items() + if parameter.is_disabled + } + if len(disabled_parameters_parameterization) == 0: + return node_fixed_features - node_fixed_features = None - input_constructors_module = gs_module.generation_node_input_constructors - purpose_fixed_features = ( - input_constructors_module.InputConstructorPurpose.FIXED_FEATURES + if node_fixed_features is None: + return ObservationFeatures(parameters=disabled_parameters_parameterization) + + return node_fixed_features.clone( + replace_parameters={ + **disabled_parameters_parameterization, + **node_fixed_features.parameters, + } ) - if purpose_fixed_features in self.input_constructors: - node_fixed_features = self.input_constructors[purpose_fixed_features]( - previous_node=self.previous_node, - next_node=self, - gs_gen_call_kwargs=gen_kwargs, - experiment=experiment, - ) - return node_fixed_features class GenerationStep(GenerationNode, SortableBase): diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 8549c6de370..f07402ec9be 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -349,6 +349,29 @@ def test_single_fixed_features(self) -> None: ObservationFeatures(parameters={"x": 0}), ) + def test_disabled_parameters(self) -> None: + input_constructors = self.sobol_generation_node.apply_input_constructors( + experiment=self.branin_experiment, gen_kwargs={} + ) + self.assertIsNone(input_constructors["fixed_features"]) + # Disable parameter + self.branin_experiment.disable_parameters_in_search_space({"x1": 1.2345}) + input_constructors = self.sobol_generation_node.apply_input_constructors( + experiment=self.branin_experiment, gen_kwargs={} + ) + expected_fixed_features = ObservationFeatures(parameters={"x1": 1.2345}) + self.assertEqual(input_constructors["fixed_features"], expected_fixed_features) + # Test fixed features override + input_constructors = self.sobol_generation_node.apply_input_constructors( + experiment=self.branin_experiment, + gen_kwargs={ + "fixed_features": ObservationFeatures(parameters={"x1": 0.0, "x2": 0.0}) + }, + ) + # The passed fixed feature overrides the disabled parameter default value + expected_fixed_features = ObservationFeatures(parameters={"x1": 0.0, "x2": 0.0}) + self.assertEqual(input_constructors["fixed_features"], expected_fixed_features) + class TestGenerationStep(TestCase): def setUp(self) -> None: From 4cdce76be134fd1b0e99f68826a1bfa7ddd2760b Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Fri, 29 Aug 2025 15:10:21 -0700 Subject: [PATCH 3/4] Fix pyre errors in experiment.py (#4179) Summary: Fix remaining Pyre errors in experiment.py Reviewed By: mpolson64, lena-kashtelyan Differential Revision: D80579073 --- ax/core/experiment.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 46558dfac76..c148e03c31f 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -126,8 +126,7 @@ def __init__( # pyre-fixme[13]: Attribute `_search_space` is never initialized. self._search_space: SearchSpace self._status_quo: Arm | None = None - # pyre-fixme[13]: Attribute `_is_test` is never initialized. - self._is_test: bool + self._is_test: bool = False self._name = name self.description = description @@ -136,14 +135,12 @@ def __init__( self._data_by_trial: dict[int, OrderedDict[int, Data]] = {} self._experiment_type: str | None = experiment_type - # pyre-fixme[4]: Attribute must be annotated. - self._optimization_config = None + self._optimization_config: OptimizationConfig | None = None self._tracking_metrics: dict[str, Metric] = {} self._time_created: datetime = datetime.now() self._trials: dict[int, BaseTrial] = {} self._properties: dict[str, Any] = properties or {} - # pyre-fixme[4]: Attribute must be annotated. - self._default_data_type = default_data_type or DataType.DATA + self._default_data_type: DataType = default_data_type or DataType.DATA # Used to keep track of whether any trials on the experiment # specify a TTL. Since trials need to be checked for their TTL's # expiration often, having this attribute helps avoid unnecessary @@ -402,20 +399,20 @@ def status_quo(self, status_quo: Arm | None) -> None: # If old status_quo not present in any trials, # remove from _arms_by_signature if self._status_quo is not None: + old_status_quo_name = self._status_quo.name + old_status_quo_signature = self._status_quo.signature logger.warning( "Experiment's status_quo is updated. " "Generally the status_quo should not be changed after being set." ) persist_old_sq = False for trial in self._trials.values(): - # pyre-fixme[16]: `Optional` has no attribute `name`. - if self._status_quo.name in trial.arms_by_name: + if old_status_quo_name in trial.arms_by_name: persist_old_sq = True break if not persist_old_sq: - # pyre-fixme[16]: `Optional` has no attribute `signature`. - self._arms_by_signature.pop(self._status_quo.signature) - self._arms_by_name.pop(self._status_quo.name) + self._arms_by_signature.pop(old_status_quo_signature) + self._arms_by_name.pop(old_status_quo_name) self._status_quo = status_quo From 2b3a21e52b044cd599e1d83d59f8fe3802484026 Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Fri, 29 Aug 2025 15:10:21 -0700 Subject: [PATCH 4/4] API to add/disable parameters Summary: Client API to expose SS updates. Wrappers for `experiment.add_parameters_to_search_space` and `experiment.disable_parameters_in_search_space`. Differential Revision: D80633169 --- ax/api/client.py | 78 ++++++++++++++++++++++++++++- ax/api/tests/test_client.py | 97 ++++++++++++++++++++++++++++++++++++- 2 files changed, 173 insertions(+), 2 deletions(-) diff --git a/ax/api/client.py b/ax/api/client.py index 62ff9e4660c..e684219a146 100644 --- a/ax/api/client.py +++ b/ax/api/client.py @@ -21,6 +21,7 @@ from ax.api.protocols.runner import IRunner from ax.api.types import TOutcome, TParameterization from ax.api.utils.generation_strategy_dispatch import choose_generation_strategy +from ax.api.utils.instantiation.from_config import parameter_from_config from ax.api.utils.instantiation.from_string import optimization_config_from_string from ax.api.utils.instantiation.from_struct import experiment_from_struct from ax.api.utils.storage import db_settings_from_storage_config @@ -38,7 +39,7 @@ BaseEarlyStoppingStrategy, PercentileEarlyStoppingStrategy, ) -from ax.exceptions.core import ObjectNotFoundError, UnsupportedError +from ax.exceptions.core import ObjectNotFoundError, UnsupportedError, UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.service.orchestrator import Orchestrator, OrchestratorOptions from ax.service.utils.best_point_mixin import BestPointMixin @@ -203,6 +204,81 @@ def configure_generation_strategy( ) self.set_generation_strategy(generation_strategy=generation_strategy) + def add_parameters( + self, + parameters: Sequence[RangeParameterConfig | ChoiceParameterConfig], + backfill_values: TParameterization, + status_quo_values: TParameterization | None = None, + ) -> None: + """ + Add new parameters to the experiment's search space. This allows extending + the search space after the experiment has run some trials. + + Backfill values must be provided for all new parameters to ensure existing + trials in the experiment remain valid within the expanded search space. The + backfill values represent the parameter values that were used in the existing + trials. + + Args: + parameters: A sequence of parameter configurations to add to the search + space. + backfill_values: Parameter values to assign to existing trials for the + new parameters being added. All new parameter names must have + corresponding backfill values provided. + status_quo_values: Optional parameter values for the new parameters to + use in the status quo (baseline) arm, if one is defined. If None, + the backfill values will be used for the status quo. + """ + parameters_to_add = [ + parameter_from_config(parameter_config) for parameter_config in parameters + ] + parameter_names = {parameter.name for parameter in parameters_to_add} + missing_backfill_values = parameter_names - backfill_values.keys() + if missing_backfill_values: + raise UserInputError( + "You must provide backfill values for all parameters being added. " + f"Missing values for parameters: {missing_backfill_values}." + ) + extra_backfill_values = backfill_values.keys() - parameter_names + if extra_backfill_values: + logger.warning( + "Backfill values provided for parameters not being added: " + f"{extra_backfill_values}. Will ingore these values." + ) + for parameter in parameters_to_add: + if parameter.name in backfill_values: + parameter._backfill_value = backfill_values[parameter.name] + self._experiment.add_parameters_to_search_space( + parameters=parameters_to_add, + # pyre-fixme[6]: Type narrowing broken because core Ax + # TParameterization is dict not Mapping + status_quo_values=status_quo_values, + ) + + def disable_parameters( + self, + default_parameter_values: TParameterization, + ) -> None: + """ + Disable parameters in the experiment. This allows narrowing the search space + after the experiment has run some trials. + + When parameters are disabled, they are effectively removed from the search + space for future trial generation. Existing trials remain valid, and the + disabled parameters are replaced with fixed default values for all subsequent + trials. + + Args: + default_parameter_values: Fixed values to use for the disabled parameters + in all future trials. These values will be used for the parameter in + all subsequent trials. + """ + self._experiment.disable_parameters_in_search_space( + # pyre-fixme[6]: Type narrowing broken because core Ax + # TParameterization is dict not Mapping + default_parameter_values=default_parameter_values + ) + # -------------------- Section 1.1: Configure Automation ------------------------ def configure_runner(self, runner: IRunner) -> None: """ diff --git a/ax/api/tests/test_client.py b/ax/api/tests/test_client.py index 9783d3e7bac..314ba3d8ae7 100644 --- a/ax/api/tests/test_client.py +++ b/ax/api/tests/test_client.py @@ -36,7 +36,7 @@ from ax.core.trial import Trial from ax.core.trial_status import TrialStatus from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy -from ax.exceptions.core import UnsupportedError +from ax.exceptions.core import UnsupportedError, UserInputError from ax.service.utils.with_db_settings_base import ( _save_generation_strategy_to_db_if_possible, ) @@ -1461,6 +1461,101 @@ def test_overwrite_metric(self) -> None: ) self.assertIn(qux_metric_scalar, scalar.metrics) + def test_add_parameters(self) -> None: + client = Client() + + client.configure_experiment( + parameters=[ + RangeParameterConfig(name="x1", parameter_type="float", bounds=(-1, 1)), + RangeParameterConfig(name="x2", parameter_type="float", bounds=(-1, 1)), + ], + name="test_exp", + ) + client.configure_optimization(objective="foo") + client.configure_metrics(metrics=[DummyMetric(name="foo")]) + client._set_runner(DummyRunner()) + client.attach_baseline({"x1": 0.0, "x2": 0.0}) + + # Run a trial + client.run_trials(1) + + # Can't add parameter without a backfill value + with self.assertRaises(UserInputError): + client.add_parameters( + parameters=[ + RangeParameterConfig( + name="x3", + parameter_type="float", + bounds=(-1, 1), + ) + ], + backfill_values={}, + ) + + # Ignores extra backfill values + with self.assertLogs(logger="ax.api.client", level="WARNING") as lg: + client.add_parameters( + parameters=[], + backfill_values={"x3": 0.0}, + ) + self.assertTrue( + any( + ("Backfill values provided for parameters not being added") in msg + for msg in lg.output + ) + ) + + # Successfully adds parameter + client.add_parameters( + parameters=[ + RangeParameterConfig( + name="x3", + parameter_type="float", + bounds=(-1, 1), + ) + ], + backfill_values={"x3": 0.0}, + status_quo_values={"x3": 0.0}, + ) + + # Run one more trial + client.run_trials(1) + self.assertEqual( + client._experiment.trials[2].arms[0].parameters.keys(), {"x1", "x2", "x3"} + ) + + def test_disable_parameters(self) -> None: + client = Client() + + client.configure_experiment( + parameters=[ + RangeParameterConfig(name="x1", parameter_type="float", bounds=(-1, 1)), + ChoiceParameterConfig( + name="x2", parameter_type="str", values=["value_a", "value_b"] + ), + ], + name="test_exp", + ) + client.configure_optimization(objective="foo") + client.configure_metrics(metrics=[DummyMetric(name="foo")]) + client._set_runner(DummyRunner()) + client.attach_baseline({"x1": 0.0, "x2": "value_a"}) + + # Run a trial + client.run_trials(1) + + # Successfully disables parameter + client.disable_parameters( + default_parameter_values={"x2": "value_b"}, + ) + + # Run one more trial + client.run_trials(1) + self.assertEqual( + client._experiment.trials[2].arms[0].parameters["x2"], + "value_b", + ) + class DummyRunner(IRunner): @override