Skip to content

Commit af0a072

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Improve Summary Analysis by Relativize the metric results if there is a status quo to relativize against (#4342)
Summary: Pull Request resolved: #4342 Differential Revision: D82658357
1 parent 5f06378 commit af0a072

File tree

4 files changed

+223
-37
lines changed

4 files changed

+223
-37
lines changed

ax/analysis/summary.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ax.analysis.analysis import Analysis
1414
from ax.analysis.analysis_card import AnalysisCard
15+
from ax.analysis.utils import filter_trials_by_indices_and_statuses
1516
from ax.core.experiment import Experiment
1617
from ax.core.trial_status import NON_STALE_STATUSES, TrialStatus
1718
from ax.exceptions.core import UserInputError
@@ -63,15 +64,46 @@ def compute(
6364
if experiment is None:
6465
raise UserInputError("`Summary` analysis requires an `Experiment` input")
6566

67+
# Get the trials that will be included in the summary
68+
trials = filter_trials_by_indices_and_statuses(
69+
experiment=experiment,
70+
trial_indices=self.trial_indices,
71+
trial_statuses=self.trial_statuses,
72+
)
73+
74+
# Check if all trials have status quo in their arms
75+
all_trials_have_status_quo = all(
76+
experiment.status_quo is not None
77+
and experiment.status_quo.name in trial.arms_by_name
78+
for trial in trials
79+
)
80+
81+
# Determine if we should relativize based on:
82+
# (1) experiment has metrics and (2) experiment has status quo
83+
# (3) all trials being used have status quo
84+
should_relativize = (
85+
len(experiment.metrics) > 0
86+
and experiment.status_quo is not None
87+
and all_trials_have_status_quo
88+
)
89+
6690
return self._create_analysis_card(
6791
title=(
6892
"Summary for "
6993
f"{experiment.name if experiment.has_name else 'Experiment'}"
7094
),
71-
subtitle="High-level summary of the `Trial`-s in this `Experiment`",
95+
subtitle=(
96+
"High-level summary of the `Trial`-s in this `Experiment`"
97+
if not should_relativize
98+
else (
99+
"High-level summary of the `Trial`-s in this `Experiment` "
100+
"Metric results are relativized against status quo."
101+
)
102+
),
72103
df=experiment.to_df(
73104
trial_indices=self.trial_indices,
74105
omit_empty_columns=self.omit_empty_columns,
75106
trial_statuses=self.trial_statuses,
107+
relativize=should_relativize,
76108
),
77109
)

ax/analysis/tests/test_summary.py

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,29 @@
77

88
import numpy as np
99
import pandas as pd
10+
11+
from ax.adapter.factory import get_sobol
1012
from ax.analysis.summary import Summary
1113
from ax.api.client import Client
1214
from ax.api.configs import RangeParameterConfig
1315
from ax.core.base_trial import TrialStatus
1416
from ax.core.trial import Trial
1517
from ax.exceptions.core import UserInputError
1618
from ax.utils.common.testutils import TestCase
17-
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
19+
from ax.utils.testing.core_stubs import (
20+
get_branin_data_batch,
21+
get_branin_experiment,
22+
get_offline_experiments,
23+
get_online_experiments,
24+
)
1825
from pyre_extensions import assert_is_instance, none_throws
1926

2027

2128
class TestSummary(TestCase):
22-
def test_compute(self) -> None:
23-
client = Client()
24-
client.configure_experiment(
29+
def setUp(self) -> None:
30+
super().setUp()
31+
self.client = Client()
32+
self.client.configure_experiment(
2533
name="test_experiment",
2634
parameters=[
2735
RangeParameterConfig(
@@ -36,7 +44,10 @@ def test_compute(self) -> None:
3644
),
3745
],
3846
)
39-
client.configure_optimization(objective="foo, bar")
47+
self.client.configure_optimization(objective="foo, bar")
48+
49+
def test_compute(self) -> None:
50+
client = self.client
4051

4152
# Get two trials and fail one, giving us a ragged structure
4253
client.get_next_trials(max_trials=2)
@@ -142,23 +153,7 @@ def test_offline(self) -> None:
142153

143154
def test_trial_indices_filter(self) -> None:
144155
"""Test that Client.summarize correctly uses Summary."""
145-
client = Client()
146-
client.configure_experiment(
147-
name="test_experiment",
148-
parameters=[
149-
RangeParameterConfig(
150-
name="x1",
151-
parameter_type="float",
152-
bounds=(0, 1),
153-
),
154-
RangeParameterConfig(
155-
name="x2",
156-
parameter_type="float",
157-
bounds=(0, 1),
158-
),
159-
],
160-
)
161-
client.configure_optimization(objective="foo")
156+
client = self.client
162157

163158
# Get a trial
164159
client.get_next_trials(max_trials=1)
@@ -228,19 +223,7 @@ def test_trial_status_filter(self) -> None:
228223

229224
def test_default_excludes_stale_trials(self) -> None:
230225
"""Test that Summary defaults to excluding STALE trials."""
231-
# Set up experiment with basic configuration
232-
client = Client()
233-
client.configure_experiment(
234-
name="test_experiment",
235-
parameters=[
236-
RangeParameterConfig(
237-
name="x1",
238-
parameter_type="float",
239-
bounds=(0, 1),
240-
),
241-
],
242-
)
243-
client.configure_optimization(objective="foo")
226+
client = self.client
244227

245228
# Create 3 trials with different statuses to test default filtering behavior
246229
client.get_next_trials(max_trials=3)
@@ -275,3 +258,54 @@ def test_default_excludes_stale_trials(self) -> None:
275258
# Verify that no trials in the output have STALE status
276259
stale_statuses = card.df[card.df["trial_status"] == "STALE"]
277260
self.assertEqual(len(stale_statuses), 0)
261+
262+
def test_metrics_relativized_with_status_quo(self) -> None:
263+
"""Test that Summary relativizes metrics by default when status
264+
quos are present."""
265+
experiment = get_branin_experiment(with_status_quo=True, named=True)
266+
experiment.name = "test_experiment_relativize"
267+
268+
# Create batch trials with status quo
269+
for _ in range(2):
270+
sobol_generator = get_sobol(search_space=experiment.search_space)
271+
trial = experiment.new_batch_trial(should_add_status_quo_arm=True)
272+
trial.add_generator_run(sobol_generator.gen(n=1))
273+
trial.mark_running(no_runner_required=True)
274+
experiment.attach_data(
275+
get_branin_data_batch(batch=trial, metrics=[*experiment.metrics.keys()])
276+
)
277+
trial.mark_completed()
278+
279+
analysis = Summary()
280+
card = analysis.compute(experiment=experiment)
281+
282+
with self.subTest("subtitle_indicates_relativization"):
283+
self.assertIn("relativized", card.subtitle.lower())
284+
285+
with self.subTest("metric_values_formatted_as_percentages"):
286+
metric_values = card.df["branin"].dropna()
287+
self.assertGreater(len(metric_values), 0)
288+
for val in metric_values:
289+
self.assertIsInstance(val, str)
290+
self.assertTrue(val.endswith("%"))
291+
292+
with self.subTest("relativization_calculation_correct"):
293+
raw_data = experiment.lookup_data().df
294+
sq_name = none_throws(experiment.status_quo).name
295+
trial_0_data = raw_data[raw_data["trial_index"] == 0]
296+
treatment_arm = [a for a in experiment.trials[0].arms if a.name != sq_name][
297+
0
298+
]
299+
300+
sq_val = trial_0_data[trial_0_data["arm_name"] == sq_name]["mean"].values[0]
301+
arm_val = trial_0_data[trial_0_data["arm_name"] == treatment_arm.name][
302+
"mean"
303+
].values[0]
304+
expected = ((arm_val - sq_val) / sq_val) * 100
305+
306+
actual = float(
307+
card.df[card.df["arm_name"] == treatment_arm.name]["branin"]
308+
.values[0]
309+
.rstrip("%")
310+
)
311+
self.assertAlmostEqual(actual, expected, places=1)

ax/core/experiment.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,7 @@ def to_df(
20392039
trial_indices: Iterable[int] | None = None,
20402040
trial_statuses: Sequence[TrialStatus] | None = None,
20412041
omit_empty_columns: bool = True,
2042+
relativize: bool = False,
20422043
) -> pd.DataFrame:
20432044
"""
20442045
High-level summary of the Experiment with one row per arm. Any values missing at
@@ -2060,10 +2061,32 @@ def to_df(
20602061
trial_indices: If specified, only include these trial indices.
20612062
omit_empty_columns: If True, omit columns where every value is None.
20622063
trial_status: If specified, only include trials with this status.
2064+
relativize: If True and:
2065+
* experiment has a status quo on all of its ``BatchTrial``-s
2066+
* OR a status quo trial among its ``Trial``-s,
2067+
, relativize metrics against the status quo.
20632068
"""
20642069

20652070
records = []
2066-
data_df = self.lookup_data(trial_indices=trial_indices).df
2071+
data = self.lookup_data(trial_indices=trial_indices)
2072+
2073+
# Relativize metrics if requested
2074+
if relativize:
2075+
if self.status_quo is None:
2076+
raise UserInputError(
2077+
"Attempting to relativize the experiment data, however, "
2078+
"the experiment status quo is None. Please set the experiment "
2079+
"status quo, or set `relativize` = False"
2080+
)
2081+
2082+
data_df = data.relativize(
2083+
status_quo_name=self.status_quo.name,
2084+
as_percent=True,
2085+
include_sq=True,
2086+
).df
2087+
else:
2088+
data_df = data.df
2089+
20672090
trials = (
20682091
self.get_trials_by_indices(trial_indices=trial_indices)
20692092
if trial_indices
@@ -2125,6 +2148,20 @@ def to_df(
21252148
df = pd.DataFrame(records)
21262149
if omit_empty_columns:
21272150
df = df.loc[:, df.notnull().any()]
2151+
2152+
# Format metric columns as percentages with 4 significant figures when
2153+
# relativized
2154+
if relativize:
2155+
for metric_name in self.metrics.keys():
2156+
if metric_name in df.columns:
2157+
df[metric_name] = df[metric_name].apply(
2158+
lambda x: (
2159+
f"{x:.4g}%"
2160+
if pd.notna(x) and x != 0.0
2161+
else ("0%" if pd.notna(x) else None)
2162+
)
2163+
)
2164+
21282165
return df
21292166

21302167
def add_auxiliary_experiment(

ax/core/tests/test_experiment.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,89 @@ def test_to_df(self) -> None:
16411641
)
16421642
self.assertTrue(df_completed.equals(expected_completed_df))
16431643

1644+
def test_to_df_with_relativize(self) -> None:
1645+
"""Test the relativize flag in to_df method with status quo."""
1646+
# Create an experiment with status quo and completed trials
1647+
experiment = get_branin_experiment(
1648+
with_status_quo=True, with_completed_batch=True
1649+
)
1650+
1651+
with self.subTest("without relativization"):
1652+
df_no_rel = experiment.to_df(relativize=False)
1653+
1654+
# Verify dataframe has expected structure
1655+
self.assertGreater(len(df_no_rel), 0)
1656+
self.assertIn("trial_index", df_no_rel.columns)
1657+
self.assertIn("arm_name", df_no_rel.columns)
1658+
1659+
# Find metric columns
1660+
metric_cols = [
1661+
col
1662+
for col in df_no_rel.columns
1663+
if col
1664+
not in ["trial_index", "arm_name", "trial_status", "name", "x1", "x2"]
1665+
]
1666+
self.assertGreater(len(metric_cols), 0, "Should have at least one metric")
1667+
1668+
# Verify metric values are numeric, not percentage strings
1669+
for metric_name in metric_cols:
1670+
values = df_no_rel[metric_name].dropna()
1671+
for val in values:
1672+
self.assertNotIsInstance(
1673+
val, str, "Non-relativized values should not be strings"
1674+
)
1675+
1676+
with self.subTest("with relativization"):
1677+
df_with_rel = experiment.to_df(relativize=True)
1678+
df_no_rel = experiment.to_df(relativize=False)
1679+
1680+
# Verify structure is preserved
1681+
self.assertEqual(len(df_with_rel), len(df_no_rel))
1682+
self.assertEqual(set(df_with_rel.columns), set(df_no_rel.columns))
1683+
1684+
# Find metric columns
1685+
metric_cols = [
1686+
col
1687+
for col in df_no_rel.columns
1688+
if col
1689+
not in ["trial_index", "arm_name", "trial_status", "name", "x1", "x2"]
1690+
]
1691+
1692+
# Verify relativization for each metric
1693+
self.assertIsNotNone(experiment.status_quo)
1694+
status_quo_name = experiment.status_quo.name
1695+
for metric_name in metric_cols:
1696+
# Status quo should be 0% after relativization (using .4g format)
1697+
sq_rel_values = df_with_rel[df_with_rel["arm_name"] == status_quo_name][
1698+
metric_name
1699+
]
1700+
for val in sq_rel_values.dropna():
1701+
self.assertEqual(
1702+
val, "0%", "Status quo should be relativized to 0%"
1703+
)
1704+
1705+
# Non-status-quo arms should have percentage strings
1706+
non_sq_rel_values = df_with_rel[
1707+
df_with_rel["arm_name"] != status_quo_name
1708+
][metric_name].dropna()
1709+
for val in non_sq_rel_values:
1710+
self.assertIsInstance(
1711+
val, str, "Relativized values should be strings"
1712+
)
1713+
self.assertTrue(
1714+
val.endswith("%"), "Relativized values should end with %"
1715+
)
1716+
1717+
# Verify at least one non-status-quo value is non-zero
1718+
has_nonzero = any(
1719+
float(v.rstrip("%")) != 0.0 for v in non_sq_rel_values
1720+
)
1721+
self.assertTrue(
1722+
has_nonzero,
1723+
"At least one non-status-quo arm should have non-zero "
1724+
"relativized value",
1725+
)
1726+
16441727

16451728
class ExperimentWithMapDataTest(TestCase):
16461729
def setUp(self) -> None:

0 commit comments

Comments
 (0)