Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions ax/adapter/transforms/derelativize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ class Derelativize(Transform):
"""Changes relative constraints to not-relative constraints using a plug-in
estimate of the status quo value.

If status quo is in-design, uses model estimate at status quo. If not, uses
raw observation at status quo.
By default, if status quo is in-design, uses model estimate at status quo.
If not, uses raw observation at status quo. If flag "use_raw_status_quo" is
set to True in the transform config, will always use raw observation.

Will raise an error if status quo is in-design and model fails to predict
for it, unless the flag "use_raw_status_quo" is set to True in the
transform config, in which case it will fall back to using the observed
value in the training data.
If using model estimate and a fitted model is not yet available, the
transform will be no-op. This can happen when transforming the data before
fitting the model. During candidate generation, the fitted model will be
available and the predictions can be used to derelativize the optimization
config before using it for candidate generation.

Transform is done in-place.
"""
Expand Down Expand Up @@ -78,9 +80,20 @@ def transform_optimization_config(
):
# Only use model predictions if the status quo is in the search space
# (including parameter constraints) and `use_raw_sq` is false.
f, cov = adapter.predict(
observation_features=[sq.features], use_posterior_predictive=True
)
try:
f, cov = adapter.predict(
observation_features=[sq.features], use_posterior_predictive=True
)
except ValueError as e:
if "Generator must be fit" in str(e):
logger.debug(
"Returning optimization config as is since the fitted model "
"is not yet available."
)
return optimization_config
else:
raise e

# Warn if the raw SQ values are outside of the CI for the predictions.
_warn_if_raw_sq_is_out_of_CI(
optimization_config=optimization_config,
Expand Down
41 changes: 41 additions & 0 deletions ax/adapter/transforms/tests/test_derelativize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,44 @@ def test_warning_if_raw_sq_is_out_of_CI(self) -> None:
), self.assertLogs(logger=logger) as mock_logs:
adapter.gen(n=1)
self.assertIn("deviate more than", mock_logs.records[0].getMessage())

def test_without_fitted_model(self) -> None:
# This test checks that the transform is no-op without a fitted model.
exp = get_branin_experiment(
with_status_quo=True,
with_completed_trial=True,
with_relative_constraint=True,
)
# Add data for SQ.
trial = exp.new_trial().add_arm(exp.status_quo).run().mark_completed()
exp.attach_data(get_branin_data(trials=[trial], metrics=exp.metrics))

adapter = TorchAdapter(
experiment=exp,
generator=BoTorchGenerator(),
transforms=[Derelativize],
fit_on_init=False,
)
t = Derelativize(search_space=exp.search_space, adapter=adapter)
opt_config = none_throws(exp.optimization_config)
with self.assertLogs(logger=logger, level="DEBUG") as mock_logs:
self.assertEqual(
t.transform_optimization_config(
optimization_config=opt_config.clone(), adapter=adapter
),
opt_config,
)
self.assertIn(
"fitted model is not yet available", mock_logs.records[0].getMessage()
)
# Check that it is transformed with a fitted model.
adapter = TorchAdapter(
experiment=exp, generator=BoTorchGenerator(), transforms=[Derelativize]
)
t = Derelativize(search_space=exp.search_space, adapter=adapter)
self.assertNotEqual(
t.transform_optimization_config(
optimization_config=opt_config.clone(), adapter=adapter
),
opt_config,
)