Skip to content

Commit 8761cf7

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Improve Summary Analysis by relativizing the metric results if there is a status quo to relativize against (#4342)
Summary: Pull Request resolved: #4342 Reviewed By: mpolson64 Differential Revision: D82658357
1 parent 12990bb commit 8761cf7

File tree

5 files changed

+226
-40
lines changed

5 files changed

+226
-40
lines changed

ax/analysis/summary.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ax.analysis.analysis_card import AnalysisCard
1515
from ax.analysis.utils import validate_experiment
1616
from ax.core.experiment import Experiment
17+
from ax.core.map_data import MapData
1718
from ax.core.trial_status import NON_STALE_STATUSES, TrialStatus
1819
from ax.exceptions.core import UserInputError
1920
from ax.generation_strategy.generation_strategy import GenerationStrategy
@@ -78,15 +79,34 @@ def compute(
7879
if experiment is None:
7980
raise UserInputError("`Summary` analysis requires an `Experiment` input")
8081

82+
# Determine if we should relativize based on:
83+
# (1) experiment has metrics and (2) experiment has status quo
84+
# (3) experiment data is not MapData (MapData doesn't support relativization
85+
# due to time-series step alignment complexities.)
86+
data = experiment.lookup_data(trial_indices=self.trial_indices)
87+
should_relativize = (
88+
len(experiment.metrics) > 0
89+
and experiment.status_quo is not None
90+
and not isinstance(data, MapData)
91+
)
92+
8193
return self._create_analysis_card(
8294
title=(
8395
"Summary for "
8496
f"{experiment.name if experiment.has_name else 'Experiment'}"
8597
),
86-
subtitle="High-level summary of the `Trial`-s in this `Experiment`",
98+
subtitle=(
99+
"High-level summary of the `Trial`-s in this `Experiment`"
100+
if not should_relativize
101+
else (
102+
"High-level summary of the `Trial`-s in this `Experiment` "
103+
"Metric results are relativized against status quo."
104+
)
105+
),
87106
df=experiment.to_df(
88107
trial_indices=self.trial_indices,
89108
omit_empty_columns=self.omit_empty_columns,
90109
trial_statuses=self.trial_statuses,
110+
relativize=should_relativize,
91111
),
92112
)

ax/analysis/tests/test_summary.py

Lines changed: 86 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,28 @@
77

88
import numpy as np
99
import pandas as pd
10+
1011
from ax.analysis.summary import Summary
1112
from ax.api.client import Client
1213
from ax.api.configs import RangeParameterConfig
1314
from ax.core.base_trial import TrialStatus
15+
from ax.core.map_data import MapData
1416
from ax.core.trial import Trial
1517
from ax.exceptions.core import UserInputError
1618
from ax.utils.common.testutils import TestCase
17-
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
19+
from ax.utils.testing.core_stubs import (
20+
get_branin_experiment_with_status_quo_trials,
21+
get_offline_experiments,
22+
get_online_experiments,
23+
)
1824
from pyre_extensions import assert_is_instance, none_throws
1925

2026

2127
class TestSummary(TestCase):
22-
def test_compute(self) -> None:
23-
client = Client()
24-
client.configure_experiment(
28+
def setUp(self) -> None:
29+
super().setUp()
30+
self.client = Client()
31+
self.client.configure_experiment(
2532
name="test_experiment",
2633
parameters=[
2734
RangeParameterConfig(
@@ -36,7 +43,10 @@ def test_compute(self) -> None:
3643
),
3744
],
3845
)
39-
client.configure_optimization(objective="foo, bar")
46+
self.client.configure_optimization(objective="foo, bar")
47+
48+
def test_compute(self) -> None:
49+
client = self.client
4050

4151
# Get two trials and fail one, giving us a ragged structure
4252
client.get_next_trials(max_trials=2)
@@ -142,23 +152,7 @@ def test_offline(self) -> None:
142152

143153
def test_trial_indices_filter(self) -> None:
144154
"""Test that Client.summarize correctly uses Summary."""
145-
client = Client()
146-
client.configure_experiment(
147-
name="test_experiment",
148-
parameters=[
149-
RangeParameterConfig(
150-
name="x1",
151-
parameter_type="float",
152-
bounds=(0, 1),
153-
),
154-
RangeParameterConfig(
155-
name="x2",
156-
parameter_type="float",
157-
bounds=(0, 1),
158-
),
159-
],
160-
)
161-
client.configure_optimization(objective="foo")
155+
client = self.client
162156

163157
# Get a trial
164158
client.get_next_trials(max_trials=1)
@@ -228,19 +222,7 @@ def test_trial_status_filter(self) -> None:
228222

229223
def test_default_excludes_stale_trials(self) -> None:
230224
"""Test that Summary defaults to excluding STALE trials."""
231-
# Set up experiment with basic configuration
232-
client = Client()
233-
client.configure_experiment(
234-
name="test_experiment",
235-
parameters=[
236-
RangeParameterConfig(
237-
name="x1",
238-
parameter_type="float",
239-
bounds=(0, 1),
240-
),
241-
],
242-
)
243-
client.configure_optimization(objective="foo")
225+
client = self.client
244226

245227
# Create 3 trials with different statuses to test default filtering behavior
246228
client.get_next_trials(max_trials=3)
@@ -275,3 +257,72 @@ def test_default_excludes_stale_trials(self) -> None:
275257
# Verify that no trials in the output have STALE status
276258
stale_statuses = card.df[card.df["trial_status"] == "STALE"]
277259
self.assertEqual(len(stale_statuses), 0)
260+
261+
def test_metrics_relativized_with_status_quo(self) -> None:
262+
"""Test that Summary relativizes metrics by default when status
263+
quos are present."""
264+
# Use helper function that creates batch trials with status quo
265+
experiment = get_branin_experiment_with_status_quo_trials(num_sobol_trials=2)
266+
267+
analysis = Summary()
268+
card = analysis.compute(experiment=experiment)
269+
270+
with self.subTest("subtitle_indicates_relativization"):
271+
self.assertIn("relativized", card.subtitle.lower())
272+
273+
with self.subTest("metric_values_formatted_as_percentages"):
274+
metric_values = card.df["branin"].dropna()
275+
self.assertGreater(len(metric_values), 0)
276+
for val in metric_values:
277+
self.assertIsInstance(val, str)
278+
self.assertTrue(val.endswith("%"))
279+
280+
with self.subTest("relativization_calculation_correct"):
281+
raw_data = experiment.lookup_data().df
282+
sq_name = none_throws(experiment.status_quo).name
283+
trial_0_data = raw_data[raw_data["trial_index"] == 0]
284+
treatment_arm = [a for a in experiment.trials[0].arms if a.name != sq_name][
285+
0
286+
]
287+
288+
sq_val = trial_0_data[trial_0_data["arm_name"] == sq_name]["mean"].values[0]
289+
arm_val = trial_0_data[trial_0_data["arm_name"] == treatment_arm.name][
290+
"mean"
291+
].values[0]
292+
expected = ((arm_val - sq_val) / sq_val) * 100
293+
294+
actual = float(
295+
card.df[card.df["arm_name"] == treatment_arm.name]["branin"]
296+
.values[0]
297+
.rstrip("%")
298+
)
299+
self.assertAlmostEqual(actual, expected, places=1)
300+
301+
def test_mapdata_not_relativized(self) -> None:
302+
"""Test that Summary does not attempt relativization when data is MapData,
303+
even when status quo is present."""
304+
# Create an experiment with MapData
305+
experiment = get_branin_experiment_with_status_quo_trials(num_sobol_trials=2)
306+
307+
# Replace the experiment's data with MapData
308+
map_data = MapData(
309+
df=experiment.lookup_data().df.assign(step=1.0) # Add step column
310+
)
311+
# Store the MapData in the experiment
312+
for trial in experiment.trials.values():
313+
trial_data = map_data.filter(trial_indices=[trial.index])
314+
experiment.attach_data(trial_data, combine_with_last_data=False)
315+
316+
# Compute the summary
317+
analysis = Summary()
318+
card = analysis.compute(experiment=experiment)
319+
320+
with self.subTest("subtitle_does_not_indicate_relativization"):
321+
self.assertNotIn("relativized", card.subtitle.lower())
322+
323+
with self.subTest("metric_values_not_formatted_as_percentages"):
324+
metric_values = card.df["branin"].dropna()
325+
self.assertGreater(len(metric_values), 0)
326+
# Values should be raw floats, not percentage strings
327+
for val in metric_values:
328+
self.assertIsInstance(val, (float, np.floating))

ax/core/data.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,21 @@ def relativize_dataframe(
410410
for grp in grouped_df.groups.keys():
411411
subgroup_df = grouped_df.get_group(grp)
412412
is_sq = subgroup_df["arm_name"] == status_quo_name
413-
sq_mean, sq_sem = (
414-
subgroup_df[is_sq][["mean", "sem"]].drop_duplicates().values.flatten()
415-
)
413+
414+
# Check if status quo exists in this subgroup (trial)
415+
sq_data = subgroup_df[is_sq][["mean", "sem"]].drop_duplicates().values
416+
if len(sq_data) == 0:
417+
# No status quo in this trial - skip relativization and include raw data
418+
logger.debug(
419+
"Status quo '%s' not found in trial group %s - "
420+
"skipping relativization for this group",
421+
status_quo_name,
422+
grp,
423+
)
424+
dfs.append(subgroup_df)
425+
continue
426+
427+
sq_mean, sq_sem = sq_data.flatten()
416428

417429
# rm status quo from final df to relativize
418430
if not include_sq:

ax/core/experiment.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2079,6 +2079,7 @@ def to_df(
20792079
trial_indices: Iterable[int] | None = None,
20802080
trial_statuses: Sequence[TrialStatus] | None = None,
20812081
omit_empty_columns: bool = True,
2082+
relativize: bool = False,
20822083
) -> pd.DataFrame:
20832084
"""
20842085
High-level summary of the Experiment with one row per arm. Any values missing at
@@ -2100,10 +2101,32 @@ def to_df(
21002101
trial_indices: If specified, only include these trial indices.
21012102
omit_empty_columns: If True, omit columns where every value is None.
21022103
trial_status: If specified, only include trials with this status.
2104+
relativize: If True and:
2105+
* experiment has a status quo on all of its ``BatchTrial``-s
2106+
* OR a status quo trial among its ``Trial``-s,
2107+
, relativize metrics against the status quo.
21032108
"""
21042109

21052110
records = []
2106-
data_df = self.lookup_data(trial_indices=trial_indices).df
2111+
data = self.lookup_data(trial_indices=trial_indices)
2112+
2113+
# Relativize metrics if requested
2114+
if relativize:
2115+
if self.status_quo is None:
2116+
raise UserInputError(
2117+
"Attempting to relativize the experiment data, however, "
2118+
"the experiment status quo is None. Please set the experiment "
2119+
"status quo, or set `relativize` = False"
2120+
)
2121+
2122+
data_df = data.relativize(
2123+
status_quo_name=self.status_quo.name,
2124+
as_percent=True,
2125+
include_sq=True,
2126+
).df
2127+
else:
2128+
data_df = data.df
2129+
21072130
trials = (
21082131
self.get_trials_by_indices(trial_indices=trial_indices)
21092132
if trial_indices
@@ -2165,6 +2188,20 @@ def to_df(
21652188
df = pd.DataFrame(records)
21662189
if omit_empty_columns:
21672190
df = df.loc[:, df.notnull().any()]
2191+
2192+
# Format metric columns as percentages with 4 significant figures when
2193+
# relativized
2194+
if relativize:
2195+
for metric_name in self.metrics.keys():
2196+
if metric_name in df.columns:
2197+
df[metric_name] = df[metric_name].apply(
2198+
lambda x: (
2199+
f"{x:.4g}%"
2200+
if pd.notna(x) and x != 0.0
2201+
else ("0%" if pd.notna(x) else None)
2202+
)
2203+
)
2204+
21682205
return df
21692206

21702207
def add_auxiliary_experiment(

ax/core/tests/test_experiment.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,6 +1679,72 @@ def test_to_df(self) -> None:
16791679
)
16801680
self.assertTrue(df_completed.equals(expected_completed_df))
16811681

1682+
def test_to_df_with_relativize(self) -> None:
1683+
"""Test the relativize flag in to_df method with status quo."""
1684+
# Create an experiment with status quo and completed trials
1685+
experiment = get_branin_experiment(
1686+
with_status_quo=True, with_completed_batch=True
1687+
)
1688+
1689+
with self.subTest("without relativization"):
1690+
df_no_rel = experiment.to_df(relativize=False)
1691+
1692+
# Verify dataframe has expected structure
1693+
self.assertGreater(len(df_no_rel), 0)
1694+
self.assertIn("trial_index", df_no_rel.columns)
1695+
self.assertIn("arm_name", df_no_rel.columns)
1696+
1697+
# Branin experiment has a single metric named "branin"
1698+
metric_name = "branin"
1699+
self.assertIn(metric_name, df_no_rel.columns)
1700+
1701+
# Verify metric values are numeric, not percentage strings
1702+
values = df_no_rel[metric_name]
1703+
for val in values:
1704+
self.assertIsInstance(
1705+
val, float, "Non-relativized values should be floats"
1706+
)
1707+
1708+
with self.subTest("with relativization"):
1709+
df_with_rel = experiment.to_df(relativize=True)
1710+
df_no_rel = experiment.to_df(relativize=False)
1711+
1712+
# Verify structure is preserved
1713+
self.assertEqual(len(df_with_rel), len(df_no_rel))
1714+
self.assertEqual(set(df_with_rel.columns), set(df_no_rel.columns))
1715+
1716+
# Branin experiment has a single metric named "branin"
1717+
metric_name = "branin"
1718+
1719+
# Verify relativization for the metric
1720+
self.assertIsNotNone(experiment.status_quo)
1721+
status_quo_name = experiment.status_quo.name
1722+
1723+
# Status quo should be 0% after relativization (using .4g format)
1724+
sq_rel_values = df_with_rel[df_with_rel["arm_name"] == status_quo_name][
1725+
metric_name
1726+
]
1727+
for val in sq_rel_values:
1728+
self.assertEqual(val, "0%", "Status quo should be relativized to 0%")
1729+
1730+
# Non-status-quo arms should have percentage strings
1731+
non_sq_rel_values = df_with_rel[df_with_rel["arm_name"] != status_quo_name][
1732+
metric_name
1733+
]
1734+
for val in non_sq_rel_values:
1735+
self.assertIsInstance(val, str, "Relativized values should be strings")
1736+
self.assertTrue(
1737+
val.endswith("%"), "Relativized values should end with %"
1738+
)
1739+
1740+
# Verify at least one non-status-quo value is non-zero
1741+
has_nonzero = any(float(v.rstrip("%")) != 0.0 for v in non_sq_rel_values)
1742+
self.assertTrue(
1743+
has_nonzero,
1744+
"At least one non-status-quo arm should have non-zero "
1745+
"relativized value",
1746+
)
1747+
16821748

16831749
class ExperimentWithMapDataTest(TestCase):
16841750
def setUp(self) -> None:

0 commit comments

Comments
 (0)