|
75 | 75 | ) |
76 | 76 | from ax.utils.testing.mock import mock_botorch_optimize |
77 | 77 | from pandas.testing import assert_frame_equal |
78 | | -from pyre_extensions import assert_is_instance |
| 78 | +from pyre_extensions import assert_is_instance, none_throws |
79 | 79 |
|
80 | 80 | DUMMY_RUN_METADATA_KEY_1 = "test_run_metadata_key_1" |
81 | 81 | DUMMY_RUN_METADATA_KEY_2 = "test_run_metadata_key_2" |
@@ -471,7 +471,7 @@ def test_StatusQuoSetter(self) -> None: |
471 | 471 | sq_parameters["w"] = 3.5 |
472 | 472 | self.experiment.status_quo = Arm(sq_parameters) |
473 | 473 | self.assertEqual(self.experiment.status_quo.parameters["w"], 3.5) |
474 | | - self.assertEqual(self.experiment.status_quo.name, "status_quo_e0") |
| 474 | + self.assertEqual(none_throws(self.experiment.status_quo).name, "status_quo_e0") |
475 | 475 |
|
476 | 476 | # Verify all None values |
477 | 477 | self.experiment.status_quo = Arm({n: None for n in sq_parameters.keys()}) |
@@ -1640,6 +1640,66 @@ def test_to_df(self) -> None: |
1640 | 1640 | ) |
1641 | 1641 | self.assertTrue(df_completed.equals(expected_completed_df)) |
1642 | 1642 |
|
| 1643 | + def test_to_df_with_relativize(self) -> None: |
| 1644 | + """Test the relativize flag in to_df method with status quo.""" |
| 1645 | + # Create an experiment with status quo and completed trials |
| 1646 | + experiment = get_branin_experiment(with_status_quo=True) |
| 1647 | + |
| 1648 | + # Create two completed trials |
| 1649 | + for _ in range(2): |
| 1650 | + sobol_run = get_sobol(search_space=experiment.search_space).gen(n=1) |
| 1651 | + trial = experiment.new_trial(generator_run=sobol_run) |
| 1652 | + trial.mark_running(no_runner_required=True) |
| 1653 | + trial.mark_completed() |
| 1654 | + |
| 1655 | + # Fetch and add status quo data |
| 1656 | + experiment.fetch_data() |
| 1657 | + sq_data = Data( |
| 1658 | + df=pd.DataFrame( |
| 1659 | + [ |
| 1660 | + { |
| 1661 | + "trial_index": i, |
| 1662 | + "arm_name": "status_quo", |
| 1663 | + "metric_name": "branin", |
| 1664 | + "metric_signature": "branin", |
| 1665 | + "mean": 10.0, |
| 1666 | + "sem": 0.1, |
| 1667 | + } |
| 1668 | + for i in range(2) |
| 1669 | + ] |
| 1670 | + ) |
| 1671 | + ) |
| 1672 | + experiment.attach_data(sq_data) |
| 1673 | + |
| 1674 | + # Test without relativization |
| 1675 | + df_no_rel = experiment.to_df(relativize=False) |
| 1676 | + |
| 1677 | + # Test with relativization |
| 1678 | + df_with_rel = experiment.to_df(relativize=True) |
| 1679 | + |
| 1680 | + # Basic structure should be the same |
| 1681 | + self.assertEqual(len(df_with_rel), len(df_no_rel)) |
| 1682 | + self.assertEqual(set(df_with_rel.columns), set(df_no_rel.columns)) |
| 1683 | + |
| 1684 | + # Find metric columns and verify relativization occurred |
| 1685 | + metric_cols = [ |
| 1686 | + col |
| 1687 | + for col in df_no_rel.columns |
| 1688 | + if col |
| 1689 | + not in ["trial_index", "arm_name", "trial_status", "name", "x1", "x2"] |
| 1690 | + ] |
| 1691 | + |
| 1692 | + if metric_cols: |
| 1693 | + metric_name = metric_cols[0] |
| 1694 | + orig_values = df_no_rel[metric_name].dropna() |
| 1695 | + rel_values = df_with_rel[metric_name].dropna() |
| 1696 | + |
| 1697 | + # Values should change for non-status-quo trials |
| 1698 | + non_sq_changed = any( |
| 1699 | + abs(o - r) > 1e-10 for o, r in zip(orig_values, rel_values) if o != 10.0 |
| 1700 | + ) |
| 1701 | + self.assertTrue(non_sq_changed, "Relativization should change some values") |
| 1702 | + |
1643 | 1703 |
|
1644 | 1704 | class ExperimentWithMapDataTest(TestCase): |
1645 | 1705 | def setUp(self) -> None: |
|
0 commit comments