|
16 | 16 | from ax.core.observation import ObservationFeatures |
17 | 17 | from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter |
18 | 18 | from ax.core.parameter_constraint import ParameterConstraint |
19 | | -from ax.core.search_space import RobustSearchSpace, SearchSpace |
| 19 | +from ax.core.search_space import HierarchicalSearchSpace, RobustSearchSpace, SearchSpace |
20 | 20 | from ax.utils.common.testutils import TestCase |
21 | 21 | from ax.utils.testing.core_stubs import ( |
22 | 22 | get_experiment_with_observations, |
@@ -80,6 +80,35 @@ def test_Init(self) -> None: |
80 | 80 | self.assertEqual(list(self.t.encoded_parameters.keys()), ["b"]) |
81 | 81 | self.assertEqual(list(self.t2.encoded_parameters.keys()), ["b"]) |
82 | 82 |
|
| 83 | + hierarchical_search_space = HierarchicalSearchSpace( |
| 84 | + parameters=[ |
| 85 | + ChoiceParameter( |
| 86 | + name="x0", |
| 87 | + parameter_type=ParameterType.STRING, |
| 88 | + values=["a", "b", "c"], |
| 89 | + is_ordered=False, |
| 90 | + dependents={"a": ["x1"], "b": ["x2"], "c": ["x3"]}, |
| 91 | + ), |
| 92 | + RangeParameter( |
| 93 | + name="x1", lower=0.0, upper=1.0, parameter_type=ParameterType.FLOAT |
| 94 | + ), |
| 95 | + RangeParameter( |
| 96 | + name="x2", lower=0.0, upper=1.0, parameter_type=ParameterType.FLOAT |
| 97 | + ), |
| 98 | + RangeParameter( |
| 99 | + name="x3", lower=0.0, upper=1.0, parameter_type=ParameterType.FLOAT |
| 100 | + ), |
| 101 | + ] |
| 102 | + ) |
| 103 | + with self.assertRaisesRegex( |
| 104 | + expected_exception=RuntimeError, |
| 105 | + expected_regex=( |
| 106 | + "Attempt to one-hot encode a hierarchical parameter x0. " |
| 107 | + "This is not supported yet." |
| 108 | + ), |
| 109 | + ): |
| 110 | + OneHot(search_space=hierarchical_search_space.clone()) |
| 111 | + |
83 | 112 | def test_TransformObservationFeatures(self) -> None: |
84 | 113 | observation_features = [self.observation_features] |
85 | 114 | obs_ft2 = deepcopy(observation_features) |
|
0 commit comments