From a5e764f986175591ad3f91f98a1d775514168647 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 4 Sep 2025 06:20:15 -0700 Subject: [PATCH] Make model based Derelativize.transform_optimization_config no-op if a fitted model is not yet available Summary: In order to use model predicted status quo values, we first need to have a fitted model. This prevents transforming optimization config in `Adapter._transform_data` and using it when instantiating the subsequent transforms, which we want to do to avoid the need to separately derelativize the opt config in transforms like `Winsorize` & `BilogY` (WIP D81597506). This diff changes the behavior of `Derelativize.transform_optimization_config` to be no-op rather than error out when a fitted model is not yet available. This will leave the `optimization_config` unchanged in `Adapter._transform_data` and untransform it using the model predictions in `Adapter.gen`. Reviewed By: dme65 Differential Revision: D81624306 --- ax/adapter/transforms/derelativize.py | 31 ++++++++++---- .../tests/test_derelativize_transform.py | 41 +++++++++++++++++++ 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/ax/adapter/transforms/derelativize.py b/ax/adapter/transforms/derelativize.py index 8fa3892706b..caf8f058ece 100644 --- a/ax/adapter/transforms/derelativize.py +++ b/ax/adapter/transforms/derelativize.py @@ -34,13 +34,15 @@ class Derelativize(Transform): """Changes relative constraints to not-relative constraints using a plug-in estimate of the status quo value. - If status quo is in-design, uses model estimate at status quo. If not, uses - raw observation at status quo. + By default, if status quo is in-design, uses model estimate at status quo. + If not, uses raw observation at status quo. If flag "use_raw_status_quo" is + set to True in the transform config, will always use raw observation. - Will raise an error if status quo is in-design and model fails to predict - for it, unless the flag "use_raw_status_quo" is set to True in the - transform config, in which case it will fall back to using the observed - value in the training data. + If using model estimate and a fitted model is not yet available, the + transform will be no-op. This can happen when transforming the data before + fitting the model. During candidate generation, the fitted model will be + available and the predictions can be used to derelativize the optimization + config before using it for candidate generation. Transform is done in-place. """ @@ -78,9 +80,20 @@ def transform_optimization_config( ): # Only use model predictions if the status quo is in the search space # (including parameter constraints) and `use_raw_sq` is false. - f, cov = adapter.predict( - observation_features=[sq.features], use_posterior_predictive=True - ) + try: + f, cov = adapter.predict( + observation_features=[sq.features], use_posterior_predictive=True + ) + except ValueError as e: + if "Generator must be fit" in str(e): + logger.debug( + "Returning optimization config as is since the fitted model " + "is not yet available." + ) + return optimization_config + else: + raise e + # Warn if the raw SQ values are outside of the CI for the predictions. _warn_if_raw_sq_is_out_of_CI( optimization_config=optimization_config, diff --git a/ax/adapter/transforms/tests/test_derelativize_transform.py b/ax/adapter/transforms/tests/test_derelativize_transform.py index 417303a84b5..59bdeb24a41 100644 --- a/ax/adapter/transforms/tests/test_derelativize_transform.py +++ b/ax/adapter/transforms/tests/test_derelativize_transform.py @@ -403,3 +403,44 @@ def test_warning_if_raw_sq_is_out_of_CI(self) -> None: ), self.assertLogs(logger=logger) as mock_logs: adapter.gen(n=1) self.assertIn("deviate more than", mock_logs.records[0].getMessage()) + + def test_without_fitted_model(self) -> None: + # This test checks that the transform is no-op without a fitted model. + exp = get_branin_experiment( + with_status_quo=True, + with_completed_trial=True, + with_relative_constraint=True, + ) + # Add data for SQ. + trial = exp.new_trial().add_arm(exp.status_quo).run().mark_completed() + exp.attach_data(get_branin_data(trials=[trial], metrics=exp.metrics)) + + adapter = TorchAdapter( + experiment=exp, + generator=BoTorchGenerator(), + transforms=[Derelativize], + fit_on_init=False, + ) + t = Derelativize(search_space=exp.search_space, adapter=adapter) + opt_config = none_throws(exp.optimization_config) + with self.assertLogs(logger=logger, level="DEBUG") as mock_logs: + self.assertEqual( + t.transform_optimization_config( + optimization_config=opt_config.clone(), adapter=adapter + ), + opt_config, + ) + self.assertIn( + "fitted model is not yet available", mock_logs.records[0].getMessage() + ) + # Check that it is transformed with a fitted model. + adapter = TorchAdapter( + experiment=exp, generator=BoTorchGenerator(), transforms=[Derelativize] + ) + t = Derelativize(search_space=exp.search_space, adapter=adapter) + self.assertNotEqual( + t.transform_optimization_config( + optimization_config=opt_config.clone(), adapter=adapter + ), + opt_config, + )