33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
55
6- # pyre-strict
7-
86import numpy as np
97import pandas as pd
8+
109from ax .analysis .summary import Summary
1110from ax .api .client import Client
1211from ax .api .configs import RangeParameterConfig
1312from ax .core .base_trial import TrialStatus
13+ from ax .core .map_data import MapData
1414from ax .core .trial import Trial
1515from ax .exceptions .core import UserInputError
1616from ax .utils .common .testutils import TestCase
17- from ax .utils .testing .core_stubs import get_offline_experiments , get_online_experiments
17+ from ax .utils .testing .core_stubs import (
18+ get_branin_experiment_with_status_quo_trials ,
19+ get_offline_experiments ,
20+ get_online_experiments ,
21+ )
1822from pyre_extensions import assert_is_instance , none_throws
1923
2024
2125class TestSummary (TestCase ):
22- def test_compute (self ) -> None :
23- client = Client ()
24- client .configure_experiment (
26+ def setUp (self ) -> None :
27+ super ().setUp ()
28+ self .client = Client ()
29+ self .client .configure_experiment (
2530 name = "test_experiment" ,
2631 parameters = [
2732 RangeParameterConfig (
@@ -36,7 +41,10 @@ def test_compute(self) -> None:
3641 ),
3742 ],
3843 )
39- client .configure_optimization (objective = "foo, bar" )
44+ self .client .configure_optimization (objective = "foo, bar" )
45+
46+ def test_compute (self ) -> None :
47+ client = self .client
4048
4149 # Get two trials and fail one, giving us a ragged structure
4250 client .get_next_trials (max_trials = 2 )
@@ -142,23 +150,7 @@ def test_offline(self) -> None:
142150
143151 def test_trial_indices_filter (self ) -> None :
144152 """Test that Client.summarize correctly uses Summary."""
145- client = Client ()
146- client .configure_experiment (
147- name = "test_experiment" ,
148- parameters = [
149- RangeParameterConfig (
150- name = "x1" ,
151- parameter_type = "float" ,
152- bounds = (0 , 1 ),
153- ),
154- RangeParameterConfig (
155- name = "x2" ,
156- parameter_type = "float" ,
157- bounds = (0 , 1 ),
158- ),
159- ],
160- )
161- client .configure_optimization (objective = "foo" )
153+ client = self .client
162154
163155 # Get a trial
164156 client .get_next_trials (max_trials = 1 )
@@ -228,19 +220,7 @@ def test_trial_status_filter(self) -> None:
228220
229221 def test_default_excludes_stale_trials (self ) -> None :
230222 """Test that Summary defaults to excluding STALE trials."""
231- # Set up experiment with basic configuration
232- client = Client ()
233- client .configure_experiment (
234- name = "test_experiment" ,
235- parameters = [
236- RangeParameterConfig (
237- name = "x1" ,
238- parameter_type = "float" ,
239- bounds = (0 , 1 ),
240- ),
241- ],
242- )
243- client .configure_optimization (objective = "foo" )
223+ client = self .client
244224
245225 # Create 3 trials with different statuses to test default filtering behavior
246226 client .get_next_trials (max_trials = 3 )
@@ -275,3 +255,72 @@ def test_default_excludes_stale_trials(self) -> None:
275255 # Verify that no trials in the output have STALE status
276256 stale_statuses = card .df [card .df ["trial_status" ] == "STALE" ]
277257 self .assertEqual (len (stale_statuses ), 0 )
258+
259+ def test_metrics_relativized_with_status_quo (self ) -> None :
260+ """Test that Summary relativizes metrics by default when status
261+ quos are present."""
262+ # Use helper function that creates batch trials with status quo
263+ experiment = get_branin_experiment_with_status_quo_trials (num_sobol_trials = 2 )
264+
265+ analysis = Summary ()
266+ card = analysis .compute (experiment = experiment )
267+
268+ with self .subTest ("subtitle_indicates_relativization" ):
269+ self .assertIn ("relativized" , card .subtitle .lower ())
270+
271+ with self .subTest ("metric_values_formatted_as_percentages" ):
272+ metric_values = card .df ["branin" ].dropna ()
273+ self .assertGreater (len (metric_values ), 0 )
274+ for val in metric_values :
275+ self .assertIsInstance (val , str )
276+ self .assertTrue (val .endswith ("%" ))
277+
278+ with self .subTest ("relativization_calculation_correct" ):
279+ raw_data = experiment .lookup_data ().df
280+ sq_name = none_throws (experiment .status_quo ).name
281+ trial_0_data = raw_data [raw_data ["trial_index" ] == 0 ]
282+ treatment_arm = [a for a in experiment .trials [0 ].arms if a .name != sq_name ][
283+ 0
284+ ]
285+
286+ sq_val = trial_0_data [trial_0_data ["arm_name" ] == sq_name ]["mean" ].values [0 ]
287+ arm_val = trial_0_data [trial_0_data ["arm_name" ] == treatment_arm .name ][
288+ "mean"
289+ ].values [0 ]
290+ expected = ((arm_val - sq_val ) / sq_val ) * 100
291+
292+ actual = float (
293+ card .df [card .df ["arm_name" ] == treatment_arm .name ]["branin" ]
294+ .values [0 ]
295+ .rstrip ("%" )
296+ )
297+ self .assertAlmostEqual (actual , expected , places = 1 )
298+
299+ def test_mapdata_not_relativized (self ) -> None :
300+ """Test that Summary does not attempt relativization when data is MapData,
301+ even when status quo is present."""
302+ # Create an experiment with MapData
303+ experiment = get_branin_experiment_with_status_quo_trials (num_sobol_trials = 2 )
304+
305+ # Replace the experiment's data with MapData
306+ map_data = MapData (
307+ df = experiment .lookup_data ().df .assign (step = 1.0 ) # Add step column
308+ )
309+ # Store the MapData in the experiment
310+ for trial in experiment .trials .values ():
311+ trial_data = map_data .filter (trial_indices = [trial .index ])
312+ experiment .attach_data (trial_data , combine_with_last_data = False )
313+
314+ # Compute the summary
315+ analysis = Summary ()
316+ card = analysis .compute (experiment = experiment )
317+
318+ with self .subTest ("subtitle_does_not_indicate_relativization" ):
319+ self .assertNotIn ("relativized" , card .subtitle .lower ())
320+
321+ with self .subTest ("metric_values_not_formatted_as_percentages" ):
322+ metric_values = card .df ["branin" ].dropna ()
323+ self .assertGreater (len (metric_values ), 0 )
324+ # Values should be raw floats, not percentage strings
325+ for val in metric_values :
326+ self .assertIsInstance (val , (float , np .floating ))
0 commit comments