Skip to content

Commit ca5a0a4

Browse files
blethamfacebook-github-bot
authored andcommitted
Add support for multitask models to ModelListGP (meta-pytorch#2154)
Summary: This upstreams behavior from MixedOutputModelListGP to enable computing a posterior on a model list that is a mix of single- and multi-task models. The ability to do this is necessary for using models that are multi-task across outcomes, such as LCE-M. Reviewed By: saitcakmak Differential Revision: D51906858
1 parent 859f63f commit ca5a0a4

File tree

5 files changed

+210
-36
lines changed

5 files changed

+210
-36
lines changed

botorch/models/gpytorch.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
)
3737
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
3838
from botorch.posteriors.gpytorch import GPyTorchPosterior
39+
from botorch.utils.multitask import separate_mtmvn
3940
from botorch.utils.transforms import is_ensemble
4041
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
4142
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
43+
from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator
4244
from torch import Tensor
4345

4446
if TYPE_CHECKING:
@@ -101,6 +103,7 @@ def _validate_tensor_args(
101103
"following error would have been raised with strict enforcement: "
102104
f"{message}",
103105
BotorchTensorDimensionWarning,
106+
stacklevel=2,
104107
)
105108
# Yvar may not have the same batch dimensions, but the trailing dimensions
106109
# of Yvar should be the same as the trailing dimensions of Y.
@@ -559,7 +562,8 @@ class ModelListGPyTorchModel(ModelList, GPyTorchModel, ABC):
559562
r"""Abstract base class for models based on multi-output GPyTorch models.
560563
561564
This is meant to be used with a gpytorch ModelList wrapper for independent
562-
evaluation of submodels.
565+
evaluation of submodels. Those submodels can themselves be multi-output
566+
models, in which case the task covariances will be ignored.
563567
564568
:meta private:
565569
"""
@@ -582,7 +586,7 @@ def batch_shape(self) -> torch.Size:
582586
)
583587
try:
584588
broadcast_shape = torch.broadcast_shapes(*batch_shapes)
585-
warnings.warn(msg + ". Broadcasting batch shapes.")
589+
warnings.warn(msg + ". Broadcasting batch shapes.", stacklevel=2)
586590
return broadcast_shape
587591
except RuntimeError:
588592
raise NotImplementedError(msg + " that are not broadcastble.")
@@ -598,6 +602,9 @@ def posterior(
598602
**kwargs: Any,
599603
) -> Union[GPyTorchPosterior, PosteriorList]:
600604
r"""Computes the posterior over model outputs at the provided points.
605+
If any model returns a MultitaskMultivariateNormal posterior, then that
606+
will be split into individual MVNs per task, with inter-task covariance
607+
ignored.
601608
602609
Args:
603610
X: A `b x q x d`-dim Tensor, where `d` is the dimension of the
@@ -648,20 +655,41 @@ def posterior(
648655
)
649656
if not returns_untransformed:
650657
mvns = [p.distribution for p in posterior.posteriors]
651-
# Combining MTMVNs into a single MTMVN is currently not supported.
652-
if not any(isinstance(m, MultitaskMultivariateNormal) for m in mvns):
653-
# Return the result as a GPyTorchPosterior/GaussianMixturePosterior.
658+
if any(isinstance(m, MultitaskMultivariateNormal) for m in mvns):
659+
mvn_list = []
660+
for mvn in mvns:
661+
if len(mvn.event_shape) == 2:
662+
# We separate MTMVNs into independent-across-task MVNs for
663+
# the convenience of using BlockDiagLinearOperator below.
664+
# (b x q x m x m) -> list of m (b x q x 1 x 1)
665+
mvn_list.extend(separate_mtmvn(mvn))
666+
else:
667+
mvn_list.append(mvn)
668+
mean = torch.stack([mvn.mean for mvn in mvn_list], dim=-1)
669+
covars = CatLinearOperator(
670+
*[mvn.lazy_covariance_matrix.unsqueeze(-3) for mvn in mvn_list],
671+
dim=-3,
672+
) # List of m (b x q x 1 x 1) -> (b x q x m x 1 x 1)
673+
mvn = MultitaskMultivariateNormal(
674+
mean=mean,
675+
covariance_matrix=BlockDiagLinearOperator(covars, block_dim=-3).to(
676+
X
677+
), # (b x q x m x 1 x 1) -> (b x q x m x m)
678+
interleaved=False,
679+
)
680+
else:
654681
mvn = (
655682
mvns[0]
656683
if len(mvns) == 1
657684
else MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
658685
)
659-
if any(is_ensemble(m) for m in self.models):
660-
# Mixing fully Bayesian and other GP models is currently
661-
# not supported.
662-
posterior = GaussianMixturePosterior(distribution=mvn)
663-
else:
664-
posterior = GPyTorchPosterior(distribution=mvn)
686+
# Return the result as a GPyTorchPosterior/GaussianMixturePosterior.
687+
if any(is_ensemble(m) for m in self.models):
688+
# Mixing fully Bayesian and other GP models is currently
689+
# not supported.
690+
posterior = GaussianMixturePosterior(distribution=mvn)
691+
else:
692+
posterior = GPyTorchPosterior(distribution=mvn)
665693
if posterior_transform is not None:
666694
return posterior_transform(posterior)
667695
return posterior

botorch/models/model_list_gp_regression.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from copy import deepcopy
1414
from typing import Any, List
1515

16+
import torch
17+
1618
from botorch.exceptions.errors import BotorchTensorDimensionError
1719
from botorch.models.gpytorch import GPyTorchModel, ModelListGPyTorchModel
1820
from botorch.models.model import FantasizeMixin
@@ -87,34 +89,56 @@ def condition_on_observations(
8789
f"{Y.shape[-1]} observation outputs, but model has "
8890
f"{self.num_outputs} outputs."
8991
)
90-
targets = [Y[..., i] for i in range(Y.shape[-1])]
91-
for i, model in enumerate(self.models):
92-
if hasattr(model, "outcome_transform"):
93-
noise = kwargs.get("noise")
94-
targets[i], noise = model.outcome_transform(targets[i], noise)
95-
96-
# This should never trigger, posterior call would fail.
97-
assert len(targets) == len(X)
92+
if len(X) != self.num_outputs:
93+
raise BotorchTensorDimensionError(
94+
"Incorrect number of inputs for observations. Received "
95+
f"{len(X)} observation inputs, but model has "
96+
f"{self.num_outputs} outputs."
97+
)
9898
if "noise" in kwargs:
9999
noise = kwargs.pop("noise")
100100
if noise.shape != Y.shape[-noise.dim() :]:
101101
raise BotorchTensorDimensionError(
102102
"The shape of observation noise does not agree with the outcomes. "
103103
f"Received {noise.shape} noise with {Y.shape} outcomes."
104104
)
105-
kwargs_ = {**kwargs, "noise": [noise[..., i] for i in range(Y.shape[-1])]}
105+
106106
else:
107-
kwargs_ = kwargs
108-
return super().get_fantasy_model(X, targets, **kwargs_)
107+
noise = None
108+
targets = []
109+
inputs = []
110+
noises = []
111+
i = 0
112+
for model in self.models:
113+
j = i + model.num_outputs
114+
y_i = torch.cat([Y[..., k] for k in range(i, j)], dim=-1)
115+
X_i = torch.cat([X[k] for k in range(i, j)], dim=-2)
116+
if noise is None:
117+
noise_i = None
118+
else:
119+
noise_i = torch.cat([noise[..., k] for k in range(i, j)], dim=-1)
120+
if hasattr(model, "outcome_transform"):
121+
y_i, noise_i = model.outcome_transform(y_i, noise_i)
122+
if noise_i is not None:
123+
noise_i = noise_i.squeeze(0)
124+
targets.append(y_i)
125+
inputs.append(X_i)
126+
noises.append(noise_i)
127+
i += model.num_outputs
128+
129+
kwargs_ = {**kwargs, "noise": noises} if noise is not None else kwargs
130+
return super().get_fantasy_model(inputs, targets, **kwargs_)
109131

110132
def subset_output(self, idcs: List[int]) -> ModelListGP:
111-
r"""Subset the model along the output dimension.
133+
r"""Subset the model along the submodel dimension.
112134
113135
Args:
114-
idcs: The output indices to subset the model to.
136+
idcs: The indices of submodels to subset the model to.
115137
116138
Returns:
117-
The current model, subset to the specified output indices.
139+
The current model, subset to the specified submodels. If each model
140+
is single-output, this will correspond to that subset of the
141+
outputs.
118142
"""
119143
return self.__class__(*[deepcopy(self.models[i]) for i in idcs])
120144

