77
88import numpy as np
99import pandas as pd
10+
1011from ax .analysis .summary import Summary
1112from ax .api .client import Client
1213from ax .api .configs import RangeParameterConfig
1314from ax .core .base_trial import TrialStatus
15+ from ax .core .map_data import MapData
1416from ax .core .trial import Trial
1517from ax .exceptions .core import UserInputError
1618from ax .utils .common .testutils import TestCase
17- from ax .utils .testing .core_stubs import get_offline_experiments , get_online_experiments
19+ from ax .utils .testing .core_stubs import (
20+ get_branin_experiment_with_status_quo_trials ,
21+ get_offline_experiments ,
22+ get_online_experiments ,
23+ )
1824from pyre_extensions import assert_is_instance , none_throws
1925
2026
2127class TestSummary (TestCase ):
22- def test_compute (self ) -> None :
23- client = Client ()
24- client .configure_experiment (
28+ def setUp (self ) -> None :
29+ super ().setUp ()
30+ self .client = Client ()
31+ self .client .configure_experiment (
2532 name = "test_experiment" ,
2633 parameters = [
2734 RangeParameterConfig (
@@ -36,7 +43,10 @@ def test_compute(self) -> None:
3643 ),
3744 ],
3845 )
39- client .configure_optimization (objective = "foo, bar" )
46+ self .client .configure_optimization (objective = "foo, bar" )
47+
48+ def test_compute (self ) -> None :
49+ client = self .client
4050
4151 # Get two trials and fail one, giving us a ragged structure
4252 client .get_next_trials (max_trials = 2 )
@@ -142,23 +152,7 @@ def test_offline(self) -> None:
142152
143153 def test_trial_indices_filter (self ) -> None :
144154 """Test that Client.summarize correctly uses Summary."""
145- client = Client ()
146- client .configure_experiment (
147- name = "test_experiment" ,
148- parameters = [
149- RangeParameterConfig (
150- name = "x1" ,
151- parameter_type = "float" ,
152- bounds = (0 , 1 ),
153- ),
154- RangeParameterConfig (
155- name = "x2" ,
156- parameter_type = "float" ,
157- bounds = (0 , 1 ),
158- ),
159- ],
160- )
161- client .configure_optimization (objective = "foo" )
155+ client = self .client
162156
163157 # Get a trial
164158 client .get_next_trials (max_trials = 1 )
@@ -228,19 +222,7 @@ def test_trial_status_filter(self) -> None:
228222
229223 def test_default_excludes_stale_trials (self ) -> None :
230224 """Test that Summary defaults to excluding STALE trials."""
231- # Set up experiment with basic configuration
232- client = Client ()
233- client .configure_experiment (
234- name = "test_experiment" ,
235- parameters = [
236- RangeParameterConfig (
237- name = "x1" ,
238- parameter_type = "float" ,
239- bounds = (0 , 1 ),
240- ),
241- ],
242- )
243- client .configure_optimization (objective = "foo" )
225+ client = self .client
244226
245227 # Create 3 trials with different statuses to test default filtering behavior
246228 client .get_next_trials (max_trials = 3 )
@@ -275,3 +257,72 @@ def test_default_excludes_stale_trials(self) -> None:
275257 # Verify that no trials in the output have STALE status
276258 stale_statuses = card .df [card .df ["trial_status" ] == "STALE" ]
277259 self .assertEqual (len (stale_statuses ), 0 )
260+
261+ def test_metrics_relativized_with_status_quo (self ) -> None :
262+ """Test that Summary relativizes metrics by default when status
263+ quos are present."""
264+ # Use helper function that creates batch trials with status quo
265+ experiment = get_branin_experiment_with_status_quo_trials (num_sobol_trials = 2 )
266+
267+ analysis = Summary ()
268+ card = analysis .compute (experiment = experiment )
269+
270+ with self .subTest ("subtitle_indicates_relativization" ):
271+ self .assertIn ("relativized" , card .subtitle .lower ())
272+
273+ with self .subTest ("metric_values_formatted_as_percentages" ):
274+ metric_values = card .df ["branin" ].dropna ()
275+ self .assertGreater (len (metric_values ), 0 )
276+ for val in metric_values :
277+ self .assertIsInstance (val , str )
278+ self .assertTrue (val .endswith ("%" ))
279+
280+ with self .subTest ("relativization_calculation_correct" ):
281+ raw_data = experiment .lookup_data ().df
282+ sq_name = none_throws (experiment .status_quo ).name
283+ trial_0_data = raw_data [raw_data ["trial_index" ] == 0 ]
284+ treatment_arm = [a for a in experiment .trials [0 ].arms if a .name != sq_name ][
285+ 0
286+ ]
287+
288+ sq_val = trial_0_data [trial_0_data ["arm_name" ] == sq_name ]["mean" ].values [0 ]
289+ arm_val = trial_0_data [trial_0_data ["arm_name" ] == treatment_arm .name ][
290+ "mean"
291+ ].values [0 ]
292+ expected = ((arm_val - sq_val ) / sq_val ) * 100
293+
294+ actual = float (
295+ card .df [card .df ["arm_name" ] == treatment_arm .name ]["branin" ]
296+ .values [0 ]
297+ .rstrip ("%" )
298+ )
299+ self .assertAlmostEqual (actual , expected , places = 1 )
300+
301+ def test_mapdata_not_relativized (self ) -> None :
302+ """Test that Summary does not attempt relativization when data is MapData,
303+ even when status quo is present."""
304+ # Create an experiment with MapData
305+ experiment = get_branin_experiment_with_status_quo_trials (num_sobol_trials = 2 )
306+
307+ # Replace the experiment's data with MapData
308+ map_data = MapData (
309+ df = experiment .lookup_data ().df .assign (step = 1.0 ) # Add step column
310+ )
311+ # Store the MapData in the experiment
312+ for trial in experiment .trials .values ():
313+ trial_data = map_data .filter (trial_indices = [trial .index ])
314+ experiment .attach_data (trial_data , combine_with_last_data = False )
315+
316+ # Compute the summary
317+ analysis = Summary ()
318+ card = analysis .compute (experiment = experiment )
319+
320+ with self .subTest ("subtitle_does_not_indicate_relativization" ):
321+ self .assertNotIn ("relativized" , card .subtitle .lower ())
322+
323+ with self .subTest ("metric_values_not_formatted_as_percentages" ):
324+ metric_values = card .df ["branin" ].dropna ()
325+ self .assertGreater (len (metric_values ), 0 )
326+ # Values should be raw floats, not percentage strings
327+ for val in metric_values :
328+ self .assertIsInstance (val , (float , np .floating ))
0 commit comments