Skip to content

Commit 291a6ab

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Improve Summary Analysis by Relativize the metric results if there is a status quo to relativize against (#4342)
Summary: Pull Request resolved: #4342 Differential Revision: D82658357
1 parent 965da5a commit 291a6ab

File tree

5 files changed

+190
-42
lines changed

5 files changed

+190
-42
lines changed

ax/analysis/summary.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,29 @@ def compute(
7878
if experiment is None:
7979
raise UserInputError("`Summary` analysis requires an `Experiment` input")
8080

81+
# Determine if we should relativize based on:
82+
# (1) experiment has metrics and (2) experiment has status quo
83+
should_relativize = (
84+
len(experiment.metrics) > 0 and experiment.status_quo is not None
85+
)
86+
8187
return self._create_analysis_card(
8288
title=(
8389
"Summary for "
8490
f"{experiment.name if experiment.has_name else 'Experiment'}"
8591
),
86-
subtitle="High-level summary of the `Trial`-s in this `Experiment`",
92+
subtitle=(
93+
"High-level summary of the `Trial`-s in this `Experiment`"
94+
if not should_relativize
95+
else (
96+
"High-level summary of the `Trial`-s in this `Experiment` "
97+
"Metric results are relativized against status quo."
98+
)
99+
),
87100
df=experiment.to_df(
88101
trial_indices=self.trial_indices,
89102
omit_empty_columns=self.omit_empty_columns,
90103
trial_statuses=self.trial_statuses,
104+
relativize=should_relativize,
91105
),
92106
)

ax/analysis/tests/test_summary.py

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,29 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-strict
7-
86
import numpy as np
97
import pandas as pd
8+
109
from ax.analysis.summary import Summary
1110
from ax.api.client import Client
1211
from ax.api.configs import RangeParameterConfig
1312
from ax.core.base_trial import TrialStatus
1413
from ax.core.trial import Trial
1514
from ax.exceptions.core import UserInputError
1615
from ax.utils.common.testutils import TestCase
17-
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
16+
from ax.utils.testing.core_stubs import (
17+
get_branin_experiment_with_status_quo_trials,
18+
get_offline_experiments,
19+
get_online_experiments,
20+
)
1821
from pyre_extensions import assert_is_instance, none_throws
1922

2023

2124
class TestSummary(TestCase):
22-
def test_compute(self) -> None:
23-
client = Client()
24-
client.configure_experiment(
25+
def setUp(self) -> None:
26+
super().setUp()
27+
self.client = Client()
28+
self.client.configure_experiment(
2529
name="test_experiment",
2630
parameters=[
2731
RangeParameterConfig(
@@ -36,7 +40,10 @@ def test_compute(self) -> None:
3640
),
3741
],
3842
)
39-
client.configure_optimization(objective="foo, bar")
43+
self.client.configure_optimization(objective="foo, bar")
44+
45+
def test_compute(self) -> None:
46+
client = self.client
4047

4148
# Get two trials and fail one, giving us a ragged structure
4249
client.get_next_trials(max_trials=2)
@@ -142,23 +149,7 @@ def test_offline(self) -> None:
142149

143150
def test_trial_indices_filter(self) -> None:
144151
"""Test that Client.summarize correctly uses Summary."""
145-
client = Client()
146-
client.configure_experiment(
147-
name="test_experiment",
148-
parameters=[
149-
RangeParameterConfig(
150-
name="x1",
151-
parameter_type="float",
152-
bounds=(0, 1),
153-
),
154-
RangeParameterConfig(
155-
name="x2",
156-
parameter_type="float",
157-
bounds=(0, 1),
158-
),
159-
],
160-
)
161-
client.configure_optimization(objective="foo")
152+
client = self.client
162153

163154
# Get a trial
164155
client.get_next_trials(max_trials=1)
@@ -228,19 +219,7 @@ def test_trial_status_filter(self) -> None:
228219

229220
def test_default_excludes_stale_trials(self) -> None:
230221
"""Test that Summary defaults to excluding STALE trials."""
231-
# Set up experiment with basic configuration
232-
client = Client()
233-
client.configure_experiment(
234-
name="test_experiment",
235-
parameters=[
236-
RangeParameterConfig(
237-
name="x1",
238-
parameter_type="float",
239-
bounds=(0, 1),
240-
),
241-
],
242-
)
243-
client.configure_optimization(objective="foo")
222+
client = self.client
244223

245224
# Create 3 trials with different statuses to test default filtering behavior
246225
client.get_next_trials(max_trials=3)
@@ -275,3 +254,43 @@ def test_default_excludes_stale_trials(self) -> None:
275254
# Verify that no trials in the output have STALE status
276255
stale_statuses = card.df[card.df["trial_status"] == "STALE"]
277256
self.assertEqual(len(stale_statuses), 0)
257+
258+
def test_metrics_relativized_with_status_quo(self) -> None:
259+
"""Test that Summary relativizes metrics by default when status
260+
quos are present."""
261+
# Use helper function that creates batch trials with status quo
262+
experiment = get_branin_experiment_with_status_quo_trials(num_sobol_trials=2)
263+
264+
analysis = Summary()
265+
card = analysis.compute(experiment=experiment)
266+
267+
with self.subTest("subtitle_indicates_relativization"):
268+
self.assertIn("relativized", card.subtitle.lower())
269+
270+
with self.subTest("metric_values_formatted_as_percentages"):
271+
metric_values = card.df["branin"].dropna()
272+
self.assertGreater(len(metric_values), 0)
273+
for val in metric_values:
274+
self.assertIsInstance(val, str)
275+
self.assertTrue(val.endswith("%"))
276+
277+
with self.subTest("relativization_calculation_correct"):
278+
raw_data = experiment.lookup_data().df
279+
sq_name = none_throws(experiment.status_quo).name
280+
trial_0_data = raw_data[raw_data["trial_index"] == 0]
281+
treatment_arm = [a for a in experiment.trials[0].arms if a.name != sq_name][
282+
0
283+
]
284+
285+
sq_val = trial_0_data[trial_0_data["arm_name"] == sq_name]["mean"].values[0]
286+
arm_val = trial_0_data[trial_0_data["arm_name"] == treatment_arm.name][
287+
"mean"
288+
].values[0]
289+
expected = ((arm_val - sq_val) / sq_val) * 100
290+
291+
actual = float(
292+
card.df[card.df["arm_name"] == treatment_arm.name]["branin"]
293+
.values[0]
294+
.rstrip("%")
295+
)
296+
self.assertAlmostEqual(actual, expected, places=1)

ax/core/data.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,21 @@ def relativize_dataframe(
410410
for grp in grouped_df.groups.keys():
411411
subgroup_df = grouped_df.get_group(grp)
412412
is_sq = subgroup_df["arm_name"] == status_quo_name
413-
sq_mean, sq_sem = (
414-
subgroup_df[is_sq][["mean", "sem"]].drop_duplicates().values.flatten()
415-
)
413+
414+
# Check if status quo exists in this subgroup (trial)
415+
sq_data = subgroup_df[is_sq][["mean", "sem"]].drop_duplicates().values
416+
if len(sq_data) == 0:
417+
# No status quo in this trial - skip relativization and include raw data
418+
logger.debug(
419+
"Status quo '%s' not found in trial group %s - "
420+
"skipping relativization for this group",
421+
status_quo_name,
422+
grp,
423+
)
424+
dfs.append(subgroup_df)
425+
continue
426+
427+
sq_mean, sq_sem = sq_data.flatten()
416428

417429
# rm status quo from final df to relativize
418430
if not include_sq:

ax/core/experiment.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,7 @@ def to_df(
21302130
trial_indices: Iterable[int] | None = None,
21312131
trial_statuses: Sequence[TrialStatus] | None = None,
21322132
omit_empty_columns: bool = True,
2133+
relativize: bool = False,
21332134
) -> pd.DataFrame:
21342135
"""
21352136
High-level summary of the Experiment with one row per arm. Any values missing at
@@ -2151,10 +2152,32 @@ def to_df(
21512152
trial_indices: If specified, only include these trial indices.
21522153
omit_empty_columns: If True, omit columns where every value is None.
21532154
trial_status: If specified, only include trials with this status.
2155+
relativize: If True and:
2156+
* experiment has a status quo on all of its ``BatchTrial``-s
2157+
* OR a status quo trial among its ``Trial``-s,
2158+
, relativize metrics against the status quo.
21542159
"""
21552160

21562161
records = []
2157-
data_df = self.lookup_data(trial_indices=trial_indices).df
2162+
data = self.lookup_data(trial_indices=trial_indices)
2163+
2164+
# Relativize metrics if requested
2165+
if relativize:
2166+
if self.status_quo is None:
2167+
raise UserInputError(
2168+
"Attempting to relativize the experiment data, however, "
2169+
"the experiment status quo is None. Please set the experiment "
2170+
"status quo, or set `relativize` = False"
2171+
)
2172+
2173+
data_df = data.relativize(
2174+
status_quo_name=self.status_quo.name,
2175+
as_percent=True,
2176+
include_sq=True,
2177+
).df
2178+
else:
2179+
data_df = data.df
2180+
21582181
trials = (
21592182
self.get_trials_by_indices(trial_indices=trial_indices)
21602183
if trial_indices
@@ -2216,6 +2239,20 @@ def to_df(
22162239
df = pd.DataFrame(records)
22172240
if omit_empty_columns:
22182241
df = df.loc[:, df.notnull().any()]
2242+
2243+
# Format metric columns as percentages with 4 significant figures when
2244+
# relativized
2245+
if relativize:
2246+
for metric_name in self.metrics.keys():
2247+
if metric_name in df.columns:
2248+
df[metric_name] = df[metric_name].apply(
2249+
lambda x: (
2250+
f"{x:.4g}%"
2251+
if pd.notna(x) and x != 0.0
2252+
else ("0%" if pd.notna(x) else None)
2253+
)
2254+
)
2255+
22192256
return df
22202257

22212258
def add_auxiliary_experiment(

ax/core/tests/test_experiment.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,6 +1679,72 @@ def test_to_df(self) -> None:
16791679
)
16801680
self.assertTrue(df_completed.equals(expected_completed_df))
16811681

1682+
def test_to_df_with_relativize(self) -> None:
1683+
"""Test the relativize flag in to_df method with status quo."""
1684+
# Create an experiment with status quo and completed trials
1685+
experiment = get_branin_experiment(
1686+
with_status_quo=True, with_completed_batch=True
1687+
)
1688+
1689+
with self.subTest("without relativization"):
1690+
df_no_rel = experiment.to_df(relativize=False)
1691+
1692+
# Verify dataframe has expected structure
1693+
self.assertGreater(len(df_no_rel), 0)
1694+
self.assertIn("trial_index", df_no_rel.columns)
1695+
self.assertIn("arm_name", df_no_rel.columns)
1696+
1697+
# Branin experiment has a single metric named "branin"
1698+
metric_name = "branin"
1699+
self.assertIn(metric_name, df_no_rel.columns)
1700+
1701+
# Verify metric values are numeric, not percentage strings
1702+
values = df_no_rel[metric_name]
1703+
for val in values:
1704+
self.assertIsInstance(
1705+
val, float, "Non-relativized values should be floats"
1706+
)
1707+
1708+
with self.subTest("with relativization"):
1709+
df_with_rel = experiment.to_df(relativize=True)
1710+
df_no_rel = experiment.to_df(relativize=False)
1711+
1712+
# Verify structure is preserved
1713+
self.assertEqual(len(df_with_rel), len(df_no_rel))
1714+
self.assertEqual(set(df_with_rel.columns), set(df_no_rel.columns))
1715+
1716+
# Branin experiment has a single metric named "branin"
1717+
metric_name = "branin"
1718+
1719+
# Verify relativization for the metric
1720+
self.assertIsNotNone(experiment.status_quo)
1721+
status_quo_name = experiment.status_quo.name
1722+
1723+
# Status quo should be 0% after relativization (using .4g format)
1724+
sq_rel_values = df_with_rel[df_with_rel["arm_name"] == status_quo_name][
1725+
metric_name
1726+
]
1727+
for val in sq_rel_values:
1728+
self.assertEqual(val, "0%", "Status quo should be relativized to 0%")
1729+
1730+
# Non-status-quo arms should have percentage strings
1731+
non_sq_rel_values = df_with_rel[df_with_rel["arm_name"] != status_quo_name][
1732+
metric_name
1733+
]
1734+
for val in non_sq_rel_values:
1735+
self.assertIsInstance(val, str, "Relativized values should be strings")
1736+
self.assertTrue(
1737+
val.endswith("%"), "Relativized values should end with %"
1738+
)
1739+
1740+
# Verify at least one non-status-quo value is non-zero
1741+
has_nonzero = any(float(v.rstrip("%")) != 0.0 for v in non_sq_rel_values)
1742+
self.assertTrue(
1743+
has_nonzero,
1744+
"At least one non-status-quo arm should have non-zero "
1745+
"relativized value",
1746+
)
1747+
16821748

16831749
class ExperimentWithMapDataTest(TestCase):
16841750
def setUp(self) -> None:

0 commit comments

Comments
 (0)