Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
)
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.utils.multitask import separate_mtmvn
from botorch.utils.transforms import is_ensemble
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator
from torch import Tensor

if TYPE_CHECKING:
Expand Down Expand Up @@ -101,6 +103,7 @@ def _validate_tensor_args(
"following error would have been raised with strict enforcement: "
f"{message}",
BotorchTensorDimensionWarning,
stacklevel=2,
)
# Yvar may not have the same batch dimensions, but the trailing dimensions
# of Yvar should be the same as the trailing dimensions of Y.
Expand Down Expand Up @@ -559,7 +562,8 @@ class ModelListGPyTorchModel(ModelList, GPyTorchModel, ABC):
r"""Abstract base class for models based on multi-output GPyTorch models.

This is meant to be used with a gpytorch ModelList wrapper for independent
evaluation of submodels.
evaluation of submodels. Those submodels can themselves be multi-output
models, in which case the task covariances will be ignored.

:meta private:
"""
Expand All @@ -582,7 +586,7 @@ def batch_shape(self) -> torch.Size:
)
try:
broadcast_shape = torch.broadcast_shapes(*batch_shapes)
warnings.warn(msg + ". Broadcasting batch shapes.")
warnings.warn(msg + ". Broadcasting batch shapes.", stacklevel=2)
return broadcast_shape
except RuntimeError:
raise NotImplementedError(msg + " that are not broadcastble.")
Expand All @@ -598,6 +602,9 @@ def posterior(
**kwargs: Any,
) -> Union[GPyTorchPosterior, PosteriorList]:
r"""Computes the posterior over model outputs at the provided points.
If any model returns a MultitaskMultivariateNormal posterior, then that
will be split into individual MVNs per task, with inter-task covariance
ignored.

Args:
X: A `b x q x d`-dim Tensor, where `d` is the dimension of the
Expand Down Expand Up @@ -648,20 +655,41 @@ def posterior(
)
if not returns_untransformed:
mvns = [p.distribution for p in posterior.posteriors]
# Combining MTMVNs into a single MTMVN is currently not supported.
if not any(isinstance(m, MultitaskMultivariateNormal) for m in mvns):
# Return the result as a GPyTorchPosterior/GaussianMixturePosterior.
if any(isinstance(m, MultitaskMultivariateNormal) for m in mvns):
mvn_list = []
for mvn in mvns:
if len(mvn.event_shape) == 2:
# We separate MTMVNs into independent-across-task MVNs for
# the convenience of using BlockDiagLinearOperator below.
# (b x q x m x m) -> list of m (b x q x 1 x 1)
mvn_list.extend(separate_mtmvn(mvn))
else:
mvn_list.append(mvn)
mean = torch.stack([mvn.mean for mvn in mvn_list], dim=-1)
covars = CatLinearOperator(
*[mvn.lazy_covariance_matrix.unsqueeze(-3) for mvn in mvn_list],
dim=-3,
) # List of m (b x q x 1 x 1) -> (b x q x m x 1 x 1)
mvn = MultitaskMultivariateNormal(
mean=mean,
covariance_matrix=BlockDiagLinearOperator(covars, block_dim=-3).to(
X
), # (b x q x m x 1 x 1) -> (b x q x m x m)
interleaved=False,
)
else:
mvn = (
mvns[0]
if len(mvns) == 1
else MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
)
if any(is_ensemble(m) for m in self.models):
# Mixing fully Bayesian and other GP models is currently
# not supported.
posterior = GaussianMixturePosterior(distribution=mvn)
else:
posterior = GPyTorchPosterior(distribution=mvn)
# Return the result as a GPyTorchPosterior/GaussianMixturePosterior.
if any(is_ensemble(m) for m in self.models):
# Mixing fully Bayesian and other GP models is currently
# not supported.
posterior = GaussianMixturePosterior(distribution=mvn)
else:
posterior = GPyTorchPosterior(distribution=mvn)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior
Expand Down
52 changes: 38 additions & 14 deletions botorch/models/model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from copy import deepcopy
from typing import Any, List

import torch

from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.gpytorch import GPyTorchModel, ModelListGPyTorchModel
from botorch.models.model import FantasizeMixin
Expand Down Expand Up @@ -87,34 +89,56 @@ def condition_on_observations(
f"{Y.shape[-1]} observation outputs, but model has "
f"{self.num_outputs} outputs."
)
targets = [Y[..., i] for i in range(Y.shape[-1])]
for i, model in enumerate(self.models):
if hasattr(model, "outcome_transform"):
noise = kwargs.get("noise")
targets[i], noise = model.outcome_transform(targets[i], noise)

# This should never trigger, posterior call would fail.
assert len(targets) == len(X)
if len(X) != self.num_outputs:
raise BotorchTensorDimensionError(
"Incorrect number of inputs for observations. Received "
f"{len(X)} observation inputs, but model has "
f"{self.num_outputs} outputs."
)
if "noise" in kwargs:
noise = kwargs.pop("noise")
if noise.shape != Y.shape[-noise.dim() :]:
raise BotorchTensorDimensionError(
"The shape of observation noise does not agree with the outcomes. "
f"Received {noise.shape} noise with {Y.shape} outcomes."
)
kwargs_ = {**kwargs, "noise": [noise[..., i] for i in range(Y.shape[-1])]}

else:
kwargs_ = kwargs
return super().get_fantasy_model(X, targets, **kwargs_)
noise = None
targets = []
inputs = []
noises = []
i = 0
for model in self.models:
j = i + model.num_outputs
y_i = torch.cat([Y[..., k] for k in range(i, j)], dim=-1)
X_i = torch.cat([X[k] for k in range(i, j)], dim=-2)
if noise is None:
noise_i = None
else:
noise_i = torch.cat([noise[..., k] for k in range(i, j)], dim=-1)
if hasattr(model, "outcome_transform"):
y_i, noise_i = model.outcome_transform(y_i, noise_i)
if noise_i is not None:
noise_i = noise_i.squeeze(0)
targets.append(y_i)
inputs.append(X_i)
noises.append(noise_i)
i += model.num_outputs

kwargs_ = {**kwargs, "noise": noises} if noise is not None else kwargs
return super().get_fantasy_model(inputs, targets, **kwargs_)

def subset_output(self, idcs: List[int]) -> ModelListGP:
r"""Subset the model along the output dimension.
r"""Subset the model along the submodel dimension.

Args:
idcs: The output indices to subset the model to.
idcs: The indices of submodels to subset the model to.

Returns:
The current model, subset to the specified output indices.
The current model, subset to the specified submodels. If each model
is single-output, this will correspond to that subset of the
outputs.
"""
return self.__class__(*[deepcopy(self.models[i]) for i in idcs])

