Skip to content

Commit 5a5173b

Browse files
Balandatfacebook-github-bot
authored andcommitted
MixedSingleTaskGP (#772)
Summary: Pull Request resolved: #772 A `MixedSingleTaskGP` that uses a combination of the categorical kernel and a kernel on the continuous inputs. Reviewed By: dme65 Differential Revision: D27419521 fbshipit-source-id: bb22623154b3b06d876d71605f1428db9d0f58cb
1 parent f402f2e commit 5a5173b

File tree

10 files changed

+512
-28
lines changed

10 files changed

+512
-28
lines changed

botorch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SingleTaskGP,
1616
)
1717
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
18+
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
1819
from botorch.models.higher_order_gp import HigherOrderGP
1920
from botorch.models.model_list_gp_regression import ModelListGP
2021
from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP
@@ -28,6 +29,7 @@
2829
"GenericDeterministicModel",
2930
"HeteroskedasticSingleTaskGP",
3031
"HigherOrderGP",
32+
"MixedSingleTaskGP",
3133
"ModelListGP",
3234
"MultiTaskGP",
3335
"PairwiseGP",

botorch/models/converter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.exceptions import UnsupportedError
1818
from botorch.models.gp_regression import FixedNoiseGP, HeteroskedasticSingleTaskGP
1919
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
20+
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
2021
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
2122
from botorch.models.model_list_gp_regression import ModelListGP
2223
from botorch.models.transforms.input import InputTransform
@@ -207,6 +208,10 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
207208
raise NotImplementedError(
208209
"Conversion of HeteroskedasticSingleTaskGP currently not supported."
209210
)
211+
if isinstance(batch_model, MixedSingleTaskGP):
212+
raise NotImplementedError(
213+
"Conversion of MixedSingleTaskGP currently not supported."
214+
)
210215
input_transform = getattr(batch_model, "input_transform", None)
211216
batch_sd = batch_model.state_dict()
212217

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from typing import Callable
10+
from typing import Dict, List, Optional, Any
11+
12+
import torch
13+
from botorch.exceptions.errors import UnsupportedError
14+
from botorch.models.gp_regression import SingleTaskGP
15+
from botorch.models.kernels.categorical import CategoricalKernel
16+
from botorch.models.transforms.input import InputTransform
17+
from botorch.models.transforms.outcome import OutcomeTransform
18+
from botorch.utils.containers import TrainingData
19+
from botorch.utils.transforms import normalize_indices
20+
from gpytorch.constraints import GreaterThan
21+
from gpytorch.kernels.kernel import Kernel
22+
from gpytorch.kernels.matern_kernel import MaternKernel
23+
from gpytorch.kernels.scale_kernel import ScaleKernel
24+
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
25+
from gpytorch.likelihoods.likelihood import Likelihood
26+
from gpytorch.priors import GammaPrior
27+
from torch import Tensor
28+
29+
30+
class MixedSingleTaskGP(SingleTaskGP):
31+
r"""A single-task exact GP model for mixed search spaces.
32+
33+
This model uses a kernel that combines a CategoricalKernel (based on
34+
Hamming distances) and a regular kernel into a kernel of the form
35+
36+
K((x1, c1), (x2, c2)) =
37+
K_cont_1(x1, x2) + K_cat_1(c1, c2) +
38+
K_cont_2(x1, x2) * K_cat_2(c1, c2)
39+
40+
where `xi` and `ci` are the continuous and categorical features of the
41+
input, respectively. The suffix `_i` indicates that we fit different
42+
lengthscales for the kernels in the sum and product terms.
43+
44+
Since this model does not provide gradients for the categorical features,
45+
optimization of the acquisition function will need to be performed in
46+
a mixed fashion, i.e., treating the categorical features properly as
47+
discrete optimization variables.
48+
"""
49+
50+
def __init__(
51+
self,
52+
train_X: Tensor,
53+
train_Y: Tensor,
54+
cat_dims: List[int],
55+
cont_kernel_factory: Optional[Callable[[int, List[int]], Kernel]] = None,
56+
likelihood: Optional[Likelihood] = None,
57+
outcome_transform: Optional[OutcomeTransform] = None, # TODO
58+
input_transform: Optional[InputTransform] = None, # TODO
59+
) -> None:
60+
r"""A single-task exact GP model supporting categorical parameters.
61+
62+
Args:
63+
train_X: A `batch_shape x n x d` tensor of training features.
64+
train_Y: A `batch_shape x n x m` tensor of training observations.
65+
cat_dims: A list of indices corresponding to the columns of
66+
the input `X` that should be considered categorical features.
67+
cont_kernel_factory: A method that accepts `ard_num_dims` and
68+
`active_dims` arguments and returns an instatiated GPyTorch
69+
`Kernel` object to be used as the ase kernel for the continuous
70+
dimensions. If omitted, this model uses a Matern-2.5 kernel as
71+
the kernel for the ordinal parameters.
72+
likelihood: A likelihood. If omitted, use a standard
73+
GaussianLikelihood with inferred noise level.
74+
# outcome_transform: An outcome transform that is applied to the
75+
# training data during instantiation and to the posterior during
76+
# inference (that is, the `Posterior` obtained by calling
77+
# `.posterior` on the model will be on the original scale).
78+
# input_transform: An input transform that is applied in the model's
79+
# forward pass.
80+
81+
Example:
82+
>>> train_X = torch.cat(
83+
[torch.rand(20, 2), torch.randint(3, (20, 1))], dim=-1)
84+
)
85+
>>> train_Y = (
86+
torch.sin(train_X[..., :-1]).sum(dim=1, keepdim=True)
87+
+ train_X[..., -1:]
88+
)
89+
>>> model = MixedSingleTaskGP(train_X, train_Y, cat_dims=[-1])
90+
"""
91+
if outcome_transform is not None:
92+
raise UnsupportedError("outcome transforms not yet supported")
93+
if input_transform is not None:
94+
raise UnsupportedError("input transforms not yet supported")
95+
if len(cat_dims) == 0:
96+
raise ValueError(
97+
"Must specify categorical dimensions for MixedSingleTaskGP"
98+
)
99+
input_batch_shape, aug_batch_shape = self.get_batch_dimensions(
100+
train_X=train_X, train_Y=train_Y
101+
)
102+
103+
if cont_kernel_factory is None:
104+
105+
def cont_kernel_factory(
106+
batch_shape: torch.Size, ard_num_dims: int, active_dims: List[int]
107+
) -> MaternKernel:
108+
return MaternKernel(
109+
nu=2.5,
110+
batch_shape=batch_shape,
111+
ard_num_dims=ard_num_dims,
112+
active_dims=active_dims,
113+
)
114+
115+
if likelihood is None:
116+
# This Gamma prior is quite close to the Horseshoe prior
117+
min_noise = 1e-5 if train_X.dtype == torch.float else 1e-6
118+
likelihood = GaussianLikelihood(
119+
batch_shape=aug_batch_shape,
120+
noise_constraint=GreaterThan(
121+
min_noise, transform=None, initial_value=1e-3
122+
),
123+
noise_prior=GammaPrior(0.9, 10.0),
124+
)
125+
126+
d = train_X.shape[-1]
127+
cat_dims = normalize_indices(indices=cat_dims, d=d)
128+
ord_dims = sorted(set(range(d)) - set(cat_dims))
129+
if len(ord_dims) == 0:
130+
covar_module = ScaleKernel(
131+
CategoricalKernel(
132+
batch_shape=aug_batch_shape,
133+
ard_num_dims=len(cat_dims),
134+
)
135+
)
136+
else:
137+
sum_kernel = ScaleKernel(
138+
cont_kernel_factory(
139+
batch_shape=aug_batch_shape,
140+
ard_num_dims=len(ord_dims),
141+
active_dims=ord_dims,
142+
)
143+
+ ScaleKernel(
144+
CategoricalKernel(
145+
batch_shape=aug_batch_shape,
146+
ard_num_dims=len(cat_dims),
147+
active_dims=cat_dims,
148+
)
149+
)
150+
)
151+
prod_kernel = ScaleKernel(
152+
cont_kernel_factory(
153+
batch_shape=aug_batch_shape,
154+
ard_num_dims=len(ord_dims),
155+
active_dims=ord_dims,
156+
)
157+
* CategoricalKernel(
158+
batch_shape=aug_batch_shape,
159+
ard_num_dims=len(cat_dims),
160+
active_dims=cat_dims,
161+
)
162+
)
163+
covar_module = sum_kernel + prod_kernel
164+
super().__init__(
165+
train_X=train_X,
166+
train_Y=train_Y,
167+
likelihood=likelihood,
168+
covar_module=covar_module,
169+
outcome_transform=outcome_transform,
170+
input_transform=input_transform,
171+
)
172+
173+
@classmethod
174+
def construct_inputs(
175+
cls, training_data: TrainingData, **kwargs: Any
176+
) -> Dict[str, Any]:
177+
r"""Construct kwargs for the `Model` from `TrainingData` and other options.
178+
179+
Args:
180+
training_data: `TrainingData` container with data for single outcome
181+
or for multiple outcomes for batched multi-output case.
182+
**kwargs: None expected for this class.
183+
"""
184+
return {
185+
"train_X": training_data.X,
186+
"train_Y": training_data.Y,
187+
"cat_dims": kwargs["categorical_features"],
188+
"likelihood": kwargs.get("likelihood"),
189+
}

