Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions ax/adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from copy import deepcopy
from dataclasses import dataclass, field
from logging import Logger
from typing import Any
from typing import Any, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -43,6 +43,7 @@
TModelMean,
TModelPredict,
TParameterization,
TParamValue,
)
from ax.core.utils import extract_map_keys_from_opt_config, get_target_trial_index
from ax.exceptions.core import UnsupportedError, UserInputError
Expand Down Expand Up @@ -98,6 +99,16 @@ class Adapter:
specification.
"""

# pyre-ignore [13] Assigned in _set_and_filter_training_data.
_training_data: ExperimentData

# The space used for optimization.
_search_space: SearchSpace

# The space used for modeling. Might be larger than the optimization
# space to cover training data.
_model_space: SearchSpace

def __init__(
self,
*,
Expand Down Expand Up @@ -184,17 +195,17 @@ def __init__(
t_fit_start = time.monotonic()
transforms = transforms or []
transforms = [Cast] + list(transforms)
transform_configs = {} if transform_configs is None else transform_configs
if "FillMissingParameters" in transform_configs:
self._transform_configs: Mapping[str, TConfig] = (
{} if transform_configs is None else {**transform_configs}
)
if "FillMissingParameters" in self._transform_configs:
transforms = [FillMissingParameters] + transforms
self._raw_transforms = transforms
self._transform_configs: Mapping[str, TConfig] = transform_configs
self._raw_transforms: list[type[Transform]] = transforms
self._set_search_space(search_space or experiment.search_space)

self.fit_time: float = 0.0
self.fit_time_since_gen: float = 0.0
self._metric_names: set[str] = set()
# pyre-ignore [13] Assigned in _set_and_filter_training_data.
self._training_data: ExperimentData
self._optimization_config: OptimizationConfig | None = optimization_config
self._training_in_design_idx: list[bool] = []
self._status_quo: Observation | None = None
Expand All @@ -203,12 +214,6 @@ def __init__(
self._model_key: str | None = None
self._model_kwargs: dict[str, Any] | None = None
self._bridge_kwargs: dict[str, Any] | None = None
# The space used for optimization.
search_space = search_space or experiment.search_space
self._search_space: SearchSpace = search_space.clone()
# The space used for modeling. Might be larger than the optimization
# space to cover training data.
self._model_space: SearchSpace = search_space.clone()
self._fit_tracking_metrics = fit_tracking_metrics
self.outcomes: list[str] = []
self._experiment_has_immutable_search_space_and_opt_config: bool = (
Expand Down Expand Up @@ -302,6 +307,7 @@ def _process_and_transform_data(
) -> tuple[ExperimentData, SearchSpace]:
r"""Processes the data into ``ExperimentData`` and returns the transformed
``ExperimentData`` and the search space. This packages the following methods:
* self._set_search_space
* self._set_and_filter_training_data
* self._set_status_quo
* self._transform_data
Expand All @@ -311,6 +317,10 @@ def _process_and_transform_data(
data_loader_config=self._data_loader_config,
data=data,
)
# If the search space has changed, we need to update the model space
if self._search_space != experiment.search_space:
self._set_search_space(experiment.search_space)
self._set_model_space(arm_data=experiment_data.arm_data)
experiment_data = self._set_and_filter_training_data(
experiment_data=experiment_data, search_space=self._model_space
)
Expand Down Expand Up @@ -352,6 +362,38 @@ def _transform_data(
self.transforms[t.__name__] = t_instance
return experiment_data, search_space

def _set_search_space(
self,
search_space: SearchSpace,
) -> None:
"""Set search space and model space. Adds a FillMissingParameters transform for
newly added parameters."""
self._search_space = search_space.clone()
self._model_space = search_space.clone()
# Add FillMissingParameters transform if search space has parameters with
# backfill values.
backfill_values = search_space.backfill_values()
if len(backfill_values) > 0:
fill_missing_values_transform = self._transform_configs.get(
"FillMissingParameters", {}
)
current_fill_values = cast(
Mapping[str, TParamValue],
fill_missing_values_transform.get("fill_values", {}),
)
# Override backfill_values with fill values already in the transform
fill_missing_values_transform["fill_values"] = {
**backfill_values,
**current_fill_values,
}
self._transform_configs = {
**self._transform_configs,
"FillMissingParameters": fill_missing_values_transform,
}
# Add FillMissingParameters transform if not already present.
if FillMissingParameters not in self._raw_transforms:
self._raw_transforms = [FillMissingParameters] + self._raw_transforms

def _set_and_filter_training_data(
self, experiment_data: ExperimentData, search_space: SearchSpace
) -> ExperimentData:
Expand Down
28 changes: 27 additions & 1 deletion ax/adapter/tests/test_base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def test_FillMissingParameters(self, mock_fit: Mock) -> None:
# Fit model without filling missing parameters
m = Adapter(experiment=experiment, generator=Generator())
self.assertEqual(
[t.__name__ for t in m._raw_transforms], # pyre-ignore[16]
[t.__name__ for t in m._raw_transforms],
["Cast"],
)
# Check that SQ and all trial 1 are OOD
Expand Down Expand Up @@ -1111,3 +1111,29 @@ def test_get_training_data(self) -> None:
in_design_training_data.observation_data,
training_data.observation_data.iloc[[0, 2]],
)

