Skip to content

Commit 129ecaf

Browse files
kayweenfacebook-github-bot
authored andcommitted
Raise errors for unsupported hierarchical search space functionalities (#4374)
Summary: Raise errors for unsupported HSS functionalities. Differential Revision: D83626637
1 parent 146168f commit 129ecaf

File tree

6 files changed

+96
-3
lines changed

6 files changed

+96
-3
lines changed

ax/adapter/transforms/choice_encode.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ax.adapter.transforms.utils import ClosestLookupDict, construct_new_search_space
1414
from ax.core.observation import ObservationFeatures
1515
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
16-
from ax.core.search_space import SearchSpace
16+
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
1717
from ax.core.types import TParamValue
1818
from ax.generators.types import TConfig
1919

@@ -69,6 +69,15 @@ def __init__(
6969
zip(transformed_values, p.values)
7070
)
7171

72+
if isinstance(search_space, HierarchicalSearchSpace):
73+
for p_name, p in search_space.parameters.items():
74+
if p.is_hierarchical and p.parameter_type == ParameterType.FLOAT:
75+
raise RuntimeError(
76+
f"{p_name} is a float hierarchical parameter. However, "
77+
"hierarchical parameters have to be integer-valued after Ax "
78+
f"transformations, but {self.__class__.__name__} would skip it."
79+
)
80+
7281
def transform_observation_features(
7382
self, observation_features: list[ObservationFeatures]
7483
) -> list[ObservationFeatures]:

ax/adapter/transforms/one_hot.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ax.adapter.transforms.utils import construct_new_search_space
1717
from ax.core.observation import ObservationFeatures
1818
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
19-
from ax.core.search_space import SearchSpace
19+
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
2020
from ax.core.types import TParameterization, TParamValue
2121
from ax.generators.types import TConfig
2222
from pyre_extensions import assert_is_instance
@@ -119,6 +119,14 @@ def __init__(
119119
f"{p.name}{OH_PARAM_INFIX}_{i}" for i in range(encoded_len)
120120
]
121121

122+
if isinstance(search_space, HierarchicalSearchSpace):
123+
for p_name, p in search_space.parameters.items():
124+
if p.is_hierarchical and p_name in self.encoded_parameters:
125+
raise RuntimeError(
126+
f"Attempt to one-hot encode a hierarchical parameter {p_name}. "
127+
"This is not supported yet."
128+
)
129+
122130
def transform_observation_features(
123131
self, observation_features: list[ObservationFeatures]
124132
) -> list[ObservationFeatures]:

ax/adapter/transforms/tests/test_choice_encode_transform.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,28 @@ def setUp(self) -> None:
9797
def test_init(self) -> None:
9898
self.assertEqual(list(self.t.encoded_parameters.keys()), ["d", "e"])
9999

100+
hierarchical_search_space = HierarchicalSearchSpace(
101+
parameters=[
102+
ChoiceParameter(
103+
name="x0",
104+
parameter_type=ParameterType.FLOAT,
105+
values=[1.0, 2.0],
106+
dependents={1.0: ["x1"], 2.0: ["x2"]},
107+
),
108+
RangeParameter(
109+
name="x1", lower=0.0, upper=1.0, parameter_type=ParameterType.FLOAT
110+
),
111+
RangeParameter(
112+
name="x2", lower=0.0, upper=1.0, parameter_type=ParameterType.FLOAT
113+
),
114+
]
115+
)
116+
with self.assertRaisesRegex(
117+
expected_exception=RuntimeError,
118+
expected_regex=("x0 is a float hierarchical parameter*"),
119+
):
120+
ChoiceToNumericChoice(search_space=hierarchical_search_space.clone())
121+
100122
def test_transform_observation_features(self) -> None:
101123
observation_features = self.observation_features
102124
obs_ft2 = deepcopy(observation_features)

ax/adapter/transforms/tests/test_one_hot_transform.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ax.core.observation import ObservationFeatures
1717
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
1818
from ax.core.parameter_constraint import ParameterConstraint
19-
from ax.core.search_space import RobustSearchSpace, SearchSpace
19+
from ax.core.search_space import HierarchicalSearchSpace, RobustSearchSpace, SearchSpace
2020
from ax.utils.common.testutils import TestCase
2121
from ax.utils.testing.core_stubs import (
2222
get_experiment_with_observations,
@@ -80,6 +80,35 @@ def test_Init(self) -> None:
8080
self.assertEqual(list(self.t.encoded_parameters.keys()), ["b"])
8181
self.assertEqual(list(self.t2.encoded_parameters.keys()), ["b"])
8282

83+
hierarchical_search_space = HierarchicalSearchSpace(
84+
parameters=[
85+
ChoiceParameter(
86+
name="x0",
87+
parameter_type=ParameterType.STRING,
88+
values=["a", "b", "c"],
89+
is_ordered=False,
90+
dependents={"a": ["x1"], "b": ["x2"], "c": ["x3"]},
91+
),
92+
RangeParameter(
93+
name="x1", lower=0.0, upper=1.0, parameter_type=ParameterType.FLOAT
94+
),
95+
RangeParameter(
96+
name="x2", lower=0.0, upper=1.0, parameter_type=ParameterType.FLOAT
97+
),
98+
RangeParameter(
99+
name="x3", lower=0.0, upper=1.0, parameter_type=ParameterType.FLOAT
100+
),
101+
]
102+
)
103+
with self.assertRaisesRegex(
104+
expected_exception=RuntimeError,
105+
expected_regex=(
106+
"Attempt to one-hot encode a hierarchical parameter x0. "
107+
"This is not supported yet."
108+
),
109+
):
110+
OneHot(search_space=hierarchical_search_space.clone())
111+
83112
def test_TransformObservationFeatures(self) -> None:
84113
observation_features = [self.observation_features]
85114
obs_ft2 = deepcopy(observation_features)

ax/core/parameter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,13 @@ def __init__(
724724
f"Value {value} in `dependents` "
725725
f"argument is not among the parameter values: {self.values}."
726726
)
727+
if len(dependents) >= 3:
728+
raise NotImplementedError(
729+
f"Found {len(dependents)} differet values in the hierarchical "
730+
f"parameter {name}. However, as of now, only <= 2 different values "
731+
"are supported for hierarchical parameters."
732+
)
733+
727734
# NOTE: We don't need to check that dependent parameters actually exist as
728735
# that is done in `HierarchicalSearchSpace` constructor.
729736
self._dependents = dependents

ax/core/tests/test_parameter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,24 @@ def test_Hierarchical(self) -> None:
479479
dependents={"c": "other_param"},
480480
)
481481

482+
# Test case that raises an error when there are more than 2 values in a
483+
# hierarchical choice parameter. This test case should be removed once we add
484+
# support for it.
485+
with self.assertRaisesRegex(
486+
NotImplementedError,
487+
"Found 3 differet values in the hierarchical parameter x.*",
488+
):
489+
ChoiceParameter(
490+
name="x",
491+
parameter_type=ParameterType.STRING,
492+
values=["a", "b", "c"],
493+
dependents={
494+
"a": ["1st_child"],
495+
"b": ["2nd_child"],
496+
"c": ["3rd_child"],
497+
},
498+
)
499+
482500
def test_available_flags(self) -> None:
483501
choice_flags = [
484502
"is_fidelity",

0 commit comments

Comments
 (0)