botorch/utils/testing.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,24 +198,27 @@ def set_X_pending(self, X_pending: Optional[Tensor] = None):
198198

199199

200200
def _get_random_data(
201-
batch_shape: torch.Size, num_outputs: int, n: int = 10, **tkwargs
201+
batch_shape: torch.Size, m: int, d: int = 1, n: int = 10, **tkwargs
202202
) -> Tuple[Tensor, Tensor]:
203203
r"""Generate random data for testing pursposes.
204204
205205
Args:
206206
batch_shape: The batch shape of the data.
207-
num_outputs: The number of outputs.
207+
m: The number of outputs.
208+
d: The dimension of the input.
208209
n: The number of data points.
209210
tkwargs: `device` and `dtype` tensor constructor kwargs.
210211
211212
Returns:
212213
A tuple `(train_X, train_Y)` with randomly generated training data.
213214
"""
214215
rep_shape = batch_shape + torch.Size([1, 1])
215-
train_x = torch.linspace(0, 0.95, n, **tkwargs).unsqueeze(-1)
216-
train_x = train_x + 0.05 * torch.rand(n, 1, **tkwargs).repeat(rep_shape)
217-
train_y = torch.sin(train_x * (2 * math.pi))
218-
train_y = train_y + 0.2 * torch.randn(n, num_outputs, **tkwargs).repeat(rep_shape)
216+
train_x = torch.stack(
217+
[torch.linspace(0, 0.95, n, **tkwargs) for _ in range(d)], dim=-1
218+
)
219+
train_x = train_x + 0.05 * torch.rand_like(train_x).repeat(rep_shape)
220+
train_y = torch.sin(train_x[..., :1] * (2 * math.pi))
221+
train_y = train_y + 0.2 * torch.randn(n, m, **tkwargs).repeat(rep_shape)
219222
return train_x, train_y
220223