def test_added_parameters(self) -> None:
exp = get_branin_experiment()
adapter = Adapter(experiment=exp, generator=Generator())
data, ss = adapter._process_and_transform_data(experiment=exp)
self.assertEqual(ss, exp.search_space)
self.assertListEqual(list(data.arm_data.columns), ["x1", "x2", "metadata"])
# Add new parameter
exp.add_parameters_to_search_space(
[
RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.0,
upper=1.0,
backfill_value=0.5,
)
]
)
self.assertNotEqual(exp.search_space, adapter._search_space)
adapter._process_and_transform_data(experiment=exp)
data, ss = adapter._process_and_transform_data(experiment=exp)
self.assertEqual(ss, exp.search_space)
self.assertListEqual(
list(data.arm_data.columns), ["x1", "x2", "x3", "metadata"]
)
78 changes: 77 additions & 1 deletion ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ax.api.protocols.runner import IRunner
from ax.api.types import TOutcome, TParameterization
from ax.api.utils.generation_strategy_dispatch import choose_generation_strategy
from ax.api.utils.instantiation.from_config import parameter_from_config
from ax.api.utils.instantiation.from_string import optimization_config_from_string
from ax.api.utils.instantiation.from_struct import experiment_from_struct
from ax.api.utils.storage import db_settings_from_storage_config
Expand All @@ -38,7 +39,7 @@
BaseEarlyStoppingStrategy,
PercentileEarlyStoppingStrategy,
)
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError, UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.service.orchestrator import Orchestrator, OrchestratorOptions
from ax.service.utils.best_point_mixin import BestPointMixin
Expand Down Expand Up @@ -203,6 +204,81 @@ def configure_generation_strategy(
)
self.set_generation_strategy(generation_strategy=generation_strategy)

def add_parameters(
self,
parameters: Sequence[RangeParameterConfig | ChoiceParameterConfig],
backfill_values: TParameterization,
status_quo_values: TParameterization | None = None,
) -> None:
"""
Add new parameters to the experiment's search space. This allows extending
the search space after the experiment has run some trials.

Backfill values must be provided for all new parameters to ensure existing
trials in the experiment remain valid within the expanded search space. The
backfill values represent the parameter values that were used in the existing
trials.

Args:
parameters: A sequence of parameter configurations to add to the search
space.
backfill_values: Parameter values to assign to existing trials for the
new parameters being added. All new parameter names must have
corresponding backfill values provided.
status_quo_values: Optional parameter values for the new parameters to
use in the status quo (baseline) arm, if one is defined. If None,
the backfill values will be used for the status quo.
"""
parameters_to_add = [
parameter_from_config(parameter_config) for parameter_config in parameters
]
parameter_names = {parameter.name for parameter in parameters_to_add}
missing_backfill_values = parameter_names - backfill_values.keys()
if missing_backfill_values:
raise UserInputError(
"You must provide backfill values for all parameters being added. "
f"Missing values for parameters: {missing_backfill_values}."
)
extra_backfill_values = backfill_values.keys() - parameter_names
if extra_backfill_values:
logger.warning(
"Backfill values provided for parameters not being added: "
f"{extra_backfill_values}. Will ingore these values."
)
for parameter in parameters_to_add:
if parameter.name in backfill_values:
parameter._backfill_value = backfill_values[parameter.name]
self._experiment.add_parameters_to_search_space(
parameters=parameters_to_add,
# pyre-fixme[6]: Type narrowing broken because core Ax
# TParameterization is dict not Mapping
status_quo_values=status_quo_values,
)