botorch/utils/multitask.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Helpers for multitask modeling.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from typing import List
14+
15+
import torch
16+
from gpytorch.distributions import MultitaskMultivariateNormal
17+
from gpytorch.distributions.multivariate_normal import MultivariateNormal
18+
from linear_operator import to_linear_operator
19+
20+
21+
def separate_mtmvn(mvn: MultitaskMultivariateNormal) -> List[MultivariateNormal]:
22+
"""
23+
Separate a MTMVN into a list of MVNs, where covariance across data within each task are
24+
preserved, while covariance across task are dropped.
25+
"""
26+
# TODO T150340766 Upstream this into a class method on gpytorch MultitaskMultivariateNormal.
27+
full_covar = mvn.lazy_covariance_matrix
28+
num_data, num_tasks = mvn.mean.shape[-2:]
29+
if mvn._interleaved:
30+
data_indices = torch.arange(
31+
0, num_data * num_tasks, num_tasks, device=full_covar.device
32+
).view(-1, 1, 1)
33+
task_indices = torch.arange(num_tasks, device=full_covar.device)
34+
else:
35+
data_indices = torch.arange(num_data, device=full_covar.device).view(-1, 1, 1)
36+
task_indices = torch.arange(
37+
0, num_data * num_tasks, num_data, device=full_covar.device
38+
)
39+
slice_ = (data_indices + task_indices).transpose(-1, -3)
40+
data_covars = full_covar[..., slice_, slice_.transpose(-1, -2)]
41+
mvns = []
42+
for c in range(num_tasks):
43+
mvns.append(
44+
MultivariateNormal(
45+
mvn.mean[..., c], to_linear_operator(data_covars[..., c, :, :])
46+
)
47+
)
48+
return mvns