221224

sphinx/source/models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ Multi-Fidelity GP Regression Models
4444
.. automodule:: botorch.models.gp_regression_fidelity
4545
:members:
4646

47+
GP Regression Models for Mixed Parameter Spaces
48+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
49+
.. automodule:: botorch.models.gp_regression_mixed
50+
:members:
51+
4752
Model List GP Regression Models
4853
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4954
.. automodule:: botorch.models.model_list_gp_regression

test/models/test_gp_regression.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ class TestSingleTaskGP(BotorchTestCase):
3939
def _get_model_and_data(
4040
self, batch_shape, m, outcome_transform=None, input_transform=None, **tkwargs
4141
):
42-
train_X, train_Y = _get_random_data(
43-
batch_shape=batch_shape, num_outputs=m, **tkwargs
44-
)
42+
train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
4543
model_kwargs = {
4644
"train_X": train_X,
4745
"train_Y": train_Y,
@@ -174,7 +172,7 @@ def test_condition_on_observations(self):
174172
fant_shape = torch.Size([2])
175173
# fantasize at different input points
176174
X_fant, Y_fant = _get_random_data(
177-
fant_shape + batch_shape, m, n=3, **tkwargs
175+
batch_shape=fant_shape + batch_shape, m=m, n=3, **tkwargs
178176
)
179177
c_kwargs = (
180178
{"noise": torch.full_like(Y_fant, 0.01)}
@@ -319,9 +317,7 @@ class TestFixedNoiseGP(TestSingleTaskGP):
319317
def _get_model_and_data(
320318
self, batch_shape, m, outcome_transform=None, input_transform=None, **tkwargs
321319
):
322-
train_X, train_Y = _get_random_data(
323-
batch_shape=batch_shape, num_outputs=m, **tkwargs
324-
)
320+
train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
325321
model_kwargs = {
326322
"train_X": train_X,
327323
"train_Y": train_Y,
@@ -381,9 +377,7 @@ def _get_model_and_data(
381377
self, batch_shape, m, outcome_transform=None, input_transform=None, **tkwargs
382378
):
383379
with manual_seed(0):
384-
train_X, train_Y = _get_random_data(
385-
batch_shape=batch_shape, num_outputs=m, **tkwargs
386-
)
380+
train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
387381
train_Yvar = (0.1 + 0.1 * torch.rand_like(train_Y)) ** 2
388382
model_kwargs = {
389383
"train_X": train_X,

test/models/test_gp_regression_fidelity.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import itertools
88
import warnings
9+
from typing import Tuple
910

1011
import torch
1112
from botorch import fit_gpytorch_model
@@ -25,13 +26,18 @@
2526
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
2627
from gpytorch.means import ConstantMean
2728
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
29+
from torch import Tensor
2830

2931

30-
def _get_random_data_with_fidelity(batch_shape, m, n_fidelity, n=10, **tkwargs):
32+
def _get_random_data_with_fidelity(
33+
batch_shape: torch.Size, m: int, n_fidelity: int, d: int = 1, n: int = 10, **tkwargs
34+
) -> Tuple[Tensor, Tensor]:
3135
r"""Construct test data.
3236
For this test, by convention the trailing dimesions are the fidelity dimensions
3337
"""
34-
train_x, train_y = _get_random_data(batch_shape, m, n, **tkwargs)
38+
train_x, train_y = _get_random_data(
39+
batch_shape=batch_shape, m=m, d=d, n=n, **tkwargs
40+
)
3541
s = torch.rand(n, n_fidelity, **tkwargs).repeat(batch_shape + torch.Size([1, 1]))
3642
train_x = torch.cat((train_x, s), dim=-1)
3743
train_y = train_y + (1 - s).pow(2).sum(dim=-1).unsqueeze(-1)

0 commit comments

Comments
 (0)