Expand Down
48 changes: 48 additions & 0 deletions botorch/utils/multitask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Helpers for multitask modeling.
"""

from __future__ import annotations

from typing import List

import torch
from gpytorch.distributions import MultitaskMultivariateNormal
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from linear_operator import to_linear_operator


def separate_mtmvn(mvn: MultitaskMultivariateNormal) -> List[MultivariateNormal]:
"""
Separate a MTMVN into a list of MVNs, where covariance across data within each task are
preserved, while covariance across task are dropped.
"""
# TODO T150340766 Upstream this into a class method on gpytorch MultitaskMultivariateNormal.
full_covar = mvn.lazy_covariance_matrix
num_data, num_tasks = mvn.mean.shape[-2:]
if mvn._interleaved:
data_indices = torch.arange(
0, num_data * num_tasks, num_tasks, device=full_covar.device
).view(-1, 1, 1)
task_indices = torch.arange(num_tasks, device=full_covar.device)
else:
data_indices = torch.arange(num_data, device=full_covar.device).view(-1, 1, 1)
task_indices = torch.arange(
0, num_data * num_tasks, num_data, device=full_covar.device
)
slice_ = (data_indices + task_indices).transpose(-1, -3)
data_covars = full_covar[..., slice_, slice_.transpose(-1, -2)]
mvns = []
for c in range(num_tasks):
mvns.append(
MultivariateNormal(
mvn.mean[..., c], to_linear_operator(data_covars[..., c, :, :])
)
)
return mvns
45 changes: 42 additions & 3 deletions test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
ModelListGPyTorchModel,
)
from botorch.models.model import FantasizeMixin
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms import Standardize
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
from botorch.models.utils import fantasize
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.test_helpers import SimpleGPyTorchModel
from botorch.utils.testing import BotorchTestCase
from botorch.utils.testing import _get_random_data, BotorchTestCase
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
Expand Down Expand Up @@ -441,7 +442,44 @@ def test_model_list_gpytorch_model(self):
posterior = model.posterior(test_X)
self.assertIsInstance(posterior, GPyTorchPosterior)
self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
# test multioutput
train_x_raw, train_y = _get_random_data(
batch_shape=torch.Size(), m=1, n=10, **tkwargs
)
task_idx = torch.cat(
[torch.ones(5, 1, **tkwargs), torch.zeros(5, 1, **tkwargs)], dim=0
)
train_x = torch.cat([train_x_raw, task_idx], dim=-1)
model_mt = MultiTaskGP(
train_X=train_x,
train_Y=train_y,
task_feature=-1,
)
mt_posterior = model_mt.posterior(test_X)
model = SimpleModelListGPyTorchModel(m1, model_mt, m2)
posterior2 = model.posterior(test_X)
expected_mean = torch.cat(
(
posterior.mean[:, 0].unsqueeze(-1),
mt_posterior.mean,
posterior.mean[:, 1].unsqueeze(-1),
),
dim=1,
)
self.assertTrue(torch.allclose(expected_mean, posterior2.mean))
expected_covariance = torch.block_diag(
posterior.covariance_matrix[:2, :2],
mt_posterior.covariance_matrix[:2, :2],
mt_posterior.covariance_matrix[-2:, -2:],
posterior.covariance_matrix[-2:, -2:],
)
self.assertTrue(
torch.allclose(
expected_covariance, posterior2.covariance_matrix, atol=1e-5
)
)
# test output indices
posterior = model.posterior(test_X)
for output_indices in ([0], [1], [0, 1]):
posterior_subset = model.posterior(
test_X, output_indices=output_indices
Expand All @@ -451,17 +489,18 @@ def test_model_list_gpytorch_model(self):
posterior_subset.mean.shape, torch.Size([2, len(output_indices)])
)
self.assertTrue(
torch.equal(
torch.allclose(
posterior_subset.mean, posterior.mean[..., output_indices]
)
)
self.assertTrue(
torch.equal(
torch.allclose(
posterior_subset.variance,
posterior.variance[..., output_indices],
)
)
# test observation noise
model = SimpleModelListGPyTorchModel(m1, m2)
posterior = model.posterior(test_X, observation_noise=True)
self.assertIsInstance(posterior, GPyTorchPosterior)
self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
Expand Down
51 changes: 43 additions & 8 deletions test/models/test_model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import itertools
import warnings
from copy import deepcopy
from typing import Optional

import torch
Expand Down Expand Up @@ -206,7 +207,7 @@ def _base_test_ModelListGP(
)

# test X having wrong size
with self.assertRaises(AssertionError):
with self.assertRaises(BotorchTensorDimensionError):
model.condition_on_observations(f_x[:1], f_y)

# test posterior transform
Expand Down Expand Up @@ -336,12 +337,46 @@ def test_ModelListGP_multi_task(self):
model_list_gp_mean = model_list_gp.posterior(train_x_raw).mean
self.assertAllClose(model2_mean, model_list_gp_mean)
# Mix of multi-output and single-output MTGPs.
model_list_gp = ModelListGP(model, model2)
self.assertEqual(model_list_gp.num_outputs, 3)
model_list_gp = ModelListGP(model, model2, deepcopy(model))
self.assertEqual(model_list_gp.num_outputs, 4)
with torch.no_grad():
model_list_gp_mean = model_list_gp.posterior(train_x_raw).mean
expected_mean = torch.cat([model_mean, model2_mean], dim=-1)
self.assertAllClose(expected_mean, model_list_gp_mean)
posterior = model_list_gp.posterior(train_x_raw)
expected_mean = torch.cat([model_mean, model2_mean, model_mean], dim=-1)
self.assertAllClose(expected_mean, posterior.mean)
C1 = model.posterior(train_x_raw).covariance_matrix
C2 = model2.posterior(train_x_raw).covariance_matrix[:10, :10]
C3 = model2.posterior(train_x_raw).covariance_matrix[-10:, -10:]
expected_covariance = torch.block_diag(C1, C2, C3, C1)
self.assertTrue(
torch.allclose(expected_covariance, posterior.covariance_matrix, atol=1e-5)
)
# test subset outputs
subset_model = model_list_gp.subset_output([1])
self.assertEqual(subset_model.num_outputs, 2)
subset_model = model_list_gp.subset_output([0, 1])
self.assertEqual(subset_model.num_outputs, 3)
self.assertEqual(len(subset_model.models), 2)
# Test condition on observations
model_s1 = SingleTaskGP(
train_X=train_x_raw,
train_Y=train_y,
)
model_list_gp = ModelListGP(model_s1, model2, deepcopy(model_s1))
model_list_gp.posterior(train_x_raw)
f_x = [torch.rand(5, 1, **tkwargs) for _ in range(2)]
C1 = torch.cat((f_x[0], torch.zeros(5, 1, **tkwargs)), dim=-1)
C2 = torch.cat((f_x[1], torch.ones(5, 1, **tkwargs)), dim=-1)
f_x2 = [f_x[0], C1, C2, f_x[1]]
f_y = torch.rand(5, 4, **tkwargs)
cm = model_list_gp.condition_on_observations(f_x2, f_y)
self.assertIsInstance(cm, ModelListGP)
self.assertEqual(cm.num_outputs, 4)
self.assertEqual(len(cm.models), 3)
for i in [0, 2]:
self.assertIsInstance(cm.models[i], SingleTaskGP)
self.assertEqual(cm.models[i].train_inputs[0].shape, torch.Size([15, 1]))
self.assertIsInstance(cm.models[1], MultiTaskGP)
self.assertEqual(cm.models[1].train_inputs[0].shape, torch.Size([20, 2]))

def test_transform_revert_train_inputs(self):
tkwargs = {"device": self.device, "dtype": torch.float}
Expand Down Expand Up @@ -513,11 +548,11 @@ def _get_fant_mean(
eval_mask: Optional[Tensor] = None,
) -> float:
fant = model.fantasize(
target_x,
target_x, # noqa
sampler=sampler,
evaluation_mask=eval_mask,
)
return fant.posterior(target_x).mean.mean(dim=(-2, -3))
return fant.posterior(target_x).mean.mean(dim=(-2, -3)) # noqa

# ~0
sampler = IIDNormalSampler(sample_shape=torch.Size([10]), seed=0)
Expand Down