33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
55
6- # pyre-strict
7-
86import numpy as np
97import pandas as pd
8+
109from ax .analysis .summary import Summary
1110from ax .api .client import Client
1211from ax .api .configs import RangeParameterConfig
1312from ax .core .base_trial import TrialStatus
1413from ax .core .trial import Trial
1514from ax .exceptions .core import UserInputError
1615from ax .utils .common .testutils import TestCase
17- from ax .utils .testing .core_stubs import get_offline_experiments , get_online_experiments
16+ from ax .utils .testing .core_stubs import (
17+ get_branin_experiment_with_status_quo_trials ,
18+ get_offline_experiments ,
19+ get_online_experiments ,
20+ )
1821from pyre_extensions import assert_is_instance , none_throws
1922
2023
2124class TestSummary (TestCase ):
22- def test_compute (self ) -> None :
23- client = Client ()
24- client .configure_experiment (
25+ def setUp (self ) -> None :
26+ super ().setUp ()
27+ self .client = Client ()
28+ self .client .configure_experiment (
2529 name = "test_experiment" ,
2630 parameters = [
2731 RangeParameterConfig (
@@ -36,7 +40,10 @@ def test_compute(self) -> None:
3640 ),
3741 ],
3842 )
39- client .configure_optimization (objective = "foo, bar" )
43+ self .client .configure_optimization (objective = "foo, bar" )
44+
45+ def test_compute (self ) -> None :
46+ client = self .client
4047
4148 # Get two trials and fail one, giving us a ragged structure
4249 client .get_next_trials (max_trials = 2 )
@@ -142,23 +149,7 @@ def test_offline(self) -> None:
142149
143150 def test_trial_indices_filter (self ) -> None :
144151 """Test that Client.summarize correctly uses Summary."""
145- client = Client ()
146- client .configure_experiment (
147- name = "test_experiment" ,
148- parameters = [
149- RangeParameterConfig (
150- name = "x1" ,
151- parameter_type = "float" ,
152- bounds = (0 , 1 ),
153- ),
154- RangeParameterConfig (
155- name = "x2" ,
156- parameter_type = "float" ,
157- bounds = (0 , 1 ),
158- ),
159- ],
160- )
161- client .configure_optimization (objective = "foo" )
152+ client = self .client
162153
163154 # Get a trial
164155 client .get_next_trials (max_trials = 1 )
@@ -228,19 +219,7 @@ def test_trial_status_filter(self) -> None:
228219
229220 def test_default_excludes_stale_trials (self ) -> None :
230221 """Test that Summary defaults to excluding STALE trials."""
231- # Set up experiment with basic configuration
232- client = Client ()
233- client .configure_experiment (
234- name = "test_experiment" ,
235- parameters = [
236- RangeParameterConfig (
237- name = "x1" ,
238- parameter_type = "float" ,
239- bounds = (0 , 1 ),
240- ),
241- ],
242- )
243- client .configure_optimization (objective = "foo" )
222+ client = self .client
244223
245224 # Create 3 trials with different statuses to test default filtering behavior
246225 client .get_next_trials (max_trials = 3 )
@@ -275,3 +254,43 @@ def test_default_excludes_stale_trials(self) -> None:
275254 # Verify that no trials in the output have STALE status
276255 stale_statuses = card .df [card .df ["trial_status" ] == "STALE" ]
277256 self .assertEqual (len (stale_statuses ), 0 )
257+
258+ def test_metrics_relativized_with_status_quo (self ) -> None :
259+ """Test that Summary relativizes metrics by default when status
260+ quos are present."""
261+ # Use helper function that creates batch trials with status quo
262+ experiment = get_branin_experiment_with_status_quo_trials (num_sobol_trials = 2 )
263+
264+ analysis = Summary ()
265+ card = analysis .compute (experiment = experiment )
266+
267+ with self .subTest ("subtitle_indicates_relativization" ):
268+ self .assertIn ("relativized" , card .subtitle .lower ())
269+
270+ with self .subTest ("metric_values_formatted_as_percentages" ):
271+ metric_values = card .df ["branin" ].dropna ()
272+ self .assertGreater (len (metric_values ), 0 )
273+ for val in metric_values :
274+ self .assertIsInstance (val , str )
275+ self .assertTrue (val .endswith ("%" ))
276+
277+ with self .subTest ("relativization_calculation_correct" ):
278+ raw_data = experiment .lookup_data ().df
279+ sq_name = none_throws (experiment .status_quo ).name
280+ trial_0_data = raw_data [raw_data ["trial_index" ] == 0 ]
281+ treatment_arm = [a for a in experiment .trials [0 ].arms if a .name != sq_name ][
282+ 0
283+ ]
284+
285+ sq_val = trial_0_data [trial_0_data ["arm_name" ] == sq_name ]["mean" ].values [0 ]
286+ arm_val = trial_0_data [trial_0_data ["arm_name" ] == treatment_arm .name ][
287+ "mean"
288+ ].values [0 ]
289+ expected = ((arm_val - sq_val ) / sq_val ) * 100
290+
291+ actual = float (
292+ card .df [card .df ["arm_name" ] == treatment_arm .name ]["branin" ]
293+ .values [0 ]
294+ .rstrip ("%" )
295+ )
296+ self .assertAlmostEqual (actual , expected , places = 1 )
0 commit comments