33# This source code is licensed under the MIT license found in the 
44# LICENSE file in the root directory of this source tree. 
55
6- # pyre-strict 
7- 
86import  numpy  as  np 
97import  pandas  as  pd 
8+ 
109from  ax .analysis .summary  import  Summary 
1110from  ax .api .client  import  Client 
1211from  ax .api .configs  import  RangeParameterConfig 
1312from  ax .core .base_trial  import  TrialStatus 
13+ from  ax .core .map_data  import  MapData 
1414from  ax .core .trial  import  Trial 
1515from  ax .exceptions .core  import  UserInputError 
1616from  ax .utils .common .testutils  import  TestCase 
17- from  ax .utils .testing .core_stubs  import  get_offline_experiments , get_online_experiments 
17+ from  ax .utils .testing .core_stubs  import  (
18+     get_branin_experiment_with_status_quo_trials ,
19+     get_offline_experiments ,
20+     get_online_experiments ,
21+ )
1822from  pyre_extensions  import  assert_is_instance , none_throws 
1923
2024
2125class  TestSummary (TestCase ):
22-     def  test_compute (self ) ->  None :
23-         client  =  Client ()
24-         client .configure_experiment (
26+     def  setUp (self ) ->  None :
27+         super ().setUp ()
28+         self .client  =  Client ()
29+         self .client .configure_experiment (
2530            name = "test_experiment" ,
2631            parameters = [
2732                RangeParameterConfig (
@@ -36,7 +41,10 @@ def test_compute(self) -> None:
3641                ),
3742            ],
3843        )
39-         client .configure_optimization (objective = "foo, bar" )
44+         self .client .configure_optimization (objective = "foo, bar" )
45+ 
46+     def  test_compute (self ) ->  None :
47+         client  =  self .client 
4048
4149        # Get two trials and fail one, giving us a ragged structure 
4250        client .get_next_trials (max_trials = 2 )
@@ -142,23 +150,7 @@ def test_offline(self) -> None:
142150
143151    def  test_trial_indices_filter (self ) ->  None :
144152        """Test that Client.summarize correctly uses Summary.""" 
145-         client  =  Client ()
146-         client .configure_experiment (
147-             name = "test_experiment" ,
148-             parameters = [
149-                 RangeParameterConfig (
150-                     name = "x1" ,
151-                     parameter_type = "float" ,
152-                     bounds = (0 , 1 ),
153-                 ),
154-                 RangeParameterConfig (
155-                     name = "x2" ,
156-                     parameter_type = "float" ,
157-                     bounds = (0 , 1 ),
158-                 ),
159-             ],
160-         )
161-         client .configure_optimization (objective = "foo" )
153+         client  =  self .client 
162154
163155        # Get a trial 
164156        client .get_next_trials (max_trials = 1 )
@@ -228,19 +220,7 @@ def test_trial_status_filter(self) -> None:
228220
229221    def  test_default_excludes_stale_trials (self ) ->  None :
230222        """Test that Summary defaults to excluding STALE trials.""" 
231-         # Set up experiment with basic configuration 
232-         client  =  Client ()
233-         client .configure_experiment (
234-             name = "test_experiment" ,
235-             parameters = [
236-                 RangeParameterConfig (
237-                     name = "x1" ,
238-                     parameter_type = "float" ,
239-                     bounds = (0 , 1 ),
240-                 ),
241-             ],
242-         )
243-         client .configure_optimization (objective = "foo" )
223+         client  =  self .client 
244224
245225        # Create 3 trials with different statuses to test default filtering behavior 
246226        client .get_next_trials (max_trials = 3 )
@@ -275,3 +255,72 @@ def test_default_excludes_stale_trials(self) -> None:
275255        # Verify that no trials in the output have STALE status 
276256        stale_statuses  =  card .df [card .df ["trial_status" ] ==  "STALE" ]
277257        self .assertEqual (len (stale_statuses ), 0 )
258+ 
259+     def  test_metrics_relativized_with_status_quo (self ) ->  None :
260+         """Test that Summary relativizes metrics by default when status 
261+         quos are present.""" 
262+         # Use helper function that creates batch trials with status quo 
263+         experiment  =  get_branin_experiment_with_status_quo_trials (num_sobol_trials = 2 )
264+ 
265+         analysis  =  Summary ()
266+         card  =  analysis .compute (experiment = experiment )
267+ 
268+         with  self .subTest ("subtitle_indicates_relativization" ):
269+             self .assertIn ("relativized" , card .subtitle .lower ())
270+ 
271+         with  self .subTest ("metric_values_formatted_as_percentages" ):
272+             metric_values  =  card .df ["branin" ].dropna ()
273+             self .assertGreater (len (metric_values ), 0 )
274+             for  val  in  metric_values :
275+                 self .assertIsInstance (val , str )
276+                 self .assertTrue (val .endswith ("%" ))
277+ 
278+         with  self .subTest ("relativization_calculation_correct" ):
279+             raw_data  =  experiment .lookup_data ().df 
280+             sq_name  =  none_throws (experiment .status_quo ).name 
281+             trial_0_data  =  raw_data [raw_data ["trial_index" ] ==  0 ]
282+             treatment_arm  =  [a  for  a  in  experiment .trials [0 ].arms  if  a .name  !=  sq_name ][
283+                 0 
284+             ]
285+ 
286+             sq_val  =  trial_0_data [trial_0_data ["arm_name" ] ==  sq_name ]["mean" ].values [0 ]
287+             arm_val  =  trial_0_data [trial_0_data ["arm_name" ] ==  treatment_arm .name ][
288+                 "mean" 
289+             ].values [0 ]
290+             expected  =  ((arm_val  -  sq_val ) /  sq_val ) *  100 
291+ 
292+             actual  =  float (
293+                 card .df [card .df ["arm_name" ] ==  treatment_arm .name ]["branin" ]
294+                 .values [0 ]
295+                 .rstrip ("%" )
296+             )
297+             self .assertAlmostEqual (actual , expected , places = 1 )
298+ 
299+     def  test_mapdata_not_relativized (self ) ->  None :
300+         """Test that Summary does not attempt relativization when data is MapData, 
301+         even when status quo is present.""" 
302+         # Create an experiment with MapData 
303+         experiment  =  get_branin_experiment_with_status_quo_trials (num_sobol_trials = 2 )
304+ 
305+         # Replace the experiment's data with MapData 
306+         map_data  =  MapData (
307+             df = experiment .lookup_data ().df .assign (step = 1.0 )  # Add step column 
308+         )
309+         # Store the MapData in the experiment 
310+         for  trial  in  experiment .trials .values ():
311+             trial_data  =  map_data .filter (trial_indices = [trial .index ])
312+             experiment .attach_data (trial_data , combine_with_last_data = False )
313+ 
314+         # Compute the summary 
315+         analysis  =  Summary ()
316+         card  =  analysis .compute (experiment = experiment )
317+ 
318+         with  self .subTest ("subtitle_does_not_indicate_relativization" ):
319+             self .assertNotIn ("relativized" , card .subtitle .lower ())
320+ 
321+         with  self .subTest ("metric_values_not_formatted_as_percentages" ):
322+             metric_values  =  card .df ["branin" ].dropna ()
323+             self .assertGreater (len (metric_values ), 0 )
324+             # Values should be raw floats, not percentage strings 
325+             for  val  in  metric_values :
326+                 self .assertIsInstance (val , (float , np .floating ))
0 commit comments