Skip to content

Commit bad7154

Browse files
Soft-deprecate combine_with_last_data arg, setting it as not overwrite_existing_data
Summary: Diff contents: * Soft-deprecates `combine_with_last_data` arg in `Experiment.attach_data` signature, docs, and logic. * For all passthrough callsites, updates signature and docs. * Removes all direct specification of `combine_with_last_data=True/False`, ensuring that the same behavior remains after the change (i.e., `overwrite_existing_data is not combine_with_last_data`). Reviewed By: saitcakmak Differential Revision: D75696517
1 parent 50540a6 commit bad7154

File tree

10 files changed

+51
-40
lines changed

10 files changed

+51
-40
lines changed

ax/adapter/tests/test_base_adapter.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,10 +646,7 @@ def test_set_status_quo_with_multiple_observations(self) -> None:
646646
if additional_fetch:
647647
# Fetch constraint metric an additional time. This will lead to two
648648
# separate observations for the status quo arm.
649-
exp.fetch_data(
650-
metrics=[exp.metrics["branin_map_constraint"]],
651-
combine_with_last_data=True,
652-
)
649+
exp.fetch_data(metrics=[exp.metrics["branin_map_constraint"]])
653650
with self.assertNoLogs(logger=logger, level="WARN"), mock.patch(
654651
"ax.adapter.base._combine_multiple_status_quo_observations",
655652
wraps=_combine_multiple_status_quo_observations,

ax/api/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def attach_data(
495495

496496
trial = assert_is_instance(self._experiment.trials[trial_index], Trial)
497497
trial.update_trial_data(
498-
raw_data=data_with_progression, combine_with_last_data=True
498+
raw_data=data_with_progression,
499499
)
500500

501501
self._save_or_update_trial_in_db_if_possible(

ax/core/experiment.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def get_metrics(self, metric_names: list[str] | None) -> list[Metric]:
555555
def fetch_data_results(
556556
self,
557557
metrics: list[Metric] | None = None,
558-
combine_with_last_data: bool = False,
558+
combine_with_last_data: bool | None = None,
559559
overwrite_existing_data: bool = False,
560560
**kwargs: Any,
561561
) -> dict[int, dict[str, MetricFetchResult]]:
@@ -593,7 +593,7 @@ def fetch_trials_data_results(
593593
self,
594594
trial_indices: Iterable[int],
595595
metrics: list[Metric] | None = None,
596-
combine_with_last_data: bool = False,
596+
combine_with_last_data: bool | None = None,
597597
overwrite_existing_data: bool = False,
598598
**kwargs: Any,
599599
) -> dict[int, dict[str, MetricFetchResult]]:
@@ -629,7 +629,7 @@ def fetch_data(
629629
self,
630630
trial_indices: Iterable[int] | None = None,
631631
metrics: list[Metric] | None = None,
632-
combine_with_last_data: bool = False,
632+
combine_with_last_data: bool | None = None,
633633
overwrite_existing_data: bool = False,
634634
**kwargs: Any,
635635
) -> Data:
@@ -676,7 +676,7 @@ def _lookup_or_fetch_trials_results(
676676
self,
677677
trials: list[BaseTrial],
678678
metrics: Iterable[Metric] | None = None,
679-
combine_with_last_data: bool = False,
679+
combine_with_last_data: bool | None = None,
680680
overwrite_existing_data: bool = False,
681681
**kwargs: Any,
682682
) -> dict[int, dict[str, MetricFetchResult]]:
@@ -755,23 +755,27 @@ def _fetch_trial_data(
755755
def attach_data(
756756
self,
757757
data: Data,
758-
combine_with_last_data: bool = False,
758+
# TODO[bbeckerman]: Deprecate this argument.
759+
combine_with_last_data: bool | None = None,
759760
overwrite_existing_data: bool = False,
760761
) -> int:
761762
"""Attach data to experiment. Stores data in `experiment._data_by_trial`,
762763
to be looked up via `experiment.lookup_data_for_trial`.
763764
764765
Args:
765766
data: Data object to store.
766-
combine_with_last_data: By default, when attaching data, it's identified
767+
combine_with_last_data [DEPRECATED]: This argument will be removed in Ax
768+
1.3.0. Please leave this as ``None``, in which case it will be assigned
769+
as ``not overwrite_existing_data``.
770+
By default, when attaching data, it's identified
767771
by its timestamp, and `experiment.lookup_data_for_trial` returns
768772
data by most recent timestamp. Sometimes, however, we want to combine
769773
the data from multiple calls to `attach_data` into one dataframe.
770774
This might be because:
771775
- We attached data for some metrics at one point and data for
772776
the rest of the metrics later on.
773777
- We attached data for some fidelity at one point and data for
774-
another fidelity later one.
778+
another fidelity later on.
775779
To achieve that goal, set `combine_with_last_data` to `True`.
776780
In this case, we will take the most recent previously attached
777781
data, append the newly attached data to it, attach a new
@@ -785,10 +789,34 @@ def attach_data(
785789
the incoming data contains all the information we need for a given
786790
trial, we can replace the existing data for that trial, thereby
787791
reducing the amount we need to store in the database.
792+
If set to False, we will combine the data in the present call with that
793+
from prior calls to `attach_data`, into one dataframe. Reasons to do
794+
this may include:
795+
- We attached data for some metrics at one point and data for
796+
the rest of the metrics later on.
797+
- We attached data for some fidelity at one point and data for
798+
another fidelity later on.
799+
In this case, we will take the most recent previously attached
800+
data, append the newly attached data to it, and attach a new
801+
Data object with the merged result and delete the old one.
802+
Afterwards, calls to `lookup_data_for_trial` will return this
803+
new combined data object. This operation will also validate that the
804+
newly added data does not contain observations for metrics that
805+
already have observations at the same fidelity in the most recent data.
806+
788807
789808
Returns:
790809
Timestamp of storage in millis.
791810
"""
811+
if combine_with_last_data is None:
812+
combine_with_last_data = not overwrite_existing_data
813+
else:
814+
# logger.warning(
815+
raise DeprecationWarning(
816+
"The `combine_with_last_data` argument is deprecated and will be "
817+
"removed soon. Please leave this as None, in which case it will be "
818+
"assigned as `not overwrite_existing_data`."
819+
)
792820
if combine_with_last_data and overwrite_existing_data:
793821
raise UnsupportedError(
794822
"Cannot set both combine_with_last_data=True and "
@@ -921,7 +949,7 @@ def _get_last_data_without_similar_rows(
921949
def attach_fetch_results(
922950
self,
923951
results: Mapping[int, Mapping[str, MetricFetchResult]],
924-
combine_with_last_data: bool = False,
952+
combine_with_last_data: bool | None = None,
925953
overwrite_existing_data: bool = False,
926954
) -> int | None:
927955
"""

ax/core/multi_type_experiment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def fetch_data(
252252
self,
253253
trial_indices: Iterable[int] | None = None,
254254
metrics: list[Metric] | None = None,
255-
combine_with_last_data: bool = False,
255+
combine_with_last_data: bool | None = None,
256256
overwrite_existing_data: bool = False,
257257
**kwargs: Any,
258258
) -> Data:
@@ -262,7 +262,12 @@ def fetch_data(
262262
return self.default_data_constructor.from_multiple_data(
263263
[
264264
(
265-
trial.fetch_data(**kwargs, metrics=metrics)
265+
trial.fetch_data(
266+
**kwargs,
267+
metrics=metrics,
268+
overwrite_existing_data=overwrite_existing_data,
269+
combine_with_last_data=combine_with_last_data,
270+
)
266271
if trial.status.expecting_data
267272
else Data()
268273
)

ax/core/tests/test_experiment.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def test_FetchAndStoreData(self) -> None:
520520
]
521521
)
522522
)
523-
t3 = exp.attach_data(new_data, combine_with_last_data=True)
523+
t3 = exp.attach_data(new_data)
524524
# still 6 data objs, since we combined last one
525525
self.assertEqual(len(full_dict[0]), 6)
526526
self.assertIn("z", exp.lookup_data_for_ts(t3).df["metric_name"].tolist())
@@ -557,12 +557,6 @@ def test_OverwriteExistingData(self) -> None:
557557
# automatically attaches data
558558
data = exp.fetch_data()
559559

560-
# can't set both combine_with_last_data and overwrite_existing_data
561-
with self.assertRaises(UnsupportedError):
562-
exp.attach_data(
563-
data, combine_with_last_data=True, overwrite_existing_data=True
564-
)
565-
566560
# data exists for two trials
567561
# data has been attached once for each trial
568562
self.assertEqual(len(exp._data_by_trial), 2)

ax/core/trial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def update_trial_data(
280280
raw_data: TEvaluationOutcome,
281281
metadata: dict[str, str | int] | None = None,
282282
sample_size: int | None = None,
283-
combine_with_last_data: bool = False,
283+
combine_with_last_data: bool | None = None,
284284
) -> str:
285285
"""Utility method that attaches data to a trial and
286286
returns an update message.
@@ -295,8 +295,8 @@ def update_trial_data(
295295
metadata: Additional metadata to track about this run, optional.
296296
sample_size: Number of samples collected for the underlying arm,
297297
optional.
298-
combine_with_last_data: Whether to combine the given data with the
299-
data that was previously attached to the trial. See
298+
combine_with_last_data [DEPRECATED]: Whether to combine the given data
299+
with the data that was previously attached to the trial. See
300300
`Experiment.attach_data` for a detailed explanation.
301301
302302
Returns:

ax/service/ax_client.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,6 @@ def update_running_trial_with_intermediate_data(
743743
raw_data=raw_data,
744744
metadata=metadata,
745745
sample_size=sample_size,
746-
combine_with_last_data=True,
747746
)
748747
logger.info(f"Updated trial {trial_index} with data: " f"{data_update_repr}.")
749748

@@ -794,7 +793,6 @@ def complete_trial(
794793
metadata=metadata,
795794
sample_size=sample_size,
796795
complete_trial=True,
797-
combine_with_last_data=True,
798796
)
799797
logger.info(f"Completed trial {trial_index} with data: " f"{data_update_repr}.")
800798

@@ -839,7 +837,6 @@ def update_trial_data(
839837
raw_data=raw_data,
840838
metadata=metadata,
841839
sample_size=sample_size,
842-
combine_with_last_data=True,
843840
)
844841
logger.info(f"Added data: {data_update_repr} to trial {trial.index}.")
845842

@@ -1558,7 +1555,7 @@ def _update_trial_with_raw_data(
15581555
metadata: dict[str, str | int] | None = None,
15591556
sample_size: int | None = None,
15601557
complete_trial: bool = False,
1561-
combine_with_last_data: bool = False,
1558+
combine_with_last_data: bool | None = None,
15621559
) -> str:
15631560
"""Helper method attaches data to a trial, returns a str of update."""
15641561
# Format the data to save.

ax/service/orchestrator.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,11 +1890,6 @@ def _fetch_and_process_trials_data_results(
18901890
kwargs = deepcopy(self.options.fetch_kwargs)
18911891
for k, v in self.DEFAULT_FETCH_KWARGS.items():
18921892
kwargs.setdefault(k, v)
1893-
if kwargs.get("overwrite_existing_data") and kwargs.get(
1894-
"combine_with_last_data"
1895-
):
1896-
# to avoid error https://fburl.com/code/ilix4okj
1897-
kwargs["overwrite_existing_data"] = False
18981893
if self.trial_type is not None:
18991894
metrics = assert_is_instance(
19001895
self.experiment, MultiTypeExperiment

ax/service/tests/test_orchestrator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,9 +2072,6 @@ def test_it_does_not_overwrite_data_with_combine_fetch_kwarg(self) -> None:
20722072
experiment=self.branin_experiment, # Has runner and metrics.
20732073
generation_strategy=gs,
20742074
options=OrchestratorOptions(
2075-
fetch_kwargs={
2076-
"combine_with_last_data": True,
2077-
},
20782075
**self.orchestrator_options_kwargs,
20792076
),
20802077
db_settings=self.db_settings_if_always_needed,

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -877,9 +877,7 @@ def test_ExperimentUpdateTrial(self) -> None:
877877
self.assertEqual(self.experiment, loaded_experiment)
878878

879879
# Update a trial by attaching data again
880-
self.experiment.attach_data(
881-
get_data(trial_index=trial.index), combine_with_last_data=True
882-
)
880+
self.experiment.attach_data(get_data(trial_index=trial.index))
883881
save_or_update_trial(experiment=self.experiment, trial=trial)
884882

885883
loaded_experiment = load_experiment(self.experiment.name)

0 commit comments

Comments
 (0)