Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions ax/adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from copy import deepcopy
from dataclasses import dataclass, field
from logging import Logger
from typing import Any
from typing import Any, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -44,6 +44,7 @@
TModelMean,
TModelPredict,
TParameterization,
TParamValue,
)
from ax.core.utils import get_target_trial_index, has_map_metrics
from ax.exceptions.core import UnsupportedError, UserInputError
Expand Down Expand Up @@ -99,6 +100,16 @@ class Adapter:
specification.
"""

# pyre-ignore [13] Assigned in _set_and_filter_training_data.
_training_data: ExperimentData

# The space used for optimization.
_search_space: SearchSpace

# The space used for modeling. Might be larger than the optimization
# space to cover training data.
_model_space: SearchSpace

def __init__(
self,
*,
Expand Down Expand Up @@ -185,17 +196,17 @@ def __init__(
t_fit_start = time.monotonic()
transforms = transforms or []
transforms = [Cast] + list(transforms)
transform_configs = {} if transform_configs is None else transform_configs
if "FillMissingParameters" in transform_configs:
self._transform_configs: Mapping[str, TConfig] = (
{} if transform_configs is None else {**transform_configs}
)
if "FillMissingParameters" in self._transform_configs:
transforms = [FillMissingParameters] + transforms
self._raw_transforms = transforms
self._transform_configs: Mapping[str, TConfig] = transform_configs
self._raw_transforms: list[type[Transform]] = transforms
self._set_search_space(search_space or experiment.search_space)

self.fit_time: float = 0.0
self.fit_time_since_gen: float = 0.0
self._metric_signatures: set[str] = set()
# pyre-ignore [13] Assigned in _set_and_filter_training_data.
self._training_data: ExperimentData
self._optimization_config: OptimizationConfig | None = optimization_config
self._training_in_design_idx: list[bool] = []
self._status_quo: Observation | None = None
Expand All @@ -204,12 +215,6 @@ def __init__(
self._model_key: str | None = None
self._model_kwargs: dict[str, Any] | None = None
self._bridge_kwargs: dict[str, Any] | None = None
# The space used for optimization.
search_space = search_space or experiment.search_space
self._search_space: SearchSpace = search_space.clone()
# The space used for modeling. Might be larger than the optimization
# space to cover training data.
self._model_space: SearchSpace = search_space.clone()
self._fit_tracking_metrics = fit_tracking_metrics
self.outcomes: list[str] = []
self._experiment_has_immutable_search_space_and_opt_config: bool = (
Expand Down Expand Up @@ -303,6 +308,7 @@ def _process_and_transform_data(
) -> tuple[ExperimentData, SearchSpace]:
r"""Processes the data into ``ExperimentData`` and returns the transformed
``ExperimentData`` and the search space. This packages the following methods:
* self._set_search_space
* self._set_and_filter_training_data
* self._set_status_quo
* self._transform_data
Expand All @@ -312,6 +318,10 @@ def _process_and_transform_data(
data_loader_config=self._data_loader_config,
data=data,
)
# If the search space has changed, we need to update the model space
if self._search_space != experiment.search_space:
self._set_search_space(experiment.search_space)
self._set_model_space(arm_data=experiment_data.arm_data)
experiment_data = self._set_and_filter_training_data(
experiment_data=experiment_data, search_space=self._model_space
)
Expand Down Expand Up @@ -353,6 +363,38 @@ def _transform_data(
self.transforms[t.__name__] = t_instance
return experiment_data, search_space

def _set_search_space(
self,
search_space: SearchSpace,
) -> None:
"""Set search space and model space. Adds a FillMissingParameters transform for
newly added parameters."""
self._search_space = search_space.clone()
self._model_space = search_space.clone()
# Add FillMissingParameters transform if search space has parameters with
# backfill values.
backfill_values = search_space.backfill_values()
if len(backfill_values) > 0:
fill_missing_values_transform = self._transform_configs.get(
"FillMissingParameters", {}
)
current_fill_values = cast(
Mapping[str, TParamValue],
fill_missing_values_transform.get("fill_values", {}),
)
# Override backfill_values with fill values already in the transform
fill_missing_values_transform["fill_values"] = {
**backfill_values,
**current_fill_values,
}
self._transform_configs = {
**self._transform_configs,
"FillMissingParameters": fill_missing_values_transform,
}
# Add FillMissingParameters transform if not already present.
if FillMissingParameters not in self._raw_transforms:
self._raw_transforms = [FillMissingParameters] + self._raw_transforms

def _set_and_filter_training_data(
self, experiment_data: ExperimentData, search_space: SearchSpace
) -> ExperimentData:
Expand Down
28 changes: 27 additions & 1 deletion ax/adapter/tests/test_base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def test_FillMissingParameters(self, mock_fit: Mock) -> None:
# Fit model without filling missing parameters
m = Adapter(experiment=experiment, generator=Generator())
self.assertEqual(
[t.__name__ for t in m._raw_transforms], # pyre-ignore[16]
[t.__name__ for t in m._raw_transforms],
["Cast"],
)
# Check that SQ and all trial 1 are OOD
Expand Down Expand Up @@ -1112,3 +1112,29 @@ def test_get_training_data(self) -> None:
in_design_training_data.observation_data,
training_data.observation_data.iloc[[0, 2]],
)