test/models/test_gpytorch.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
ModelListGPyTorchModel,
2424
)
2525
from botorch.models.model import FantasizeMixin
26+
from botorch.models.multitask import MultiTaskGP
2627
from botorch.models.transforms import Standardize
2728
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
2829
from botorch.models.utils import fantasize
2930
from botorch.posteriors.gpytorch import GPyTorchPosterior
3031
from botorch.sampling.normal import SobolQMCNormalSampler
3132
from botorch.utils.test_helpers import SimpleGPyTorchModel
32-
from botorch.utils.testing import BotorchTestCase
33+
from botorch.utils.testing import _get_random_data, BotorchTestCase
3334
from gpytorch import ExactMarginalLogLikelihood
3435
from gpytorch.distributions import MultivariateNormal
3536
from gpytorch.kernels import RBFKernel, ScaleKernel
@@ -441,7 +442,44 @@ def test_model_list_gpytorch_model(self):
441442
posterior = model.posterior(test_X)
442443
self.assertIsInstance(posterior, GPyTorchPosterior)
443444
self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
445+
# test multioutput
446+
train_x_raw, train_y = _get_random_data(
447+
batch_shape=torch.Size(), m=1, n=10, **tkwargs
448+
)
449+
task_idx = torch.cat(
450+
[torch.ones(5, 1, **tkwargs), torch.zeros(5, 1, **tkwargs)], dim=0
451+
)
452+
train_x = torch.cat([train_x_raw, task_idx], dim=-1)
453+
model_mt = MultiTaskGP(
454+
train_X=train_x,
455+
train_Y=train_y,
456+
task_feature=-1,
457+
)
458+
mt_posterior = model_mt.posterior(test_X)
459+
model = SimpleModelListGPyTorchModel(m1, model_mt, m2)
460+
posterior2 = model.posterior(test_X)
461+
expected_mean = torch.cat(
462+
(
463+
posterior.mean[:, 0].unsqueeze(-1),
464+
mt_posterior.mean,
465+
posterior.mean[:, 1].unsqueeze(-1),
466+
),
467+
dim=1,
468+
)
469+
self.assertTrue(torch.allclose(expected_mean, posterior2.mean))
470+
expected_covariance = torch.block_diag(
471+
posterior.covariance_matrix[:2, :2],
472+
mt_posterior.covariance_matrix[:2, :2],
473+
mt_posterior.covariance_matrix[-2:, -2:],
474+
posterior.covariance_matrix[-2:, -2:],
475+
)
476+
self.assertTrue(
477+
torch.allclose(
478+
expected_covariance, posterior2.covariance_matrix, atol=1e-5
479+
)
480+
)
444481
# test output indices
482+
posterior = model.posterior(test_X)
445483
for output_indices in ([0], [1], [0, 1]):
446484
posterior_subset = model.posterior(
447485
test_X, output_indices=output_indices
@@ -451,17 +489,18 @@ def test_model_list_gpytorch_model(self):
451489
posterior_subset.mean.shape, torch.Size([2, len(output_indices)])
452490
)
453491
self.assertTrue(
454-
torch.equal(
492+
torch.allclose(
455493
posterior_subset.mean, posterior.mean[..., output_indices]
456494
)
457495
)
458496
self.assertTrue(
459-
torch.equal(
497+
torch.allclose(
460498
posterior_subset.variance,
461499
posterior.variance[..., output_indices],
462500
)
463501
)
464502
# test observation noise
503+
model = SimpleModelListGPyTorchModel(m1, m2)
465504
posterior = model.posterior(test_X, observation_noise=True)
466505
self.assertIsInstance(posterior, GPyTorchPosterior)
467506
self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))

