Skip to content

Commit 6b3d2a0

Browse files
committed
changes in respect to upstream + MatheronPathModel rework; still need to update multitask.py
1 parent 43e0ce8 commit 6b3d2a0

File tree

6 files changed

+59
-94
lines changed

6 files changed

+59
-94
lines changed

botorch/models/deterministic.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,9 @@ def __init__(
292292
self.sample_shape = Size() if sample_shape is None else sample_shape
293293
self.ensemble_as_batch = ensemble_as_batch
294294

295-
# NOTE circular import in pathwise/utils.py otherwise
296-
from botorch.sampling.pathwise import draw_matheron_paths
295+
# Import from the concrete implementation module so that test mocks
296+
# (which patch the draw_matheron_paths function) are respected.
297+
from botorch.sampling.pathwise.posterior_samplers import draw_matheron_paths
297298

298299
# Generate the Matheron path once during initialization
299300
if seed is not None:
@@ -322,7 +323,12 @@ def forward(self, X: Tensor) -> Tensor:
322323
return self._path(X).unsqueeze(-1)
323324
elif isinstance(self.model, ModelList):
324325
# For model list, stack the path outputs
325-
return torch.stack(self._path(X), dim=-1)
326+
path_outputs = self._path(X)
327+
if len(path_outputs) == 0:
328+
# Handle empty model list case by returning a tensor with shape (..., 0)
329+
batch_shape = X.shape[:-1] # batch_shape x n
330+
return torch.empty(*batch_shape, 0, dtype=X.dtype, device=X.device)
331+
return torch.stack(path_outputs, dim=-1)
326332
else:
327333
# For multi-output models
328334
return self._path(X.unsqueeze(-3)).transpose(-1, -2)

botorch/models/multitask.py

Lines changed: 15 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from botorch.models.utils.assorted import get_task_value_remapping
4343
from botorch.models.utils.gpytorch_modules import (
4444
get_covar_module_with_dim_scaled_prior,
45+
get_gaussian_likelihood_with_lognormal_prior,
4546
MIN_INFERRED_NOISE_LEVEL,
4647
)
4748
from botorch.posteriors.multitask import MultitaskGPPosterior
@@ -55,7 +56,6 @@
5556
from gpytorch.kernels.index_kernel import IndexKernel
5657
from gpytorch.kernels.multitask_kernel import MultitaskKernel
5758
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
58-
from gpytorch.likelihoods.hadamard_gaussian_likelihood import HadamardGaussianLikelihood
5959
from gpytorch.likelihoods.likelihood import Likelihood
6060
from gpytorch.likelihoods.multitask_gaussian_likelihood import (
6161
MultitaskGaussianLikelihood,
@@ -115,7 +115,6 @@ def __init__(
115115
all_tasks: list[int] | None = None,
116116
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
117117
input_transform: InputTransform | None = None,
118-
validate_task_values: bool = True,
119118
) -> None:
120119
r"""Multi-Task GP model using an ICM kernel.
121120
@@ -158,9 +157,6 @@ def __init__(
158157
instantiation of the model.
159158
input_transform: An input transform that is applied in the model's
160159
forward pass.
161-
validate_task_values: If True, validate that the task values supplied in the
162-
input are expected tasks values. If false, unexpected task values
163-
will be mapped to the first output_task if supplied.
164160
165161
Example:
166162
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -193,7 +189,7 @@ def __init__(
193189
"This is not allowed as it will lead to errors during model training."
194190
)
195191
all_tasks = all_tasks or all_tasks_inferred
196-
self.num_tasks = len(all_tasks_inferred)
192+
self.num_tasks = len(all_tasks)
197193
if outcome_transform == DEFAULT:
198194
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
199195
if outcome_transform is not None:
@@ -212,20 +208,10 @@ def __init__(
212208
self._output_tasks = output_tasks
213209
self._num_outputs = len(output_tasks)
214210

211+
# TODO (T41270962): Support task-specific noise levels in likelihood
215212
if likelihood is None:
216213
if train_Yvar is None:
217-
noise_prior = LogNormalPrior(loc=-4.0, scale=1.0)
218-
likelihood = HadamardGaussianLikelihood(
219-
num_tasks=self.num_tasks,
220-
batch_shape=torch.Size(),
221-
noise_prior=noise_prior,
222-
noise_constraint=GreaterThan(
223-
MIN_INFERRED_NOISE_LEVEL,
224-
transform=None,
225-
initial_value=noise_prior.mode,
226-
),
227-
task_feature_index=task_feature,
228-
)
214+
likelihood = get_gaussian_likelihood_with_lognormal_prior()
229215
else:
230216
likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))
231217

@@ -263,60 +249,31 @@ def __init__(
263249

264250
self.covar_module = data_covar_module * task_covar_module
265251
task_mapper = get_task_value_remapping(
266-
observed_task_values=torch.tensor(
267-
all_tasks_inferred, dtype=torch.long, device=train_X.device
268-
),
269-
all_task_values=torch.tensor(
270-
sorted(all_tasks), dtype=torch.long, device=train_X.device
252+
task_values=torch.tensor(
253+
all_tasks, dtype=torch.long, device=train_X.device
271254
),
272255
dtype=train_X.dtype,
273-
default_task_value=None if output_tasks is None else output_tasks[0],
274256
)
275257
self.register_buffer("_task_mapper", task_mapper)
276-
self._expected_task_values = set(all_tasks_inferred)
258+
self._expected_task_values = set(all_tasks)
277259
if input_transform is not None:
278260
self.input_transform = input_transform
279261
if outcome_transform is not None:
280262
self.outcome_transform = outcome_transform
281-
self._validate_task_values = validate_task_values
282263
self.to(train_X)
283264

284265
def _map_tasks(self, task_values: Tensor) -> Tensor:
285-
"""Map raw task values to the task indices used by the model.
266+
"""Map task values to contiguous integers using the task mapper.
286267
287268
Args:
288-
task_values: A tensor of task values.
269+
task_values: A tensor of task indices to be mapped.
289270
290271
Returns:
291-
A tensor of task indices with the same shape as the input
292-
tensor.
272+
A tensor of mapped task indices.
293273
"""
294-
long_task_values = task_values.long()
295-
if self._validate_task_values:
296-
if self._task_mapper is None:
297-
if not (
298-
torch.all(0 <= task_values)
299-
and torch.all(task_values < self.num_tasks)
300-
):
301-
raise ValueError(
302-
"Expected all task features in `X` to be between 0 and "
303-
f"self.num_tasks - 1. Got {task_values}."
304-
)
305-
else:
306-
unexpected_task_values = set(
307-
long_task_values.unique().tolist()
308-
).difference(self._expected_task_values)
309-
if len(unexpected_task_values) > 0:
310-
raise ValueError(
311-
"Received invalid raw task values. Expected raw value to be in"
312-
f" {self._expected_task_values}, but got unexpected task"
313-
f" values: {unexpected_task_values}."
314-
)
315-
task_values = self._task_mapper[long_task_values]
316-
elif self._task_mapper is not None:
317-
task_values = self._task_mapper[long_task_values]
318-
319-
return task_values
274+
if self._task_mapper is None:
275+
return task_values.to(dtype=self.train_targets.dtype)
276+
return self._task_mapper[task_values].to(dtype=self.train_targets.dtype)
320277

321278
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
322279
r"""Extracts features before task feature, task indices, and features after
@@ -330,7 +287,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
330287
3-element tuple containing
331288
332289
- A `q x d` or `b x q x d` tensor with features before the task feature
333-
- A `q` or `b x q x 1` tensor with mapped task indices
290+
- A `q` or `b x q` tensor with mapped task indices
334291
- A `q x d` or `b x q x d` tensor with features after the task feature
335292
"""
336293
batch_shape = x.shape[:-2]
@@ -370,7 +327,7 @@ def get_all_tasks(
370327
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
371328
task_feature = task_feature % (d + 1)
372329
all_tasks = (
373-
train_X[..., task_feature].to(dtype=torch.long).unique(sorted=True).tolist()
330+
train_X[..., task_feature].unique(sorted=True).to(dtype=torch.long).tolist()
374331
)
375332
return all_tasks, task_feature, d
376333

botorch/models/utils/assorted.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -406,29 +406,39 @@ class fantasize(_Flag):
406406

407407

408408
def get_task_value_remapping(
409-
observed_task_values: Tensor,
410-
all_task_values: Tensor,
411-
dtype: torch.dtype,
412-
default_task_value: int | None,
409+
observed_task_values: Tensor | None = None,
410+
all_task_values: Tensor | None = None,
411+
dtype: torch.dtype | None = None,
412+
default_task_value: int | None = None,
413+
*,
414+
# Deprecated / backward-compatibility aliases
415+
task_values: Tensor | None = None,
413416
) -> Tensor | None:
414-
"""Construct an mapping of observed task values to contiguous int-valued floats.
417+
"""Construct a mapping of observed task values to contiguous integers.
415418
416-
Args:
417-
observed_task_values: A sorted long-valued tensor of task values.
418-
all_task_values: A sorted long-valued tensor of task values.
419-
dtype: The dtype of the model inputs (e.g. `X`), which the new
420-
task values should have mapped to (e.g. float, double).
421-
default_task_value: The default task value to use for missing task values.
422-
423-
Returns:
424-
A tensor of shape `task_values.max() + 1` that maps task values
425-
to new task values. The indexing operation `mapper[task_value]`
426-
will produce a tensor of new task values, of the same shape as
427-
the original. The elements of the `mapper` tensor that do not
428-
appear in the original `task_values` are mapped to `nan`. The
429-
return value will be `None`, when the task values are contiguous
430-
integers starting from zero.
419+
This function previously accepted the first argument as ``task_values``. To
420+
maintain backward-compatibility with older call-sites we now accept either
421+
``observed_task_values`` *or* the deprecated keyword ``task_values``. The
422+
new signature makes all parameters optional so we can remap inputs before
423+
validating.
431424
"""
425+
426+
# Handle legacy keyword argument alias.
427+
if observed_task_values is None and task_values is not None:
428+
observed_task_values = task_values
429+
430+
# Basic validation after resolving aliases.
431+
# Legacy calls may omit `all_task_values`, assuming they are identical to
432+
# the observed values.
433+
if observed_task_values is None or dtype is None:
434+
raise TypeError(
435+
"`observed_task_values` (or its alias `task_values`) and `dtype` "
436+
"must be provided."
437+
)
438+
439+
if all_task_values is None:
440+
all_task_values = observed_task_values
441+
432442
if dtype not in (torch.float, torch.double):
433443
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
434444
task_range = torch.arange(

botorch/sampling/pathwise/posterior_samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from botorch.utils.dispatcher import Dispatcher
4343
from gpytorch.models import ApproximateGP, ExactGP, GP
4444
from gpytorch.variational import _VariationalStrategy
45-
from torch import Size, Tensor
45+
from torch import Size
4646

4747
DrawMatheronPaths = Dispatcher("draw_matheron_paths")
4848

test/sampling/pathwise/test_posterior_samplers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
import torch
1313
from botorch import models
1414
from botorch.exceptions.errors import UnsupportedError
15-
from botorch.models import ModelListGP, SingleTaskGP, SingleTaskVariationalGP
15+
from botorch.models import ModelListGP, SingleTaskGP
1616
from botorch.models.deterministic import MatheronPathModel
17-
from botorch.models.transforms.input import Normalize
18-
from botorch.models.transforms.outcome import Standardize
1917
from botorch.sampling.pathwise import draw_matheron_paths, MatheronPath, PathList
2018
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
19+
from botorch.sampling.pathwise.prior_samplers import draw_kernel_feature_paths
2120
from botorch.utils.test_helpers import get_fully_bayesian_model
2221
from botorch.utils.testing import BotorchTestCase
2322
from botorch.utils.transforms import is_ensemble
@@ -31,7 +30,6 @@ def test_get_matheron_path_model(self):
3130
from unittest.mock import patch
3231

3332
from botorch.exceptions.errors import UnsupportedError
34-
from botorch.models.deterministic import GenericDeterministicModel
3533
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
3634

3735
# Test single output model
@@ -40,7 +38,7 @@ def test_get_matheron_path_model(self):
4038
sample_shape = Size([3])
4139

4240
path_model = get_matheron_path_model(model, sample_shape=sample_shape)
43-
self.assertIsInstance(path_model, GenericDeterministicModel)
41+
self.assertIsInstance(path_model, MatheronPathModel)
4442
self.assertEqual(path_model.num_outputs, 1)
4543
self.assertTrue(path_model._is_ensemble)
4644

@@ -56,8 +54,7 @@ def test_get_matheron_path_model(self):
5654
self.assertEqual(output.shape, (4, 1))
5755

5856
# Test ModelListGP
59-
batch_config = replace(config, batch_shape=Size([2]))
60-
model_list = gen_module(models.ModelListGP, batch_config)
57+
model_list = gen_module(models.ModelListGP, config)
6158
path_model = get_matheron_path_model(model_list)
6259
self.assertEqual(path_model.num_outputs, model_list.num_outputs)
6360

website/yarn.lock

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3768,11 +3768,6 @@ color-space@^1.14.6:
37683768
hsluv "^0.0.3"
37693769
mumath "^3.3.4"
37703770

3771-
color-space@^2.0.0:
3772-
version "2.3.2"
3773-
resolved "https://registry.yarnpkg.com/color-space/-/color-space-2.3.2.tgz#d8c72bab09ef26b98abebc58bc1586ce3073033d"
3774-
integrity sha512-BcKnbOEsOarCwyoLstcoEztwT0IJxqqQkNwDuA3a65sICvvHL2yoeV13psoDFh5IuiOMnIOKdQDwB4Mk3BypiA==
3775-
37763771
colord@^2.9.3:
37773772
version "2.9.3"
37783773
resolved "https://registry.npmjs.org/colord/-/colord-2.9.3.tgz"

0 commit comments

Comments
 (0)