def test_added_parameters(self) -> None:
exp = get_branin_experiment()
adapter = Adapter(experiment=exp, generator=Generator())
data, ss = adapter._process_and_transform_data(experiment=exp)
self.assertEqual(ss, exp.search_space)
self.assertListEqual(list(data.arm_data.columns), ["x1", "x2", "metadata"])
# Add new parameter
exp.add_parameters_to_search_space(
[
RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.0,
upper=1.0,
backfill_value=0.5,
)
]
)
self.assertNotEqual(exp.search_space, adapter._search_space)
adapter._process_and_transform_data(experiment=exp)
data, ss = adapter._process_and_transform_data(experiment=exp)
self.assertEqual(ss, exp.search_space)
self.assertListEqual(
list(data.arm_data.columns), ["x1", "x2", "x3", "metadata"]
)
103 changes: 95 additions & 8 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.objective import MultiObjective
from ax.core.optimization_config import ObjectiveThreshold, OptimizationConfig
from ax.core.parameter import Parameter
from ax.core.parameter import DerivedParameter, Parameter
from ax.core.runner import Runner
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.core.trial import Trial
Expand Down Expand Up @@ -126,8 +126,7 @@ def __init__(
# pyre-fixme[13]: Attribute `_search_space` is never initialized.
self._search_space: SearchSpace
self._status_quo: Arm | None = None
# pyre-fixme[13]: Attribute `_is_test` is never initialized.
self._is_test: bool
self._is_test: bool = False

self._name = name
self.description = description
Expand All @@ -136,14 +135,12 @@ def __init__(

self._data_by_trial: dict[int, OrderedDict[int, Data]] = {}
self._experiment_type: str | None = experiment_type
# pyre-fixme[4]: Attribute must be annotated.
self._optimization_config = None
self._optimization_config: OptimizationConfig | None = None
self._tracking_metrics: dict[str, Metric] = {}
self._time_created: datetime = datetime.now()
self._trials: dict[int, BaseTrial] = {}
self._properties: dict[str, Any] = properties or {}
# pyre-fixme[4]: Attribute must be annotated.
self._default_data_type = default_data_type or DataType.DATA
self._default_data_type: DataType = default_data_type or DataType.DATA
# Used to keep track of whether any trials on the experiment
# specify a TTL. Since trials need to be checked for their TTL's
# expiration often, having this attribute helps avoid unnecessary
Expand Down Expand Up @@ -179,7 +176,7 @@ def __init__(
self.add_tracking_metrics(tracking_metrics or [])

# call setters defined below
self.search_space = search_space
self.search_space: SearchSpace = search_space
self.status_quo = status_quo
if optimization_config is not None:
self.optimization_config = optimization_config
Expand Down Expand Up @@ -277,6 +274,96 @@ def search_space(self, search_space: SearchSpace) -> None:
)
self._search_space = search_space

def add_parameters_to_search_space(
self,
parameters: Sequence[Parameter],
status_quo_values: TParameterization | None = None,
) -> None:
"""
Add new parameters to the experiment's search space. This allows extending
the search space after the experiment has run some trials.

Backfill values must be provided for all new parameters if the experiment has
already run some trials. The backfill values represent the parameter values
that were used in the existing trials.

Args:
parameters: A sequence of parameter configurations to add to the search
space.
status_quo_values: Optional parameter values for the new parameters to
use in the status quo (baseline) arm, if one is defined.
"""
status_quo_values = status_quo_values or {}

# Additional checks iff a trial exists
if len(self.trials) != 0:
if any(parameter.backfill_value is None for parameter in parameters):
raise UserInputError(
"Must provide backfill values for all new parameters when "
"adding parameters to an experiment with existing trials."
)
if any(isinstance(parameter, DerivedParameter) for parameter in parameters):
raise UserInputError(
"Cannot add derived parameters to an experiment with existing "
"trials."
)

# Validate status quo values
status_quo = self._status_quo
if status_quo_values is not None and status_quo is None:
logger.warning(
"Status quo values specified, but experiment does not have a "
"status quo. Ignoring provided status quo values."
)
if status_quo is not None:
parameter_names = {parameter.name for parameter in parameters}
status_quo_parameters = status_quo_values.keys()
disabled_parameters = {
parameter.name
for parameter in self._search_space.parameters.values()
if parameter.is_disabled
}
extra_status_quo_values = status_quo_parameters - parameter_names
if extra_status_quo_values:
logger.warning(
"Status quo value provided for parameters "
f"`{extra_status_quo_values}` which is are being added to "
"the search space. Ignoring provided status quo values."
)
mising_status_quo_values = (
parameter_names - disabled_parameters - status_quo_parameters
)
if mising_status_quo_values:
raise UserInputError(
"No status quo value provided for parameters "
f"`{mising_status_quo_values}` which are being added to "
"the search space."
)
for parameter_name, value in status_quo_values.items():
status_quo._parameters[parameter_name] = value

# Add parameters to search space
self._search_space.add_parameters(parameters)

def disable_parameters_in_search_space(
self, default_parameter_values: TParameterization
) -> None:
"""
Disable parameters in the experiment. This allows narrowing the search space
after the experiment has run some trials.

When parameters are disabled, they are effectively removed from the search
space for future trial generation. Existing trials remain valid, and the
disabled parameters are replaced with fixed default values for all subsequent
trials.

Args:
default_parameter_values: Fixed values to use for the disabled parameters
in all future trials. These values will be used for the parameter in
all subsequent trials.
"""
self._search_space.disable_parameters(default_parameter_values)

@property
def status_quo(self) -> Arm | None:
"""The existing arm that new arms will be compared against."""
Expand Down
Loading