Skip to content

Commit 2ca0f24

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Improve Summary Analysis by Relativize the metric results if there is a status quo to relativize against (facebook#4342)
Summary: Pull Request resolved: facebook#4342 Differential Revision: D82658357
1 parent fa14162 commit 2ca0f24

File tree

5 files changed

+153
-3
lines changed

5 files changed

+153
-3
lines changed

ax/analysis/summary.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def compute(
6363
if experiment is None:
6464
raise UserInputError("`Summary` analysis requires an `Experiment` input")
6565

66+
# Determine if we should relativize based on:
67+
# (1) experiment has metrics and (2) experiment has status quo
68+
should_relativize = (
69+
len(experiment.metrics) > 0 and experiment.status_quo is not None
70+
)
71+
6672
return self._create_analysis_card(
6773
title=(
6874
"Summary for "
@@ -73,5 +79,6 @@ def compute(
7379
trial_indices=self.trial_indices,
7480
omit_empty_columns=self.omit_empty_columns,
7581
trial_statuses=self.trial_statuses,
82+
relativize=should_relativize,
7683
),
7784
)

ax/analysis/tests/test_summary.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,64 @@ def test_default_excludes_stale_trials(self) -> None:
275275
# Verify that no trials in the output have STALE status
276276
stale_statuses = card.df[card.df["trial_status"] == "STALE"]
277277
self.assertEqual(len(stale_statuses), 0)
278+
279+
def test_metrics_relativized_with_status_quo(self) -> None:
280+
"""Test that Summary relativizes metrics by default when status quo is
281+
present."""
282+
client = Client()
283+
client.configure_experiment(
284+
name="test_experiment_relativize",
285+
parameters=[
286+
RangeParameterConfig(
287+
name="x1",
288+
parameter_type="float",
289+
bounds=(0, 1),
290+
),
291+
],
292+
)
293+
client.configure_optimization(objective="metric1")
294+
295+
# Add status quo
296+
baseline_trial_index = client.attach_baseline({"x1": 0.5})
297+
client.complete_trial(
298+
trial_index=baseline_trial_index, raw_data={"metric1": 90.0}
299+
)
300+
301+
# Get trials and complete with metric data
302+
client.get_next_trials(max_trials=2)
303+
304+
# Complete trials with different metric values
305+
client.complete_trial(
306+
trial_index=baseline_trial_index + 1, raw_data={"metric1": 100.0}
307+
)
308+
client.complete_trial(
309+
trial_index=baseline_trial_index + 2, raw_data={"metric1": 80.0}
310+
)
311+
312+
experiment = client._experiment
313+
314+
# Test that Summary works and produces results
315+
# (relativization happens internally)
316+
analysis = Summary()
317+
318+
card = analysis.compute(experiment=experiment)
319+
320+
# Verify basic structure
321+
self.assertEqual(card.name, "Summary")
322+
self.assertEqual(card.title, "Summary for test_experiment_relativize")
323+
self.assertTrue("metric1" in card.df.columns)
324+
self.assertEqual(len(card.df), 3)
325+
326+
# Verify all trials are present (baseline at index 0,
327+
# regular trials at indices 1 and 2)
328+
trial_indices = set(card.df["trial_index"].values)
329+
330+
self.assertEqual(trial_indices, {0, 1, 2})
331+
332+
# Check that metric values are present (actual relativization values depend on
333+
# the underlying experiment.to_df implementation with relativize=True)
334+
# Some values might be NaN due to relativization, but not all should be NaN
335+
metric_values = card.df["metric1"].values
336+
non_na_count = sum(~pd.isna(metric_values))
337+
# At least some trials should have non-NaN metric values
338+
self.assertGreater(non_na_count, 0, "All metric values are NaN")

ax/core/data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,15 @@ def relativize(
447447
axis=1,
448448
)
449449
)
450+
if not dfs:
451+
raise ValueError(
452+
f"Relativization not possible: status quo arm '{status_quo_name}' "
453+
f"not found or dataset contains no data."
454+
)
450455
df_rel = pd.concat(dfs, axis=0)
451456
if include_sq:
457+
# Set status quo to exactly 0 mean and 0 SEM to avoid negative zero display
458+
df_rel.loc[df_rel["arm_name"] == status_quo_name, "mean"] = 0.0
452459
df_rel.loc[df_rel["arm_name"] == status_quo_name, "sem"] = 0.0
453460
return Data(df_rel)
454461

ax/core/experiment.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,7 @@ def to_df(
20392039
trial_indices: Iterable[int] | None = None,
20402040
trial_statuses: Sequence[TrialStatus] | None = None,
20412041
omit_empty_columns: bool = True,
2042+
relativize: bool = False,
20422043
) -> pd.DataFrame:
20432044
"""
20442045
High-level summary of the Experiment with one row per arm. Any values missing at
@@ -2060,10 +2061,23 @@ def to_df(
20602061
trial_indices: If specified, only include these trial indices.
20612062
omit_empty_columns: If True, omit columns where every value is None.
20622063
trial_status: If specified, only include trials with this status.
2064+
relativize: If True and experiment has a status quo, relativize metrics
20632065
"""
20642066

20652067
records = []
2066-
data_df = self.lookup_data(trial_indices=trial_indices).df
2068+
data = self.lookup_data(trial_indices=trial_indices)
2069+
2070+
# Relativize metrics if requested
2071+
data_df = (
2072+
data.relativize(
2073+
status_quo_name=none_throws(self.status_quo).name,
2074+
as_percent=True,
2075+
include_sq=True,
2076+
).df
2077+
if relativize
2078+
else data.df
2079+
)
2080+
20672081
trials = (
20682082
self.get_trials_by_indices(trial_indices=trial_indices)
20692083
if trial_indices
@@ -2123,6 +2137,7 @@ def to_df(
21232137
records.append(record)
21242138

21252139
df = pd.DataFrame(records)
2140+
21262141
if omit_empty_columns:
21272142
df = df.loc[:, df.notnull().any()]
21282143
return df

ax/core/tests/test_experiment.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
)
7676
from ax.utils.testing.mock import mock_botorch_optimize
7777
from pandas.testing import assert_frame_equal
78-
from pyre_extensions import assert_is_instance
78+
from pyre_extensions import assert_is_instance, none_throws
7979

8080
DUMMY_RUN_METADATA_KEY_1 = "test_run_metadata_key_1"
8181
DUMMY_RUN_METADATA_KEY_2 = "test_run_metadata_key_2"
@@ -471,7 +471,7 @@ def test_StatusQuoSetter(self) -> None:
471471
sq_parameters["w"] = 3.5
472472
self.experiment.status_quo = Arm(sq_parameters)
473473
self.assertEqual(self.experiment.status_quo.parameters["w"], 3.5)
474-
self.assertEqual(self.experiment.status_quo.name, "status_quo_e0")
474+
self.assertEqual(none_throws(self.experiment.status_quo).name, "status_quo_e0")
475475

476476
# Verify all None values
477477
self.experiment.status_quo = Arm({n: None for n in sq_parameters.keys()})
@@ -1640,6 +1640,66 @@ def test_to_df(self) -> None:
16401640
)
16411641
self.assertTrue(df_completed.equals(expected_completed_df))
16421642

1643+
def test_to_df_with_relativize(self) -> None:
1644+
"""Test the relativize flag in to_df method with status quo."""
1645+
# Create an experiment with status quo and completed trials
1646+
experiment = get_branin_experiment(with_status_quo=True)
1647+
1648+
# Create two completed trials
1649+
for _ in range(2):
1650+
sobol_run = get_sobol(search_space=experiment.search_space).gen(n=1)
1651+
trial = experiment.new_trial(generator_run=sobol_run)
1652+
trial.mark_running(no_runner_required=True)
1653+
trial.mark_completed()
1654+
1655+
# Fetch and add status quo data
1656+
experiment.fetch_data()
1657+
sq_data = Data(
1658+
df=pd.DataFrame(
1659+
[
1660+
{
1661+
"trial_index": i,
1662+
"arm_name": "status_quo",
1663+
"metric_name": "branin",
1664+
"metric_signature": "branin",
1665+
"mean": 10.0,
1666+
"sem": 0.1,
1667+
}
1668+
for i in range(2)
1669+
]
1670+
)
1671+
)
1672+
experiment.attach_data(sq_data)
1673+
1674+
# Test without relativization
1675+
df_no_rel = experiment.to_df(relativize=False)
1676+
1677+
# Test with relativization
1678+
df_with_rel = experiment.to_df(relativize=True)
1679+
1680+
# Basic structure should be the same
1681+
self.assertEqual(len(df_with_rel), len(df_no_rel))
1682+
self.assertEqual(set(df_with_rel.columns), set(df_no_rel.columns))
1683+
1684+
# Find metric columns and verify relativization occurred
1685+
metric_cols = [
1686+
col
1687+
for col in df_no_rel.columns
1688+
if col
1689+
not in ["trial_index", "arm_name", "trial_status", "name", "x1", "x2"]
1690+
]
1691+
1692+
if metric_cols:
1693+
metric_name = metric_cols[0]
1694+
orig_values = df_no_rel[metric_name].dropna()
1695+
rel_values = df_with_rel[metric_name].dropna()
1696+
1697+
# Values should change for non-status-quo trials
1698+
non_sq_changed = any(
1699+
abs(o - r) > 1e-10 for o, r in zip(orig_values, rel_values) if o != 10.0
1700+
)
1701+
self.assertTrue(non_sq_changed, "Relativization should change some values")
1702+
16431703

16441704
class ExperimentWithMapDataTest(TestCase):
16451705
def setUp(self) -> None:

0 commit comments

Comments
 (0)