Skip to content

Commit 76fb428

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Improve Summary Analysis by Relativize the metric results if there is a status quo to relativize against (#4342)
Summary: Pull Request resolved: #4342 Reviewed By: mpolson64 Differential Revision: D82658357
1 parent 9c34ec6 commit 76fb428

File tree

5 files changed

+227
-42
lines changed

5 files changed

+227
-42
lines changed

ax/analysis/summary.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,36 @@ def compute(
7878
if experiment is None:
7979
raise UserInputError("`Summary` analysis requires an `Experiment` input")
8080

81+
# Determine if we should relativize based on:
82+
# (1) experiment has metrics and (2) experiment has status quo
83+
# (3) data type supports relativization (checked via class attribute)
84+
should_relativize = (
85+
len(experiment.metrics) > 0 and experiment.status_quo is not None
86+
)
87+
88+
# Check if the data type supports relativization
89+
if should_relativize:
90+
data = experiment.lookup_data(trial_indices=self.trial_indices)
91+
if not data.supports_relativization:
92+
should_relativize = False
93+
8194
return self._create_analysis_card(
8295
title=(
8396
"Summary for "
8497
f"{experiment.name if experiment.has_name else 'Experiment'}"
8598
),
86-
subtitle="High-level summary of the `Trial`-s in this `Experiment`",
99+
subtitle=(
100+
"High-level summary of the `Trial`-s in this `Experiment`"
101+
if not should_relativize
102+
else (
103+
"High-level summary of the `Trial`-s in this `Experiment` "
104+
"Metric results are relativized against status quo."
105+
)
106+
),
87107
df=experiment.to_df(
88108
trial_indices=self.trial_indices,
89109
omit_empty_columns=self.omit_empty_columns,
90110
trial_statuses=self.trial_statuses,
111+
relativize=should_relativize,
91112
),
92113
)

ax/analysis/tests/test_summary.py

Lines changed: 86 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,30 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-strict
7-
86
import numpy as np
97
import pandas as pd
8+
109
from ax.analysis.summary import Summary
1110
from ax.api.client import Client
1211
from ax.api.configs import RangeParameterConfig
1312
from ax.core.base_trial import TrialStatus
13+
from ax.core.map_data import MapData
1414
from ax.core.trial import Trial
1515
from ax.exceptions.core import UserInputError
1616
from ax.utils.common.testutils import TestCase
17-
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
17+
from ax.utils.testing.core_stubs import (
18+
get_branin_experiment_with_status_quo_trials,
19+
get_offline_experiments,
20+
get_online_experiments,
21+
)
1822
from pyre_extensions import assert_is_instance, none_throws
1923

2024

2125
class TestSummary(TestCase):
22-
def test_compute(self) -> None:
23-
client = Client()
24-
client.configure_experiment(
26+
def setUp(self) -> None:
27+
super().setUp()
28+
self.client = Client()
29+
self.client.configure_experiment(
2530
name="test_experiment",
2631
parameters=[
2732
RangeParameterConfig(
@@ -36,7 +41,10 @@ def test_compute(self) -> None:
3641
),
3742
],
3843
)
39-
client.configure_optimization(objective="foo, bar")
44+
self.client.configure_optimization(objective="foo, bar")
45+
46+
def test_compute(self) -> None:
47+
client = self.client
4048

4149
# Get two trials and fail one, giving us a ragged structure
4250
client.get_next_trials(max_trials=2)
@@ -142,23 +150,7 @@ def test_offline(self) -> None:
142150

143151
def test_trial_indices_filter(self) -> None:
144152
"""Test that Client.summarize correctly uses Summary."""
145-
client = Client()
146-
client.configure_experiment(
147-
name="test_experiment",
148-
parameters=[
149-
RangeParameterConfig(
150-
name="x1",
151-
parameter_type="float",
152-
bounds=(0, 1),
153-
),
154-
RangeParameterConfig(
155-
name="x2",
156-
parameter_type="float",
157-
bounds=(0, 1),
158-
),
159-
],
160-
)
161-
client.configure_optimization(objective="foo")
153+
client = self.client
162154

163155
# Get a trial
164156
client.get_next_trials(max_trials=1)
@@ -228,19 +220,7 @@ def test_trial_status_filter(self) -> None:
228220

229221
def test_default_excludes_stale_trials(self) -> None:
230222
"""Test that Summary defaults to excluding STALE trials."""
231-
# Set up experiment with basic configuration
232-
client = Client()
233-
client.configure_experiment(
234-
name="test_experiment",
235-
parameters=[
236-
RangeParameterConfig(
237-
name="x1",
238-
parameter_type="float",
239-
bounds=(0, 1),
240-
),
241-
],
242-
)
243-
client.configure_optimization(objective="foo")
223+
client = self.client
244224

245225
# Create 3 trials with different statuses to test default filtering behavior
246226
client.get_next_trials(max_trials=3)
@@ -275,3 +255,72 @@ def test_default_excludes_stale_trials(self) -> None:
275255
# Verify that no trials in the output have STALE status
276256
stale_statuses = card.df[card.df["trial_status"] == "STALE"]
277257
self.assertEqual(len(stale_statuses), 0)
258+
259+
def test_metrics_relativized_with_status_quo(self) -> None:
260+
"""Test that Summary relativizes metrics by default when status
261+
quos are present."""
262+
# Use helper function that creates batch trials with status quo
263+
experiment = get_branin_experiment_with_status_quo_trials(num_sobol_trials=2)
264+
265+
analysis = Summary()
266+
card = analysis.compute(experiment=experiment)
267+
268+
with self.subTest("subtitle_indicates_relativization"):
269+
self.assertIn("relativized", card.subtitle.lower())
270+
271+
with self.subTest("metric_values_formatted_as_percentages"):
272+
metric_values = card.df["branin"].dropna()
273+
self.assertGreater(len(metric_values), 0)
274+
for val in metric_values:
275+
self.assertIsInstance(val, str)
276+
self.assertTrue(val.endswith("%"))
277+
278+
with self.subTest("relativization_calculation_correct"):
279+
raw_data = experiment.lookup_data().df
280+
sq_name = none_throws(experiment.status_quo).name
281+
trial_0_data = raw_data[raw_data["trial_index"] == 0]
282+
treatment_arm = [a for a in experiment.trials[0].arms if a.name != sq_name][
283+
0
284+
]
285+
286+
sq_val = trial_0_data[trial_0_data["arm_name"] == sq_name]["mean"].values[0]
287+
arm_val = trial_0_data[trial_0_data["arm_name"] == treatment_arm.name][
288+
"mean"
289+
].values[0]
290+
expected = ((arm_val - sq_val) / sq_val) * 100
291+
292+
actual = float(
293+
card.df[card.df["arm_name"] == treatment_arm.name]["branin"]
294+
.values[0]
295+
.rstrip("%")
296+
)
297+
self.assertAlmostEqual(actual, expected, places=1)
298+
299+
def test_mapdata_not_relativized(self) -> None:
300+
"""Test that Summary does not attempt relativization when data is MapData,
301+
even when status quo is present."""
302+
# Create an experiment with MapData
303+
experiment = get_branin_experiment_with_status_quo_trials(num_sobol_trials=2)
304+
305+
# Replace the experiment's data with MapData
306+
map_data = MapData(
307+
df=experiment.lookup_data().df.assign(step=1.0) # Add step column
308+
)
309+
# Store the MapData in the experiment
310+
for trial in experiment.trials.values():
311+
trial_data = map_data.filter(trial_indices=[trial.index])
312+
experiment.attach_data(trial_data, combine_with_last_data=False)
313+
314+
# Compute the summary
315+
analysis = Summary()
316+
card = analysis.compute(experiment=experiment)
317+
318+
with self.subTest("subtitle_does_not_indicate_relativization"):
319+
self.assertNotIn("relativized", card.subtitle.lower())
320+
321+
with self.subTest("metric_values_not_formatted_as_percentages"):
322+
metric_values = card.df["branin"].dropna()
323+
self.assertGreater(len(metric_values), 0)
324+
# Values should be raw floats, not percentage strings
325+
for val in metric_values:
326+
self.assertIsInstance(val, (float, np.floating))

ax/core/data.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,21 @@ def relativize_dataframe(
415415
for grp in grouped_df.groups.keys():
416416
subgroup_df = grouped_df.get_group(grp)
417417
is_sq = subgroup_df["arm_name"] == status_quo_name
418-
sq_mean, sq_sem = (
419-
subgroup_df[is_sq][["mean", "sem"]].drop_duplicates().values.flatten()
420-
)
418+
419+
# Check if status quo exists in this subgroup (trial)
420+
sq_data = subgroup_df[is_sq][["mean", "sem"]].drop_duplicates().values
421+
if len(sq_data) == 0:
422+
# No status quo in this trial - skip relativization and include raw data
423+
logger.debug(
424+
"Status quo '%s' not found in trial group %s - "
425+
"skipping relativization for this group",
426+
status_quo_name,
427+
grp,
428+
)
429+
dfs.append(subgroup_df)
430+
continue
431+
432+
sq_mean, sq_sem = sq_data.flatten()
421433

422434
# rm status quo from final df to relativize
423435
if not include_sq:

ax/core/experiment.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2089,6 +2089,7 @@ def to_df(
20892089
trial_indices: Iterable[int] | None = None,
20902090
trial_statuses: Sequence[TrialStatus] | None = None,
20912091
omit_empty_columns: bool = True,
2092+
relativize: bool = False,
20922093
) -> pd.DataFrame:
20932094
"""
20942095
High-level summary of the Experiment with one row per arm. Any values missing at
@@ -2110,10 +2111,32 @@ def to_df(
21102111
trial_indices: If specified, only include these trial indices.
21112112
omit_empty_columns: If True, omit columns where every value is None.
21122113
trial_status: If specified, only include trials with this status.
2114+
relativize: If True and:
2115+
* experiment has a status quo on all of its ``BatchTrial``-s
2116+
* OR a status quo trial among its ``Trial``-s,
2117+
, relativize metrics against the status quo.
21132118
"""
21142119

21152120
records = []
2116-
data_df = self.lookup_data(trial_indices=trial_indices).df
2121+
data = self.lookup_data(trial_indices=trial_indices)
2122+
2123+
# Relativize metrics if requested
2124+
if relativize:
2125+
if self.status_quo is None:
2126+
raise UserInputError(
2127+
"Attempting to relativize the experiment data, however, "
2128+
"the experiment status quo is None. Please set the experiment "
2129+
"status quo, or set `relativize` = False"
2130+
)
2131+
2132+
data_df = data.relativize(
2133+
status_quo_name=self.status_quo.name,
2134+
as_percent=True,
2135+
include_sq=True,
2136+
).df
2137+
else:
2138+
data_df = data.df
2139+
21172140
trials = (
21182141
self.get_trials_by_indices(trial_indices=trial_indices)
21192142
if trial_indices
@@ -2175,6 +2198,20 @@ def to_df(
21752198
df = pd.DataFrame(records)
21762199
if omit_empty_columns:
21772200
df = df.loc[:, df.notnull().any()]
2201+
2202+
# Format metric columns as percentages with 4 significant figures when
2203+
# relativized
2204+
if relativize:
2205+
for metric_name in self.metrics.keys():
2206+
if metric_name in df.columns:
2207+
df[metric_name] = df[metric_name].apply(
2208+
lambda x: (
2209+
f"{x:.4g}%"
2210+
if pd.notna(x) and x != 0.0
2211+
else ("0%" if pd.notna(x) else None)
2212+
)
2213+
)
2214+
21782215
return df
21792216

21802217
def add_auxiliary_experiment(

ax/core/tests/test_experiment.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,6 +1679,72 @@ def test_to_df(self) -> None:
16791679
)
16801680
self.assertTrue(df_completed.equals(expected_completed_df))
16811681

1682+
def test_to_df_with_relativize(self) -> None:
1683+
"""Test the relativize flag in to_df method with status quo."""
1684+
# Create an experiment with status quo and completed trials
1685+
experiment = get_branin_experiment(
1686+
with_status_quo=True, with_completed_batch=True
1687+
)
1688+
1689+
with self.subTest("without relativization"):
1690+
df_no_rel = experiment.to_df(relativize=False)
1691+
1692+
# Verify dataframe has expected structure
1693+
self.assertGreater(len(df_no_rel), 0)
1694+
self.assertIn("trial_index", df_no_rel.columns)
1695+
self.assertIn("arm_name", df_no_rel.columns)
1696+
1697+
# Branin experiment has a single metric named "branin"
1698+
metric_name = "branin"
1699+
self.assertIn(metric_name, df_no_rel.columns)
1700+
1701+
# Verify metric values are numeric, not percentage strings
1702+
values = df_no_rel[metric_name]
1703+
for val in values:
1704+
self.assertIsInstance(
1705+
val, float, "Non-relativized values should be floats"
1706+
)
1707+
1708+
with self.subTest("with relativization"):
1709+
df_with_rel = experiment.to_df(relativize=True)
1710+
df_no_rel = experiment.to_df(relativize=False)
1711+
1712+
# Verify structure is preserved
1713+
self.assertEqual(len(df_with_rel), len(df_no_rel))
1714+
self.assertEqual(set(df_with_rel.columns), set(df_no_rel.columns))
1715+
1716+
# Branin experiment has a single metric named "branin"
1717+
metric_name = "branin"
1718+
1719+
# Verify relativization for the metric
1720+
self.assertIsNotNone(experiment.status_quo)
1721+
status_quo_name = experiment.status_quo.name
1722+
1723+
# Status quo should be 0% after relativization (using .4g format)
1724+
sq_rel_values = df_with_rel[df_with_rel["arm_name"] == status_quo_name][
1725+
metric_name
1726+
]
1727+
for val in sq_rel_values:
1728+
self.assertEqual(val, "0%", "Status quo should be relativized to 0%")
1729+
1730+
# Non-status-quo arms should have percentage strings
1731+
non_sq_rel_values = df_with_rel[df_with_rel["arm_name"] != status_quo_name][
1732+
metric_name
1733+
]
1734+
for val in non_sq_rel_values:
1735+
self.assertIsInstance(val, str, "Relativized values should be strings")
1736+
self.assertTrue(
1737+
val.endswith("%"), "Relativized values should end with %"
1738+
)
1739+
1740+
# Verify at least one non-status-quo value is non-zero
1741+
has_nonzero = any(float(v.rstrip("%")) != 0.0 for v in non_sq_rel_values)
1742+
self.assertTrue(
1743+
has_nonzero,
1744+
"At least one non-status-quo arm should have non-zero "
1745+
"relativized value",
1746+
)
1747+
16821748

16831749
class ExperimentWithMapDataTest(TestCase):
16841750
def setUp(self) -> None:

0 commit comments

Comments
 (0)