def disable_parameters(
self,
default_parameter_values: TParameterization,
) -> None:
"""
Disable parameters in the experiment. This allows narrowing the search space
after the experiment has run some trials.

When parameters are disabled, they are effectively removed from the search
space for future trial generation. Existing trials remain valid, and the
disabled parameters are replaced with fixed default values for all subsequent
trials.

Args:
default_parameter_values: Fixed values to use for the disabled parameters
in all future trials. These values will be used for the parameter in
all subsequent trials.
"""
self._experiment.disable_parameters_in_search_space(
# pyre-fixme[6]: Type narrowing broken because core Ax
# TParameterization is dict not Mapping
default_parameter_values=default_parameter_values
)

# -------------------- Section 1.1: Configure Automation ------------------------
def configure_runner(self, runner: IRunner) -> None:
"""
Expand Down
97 changes: 96 additions & 1 deletion ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ax.core.trial import Trial
from ax.core.trial_status import TrialStatus
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
from ax.exceptions.core import UnsupportedError
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.service.utils.with_db_settings_base import (
_save_generation_strategy_to_db_if_possible,
)
Expand Down Expand Up @@ -1461,6 +1461,101 @@ def test_overwrite_metric(self) -> None:
)
self.assertIn(qux_metric_scalar, scalar.metrics)

def test_add_parameters(self) -> None:
client = Client()

client.configure_experiment(
parameters=[
RangeParameterConfig(name="x1", parameter_type="float", bounds=(-1, 1)),
RangeParameterConfig(name="x2", parameter_type="float", bounds=(-1, 1)),
],
name="test_exp",
)
client.configure_optimization(objective="foo")
client.configure_metrics(metrics=[DummyMetric(name="foo")])
client._set_runner(DummyRunner())
client.attach_baseline({"x1": 0.0, "x2": 0.0})

# Run a trial
client.run_trials(1)

# Can't add parameter without a backfill value
with self.assertRaises(UserInputError):
client.add_parameters(
parameters=[
RangeParameterConfig(
name="x3",
parameter_type="float",
bounds=(-1, 1),
)
],
backfill_values={},
)

# Ignores extra backfill values
with self.assertLogs(logger="ax.api.client", level="WARNING") as lg:
client.add_parameters(
parameters=[],
backfill_values={"x3": 0.0},
)
self.assertTrue(
any(
("Backfill values provided for parameters not being added") in msg
for msg in lg.output
)
)

# Successfully adds parameter
client.add_parameters(
parameters=[
RangeParameterConfig(
name="x3",
parameter_type="float",
bounds=(-1, 1),
)
],
backfill_values={"x3": 0.0},
status_quo_values={"x3": 0.0},
)

# Run one more trial
client.run_trials(1)
self.assertEqual(
client._experiment.trials[2].arms[0].parameters.keys(), {"x1", "x2", "x3"}
)

def test_disable_parameters(self) -> None:
client = Client()

client.configure_experiment(
parameters=[
RangeParameterConfig(name="x1", parameter_type="float", bounds=(-1, 1)),
ChoiceParameterConfig(
name="x2", parameter_type="str", values=["value_a", "value_b"]
),
],
name="test_exp",
)
client.configure_optimization(objective="foo")
client.configure_metrics(metrics=[DummyMetric(name="foo")])
client._set_runner(DummyRunner())
client.attach_baseline({"x1": 0.0, "x2": "value_a"})

# Run a trial
client.run_trials(1)

# Successfully disables parameter
client.disable_parameters(
default_parameter_values={"x2": "value_b"},
)

# Run one more trial
client.run_trials(1)
self.assertEqual(
client._experiment.trials[2].arms[0].parameters["x2"],
"value_b",
)


class DummyRunner(IRunner):
@override
Expand Down
Loading