test/models/test_model_list_gp_regression.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import itertools
88
import warnings
9+
from copy import deepcopy
910
from typing import Optional
1011

1112
import torch
@@ -206,7 +207,7 @@ def _base_test_ModelListGP(
206207
)
207208

208209
# test X having wrong size
209-
with self.assertRaises(AssertionError):
210+
with self.assertRaises(BotorchTensorDimensionError):
210211
model.condition_on_observations(f_x[:1], f_y)
211212

212213
# test posterior transform
@@ -336,12 +337,46 @@ def test_ModelListGP_multi_task(self):
336337
model_list_gp_mean = model_list_gp.posterior(train_x_raw).mean
337338
self.assertAllClose(model2_mean, model_list_gp_mean)
338339
# Mix of multi-output and single-output MTGPs.
339-
model_list_gp = ModelListGP(model, model2)
340-
self.assertEqual(model_list_gp.num_outputs, 3)
340+
model_list_gp = ModelListGP(model, model2, deepcopy(model))
341+
self.assertEqual(model_list_gp.num_outputs, 4)
341342
with torch.no_grad():
342-
model_list_gp_mean = model_list_gp.posterior(train_x_raw).mean
343-
expected_mean = torch.cat([model_mean, model2_mean], dim=-1)
344-
self.assertAllClose(expected_mean, model_list_gp_mean)
343+
posterior = model_list_gp.posterior(train_x_raw)
344+
expected_mean = torch.cat([model_mean, model2_mean, model_mean], dim=-1)
345+
self.assertAllClose(expected_mean, posterior.mean)
346+
C1 = model.posterior(train_x_raw).covariance_matrix
347+
C2 = model2.posterior(train_x_raw).covariance_matrix[:10, :10]
348+
C3 = model2.posterior(train_x_raw).covariance_matrix[-10:, -10:]
349+
expected_covariance = torch.block_diag(C1, C2, C3, C1)
350+
self.assertTrue(
351+
torch.allclose(expected_covariance, posterior.covariance_matrix, atol=1e-5)
352+
)
353+
# test subset outputs
354+
subset_model = model_list_gp.subset_output([1])
355+
self.assertEqual(subset_model.num_outputs, 2)
356+
subset_model = model_list_gp.subset_output([0, 1])
357+
self.assertEqual(subset_model.num_outputs, 3)
358+
self.assertEqual(len(subset_model.models), 2)
359+
# Test condition on observations
360+
model_s1 = SingleTaskGP(
361+
train_X=train_x_raw,
362+
train_Y=train_y,
363+
)
364+
model_list_gp = ModelListGP(model_s1, model2, deepcopy(model_s1))
365+
model_list_gp.posterior(train_x_raw)
366+
f_x = [torch.rand(5, 1, **tkwargs) for _ in range(2)]
367+
C1 = torch.cat((f_x[0], torch.zeros(5, 1, **tkwargs)), dim=-1)
368+
C2 = torch.cat((f_x[1], torch.ones(5, 1, **tkwargs)), dim=-1)
369+
f_x2 = [f_x[0], C1, C2, f_x[1]]
370+
f_y = torch.rand(5, 4, **tkwargs)
371+
cm = model_list_gp.condition_on_observations(f_x2, f_y)
372+
self.assertIsInstance(cm, ModelListGP)
373+
self.assertEqual(cm.num_outputs, 4)
374+
self.assertEqual(len(cm.models), 3)
375+
for i in [0, 2]:
376+
self.assertIsInstance(cm.models[i], SingleTaskGP)
377+
self.assertEqual(cm.models[i].train_inputs[0].shape, torch.Size([15, 1]))
378+
self.assertIsInstance(cm.models[1], MultiTaskGP)
379+
self.assertEqual(cm.models[1].train_inputs[0].shape, torch.Size([20, 2]))
345380

346381
def test_transform_revert_train_inputs(self):
347382
tkwargs = {"device": self.device, "dtype": torch.float}
@@ -513,11 +548,11 @@ def _get_fant_mean(
513548
eval_mask: Optional[Tensor] = None,
514549
) -> float:
515550
fant = model.fantasize(
516-
target_x,
551+
target_x, # noqa
517552
sampler=sampler,
518553
evaluation_mask=eval_mask,
519554
)
520-
return fant.posterior(target_x).mean.mean(dim=(-2, -3))
555+
return fant.posterior(target_x).mean.mean(dim=(-2, -3)) # noqa
521556

522557
# ~0
523558
sampler = IIDNormalSampler(sample_shape=torch.Size([10]), seed=0)

0 commit comments

Comments
 (0)